overloading target
authorhaftmann
Mon, 03 Dec 2007 16:04:17 +0100
changeset 25519 8570745cb40b
parent 25518 00d5cc16e891
child 25520 e123c81257a5
overloading target
src/Pure/Isar/ROOT.ML
src/Pure/Isar/isar_syn.ML
src/Pure/Isar/overloading.ML
src/Pure/Isar/theory_target.ML
--- a/src/Pure/Isar/ROOT.ML	Mon Dec 03 16:04:16 2007 +0100
+++ b/src/Pure/Isar/ROOT.ML	Mon Dec 03 16:04:17 2007 +0100
@@ -47,6 +47,7 @@
 
 (*local theories and target primitives*)
 use "local_theory.ML";
+use "overloading.ML";
 use "locale.ML";
 use "class.ML";
 
--- a/src/Pure/Isar/isar_syn.ML	Mon Dec 03 16:04:16 2007 +0100
+++ b/src/Pure/Isar/isar_syn.ML	Mon Dec 03 16:04:17 2007 +0100
@@ -445,7 +445,7 @@
       Toplevel.print o Toplevel.local_theory_to_proof loc (Subclass.subclass_cmd class)));
 
 val _ =
-  OuterSyntax.command "instantiation" "prove type arity" K.thy_decl
+  OuterSyntax.command "instantiation" "instantiate and prove type arity" K.thy_decl
    (P.and_list1 P.arity --| P.begin
      >> (fn arities => Toplevel.print o
          Toplevel.begin_local_theory true (Instance.instantiation_cmd arities)));
@@ -459,6 +459,17 @@
   || Scan.succeed (Toplevel.print o Toplevel.local_theory_to_proof NONE (Class.instantiation_instance I)));
 
 
+(* arbitrary overloading *)
+
+val _ =
+  OuterSyntax.command "overloading" "overloaded definitions" K.thy_decl
+   (Scan.repeat1 (P.xname --| P.$$$ "::" -- P.typ --| P.$$$ "is" -- P.name --
+      Scan.optional (P.$$$ "(" |-- (P.$$$ "unchecked" >> K false) --| P.$$$ ")") true)
+        --| P.begin
+   >> (fn operations => Toplevel.print o
+         Toplevel.begin_local_theory true (TheoryTarget.overloading_cmd operations)));
+
+
 (* code generation *)
 
 val _ =
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/Isar/overloading.ML	Mon Dec 03 16:04:17 2007 +0100
@@ -0,0 +1,94 @@
+(*  Title:      Pure/Isar/overloading.ML
+    ID:         $Id$
+    Author:     Florian Haftmann, TU Muenchen
+
+Overloaded definitions without any discipline.
+*)
+
+signature OVERLOADING =
+sig
+  val init: ((string * typ) * (string * bool)) list -> theory -> local_theory
+  val conclude: local_theory -> local_theory
+  val declare: string * typ -> theory -> term * theory
+  val confirm: string -> local_theory -> local_theory
+  val define: bool -> string -> string * term -> theory -> thm * theory
+  val operation: Proof.context -> string -> (string * bool) option
+end;
+
+structure Overloading: OVERLOADING =
+struct
+
+(* bookkeeping *)
+
+structure OverloadingData = ProofDataFun
+(
+  type T = ((string * typ) * (string * bool)) list;
+  fun init _ = [];
+);
+
+val get_overloading = OverloadingData.get o LocalTheory.target_of;
+val map_overloading = LocalTheory.target o OverloadingData.map;
+
+fun operation lthy v = get_overloading lthy
+  |> get_first (fn ((c, _), (v', checked)) => if v = v' then SOME (c, checked) else NONE);
+
+fun confirm c = map_overloading (filter_out (fn (_, (c', _)) => c' = c));
+
+
+(* overloaded declarations and definitions *)
+
+fun declare c_ty = pair (Const c_ty);
+
+fun define checked name (c, t) =
+  Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
+
+
+(* syntax *)
+
+fun term_check ts lthy =
+  let
+    val overloading = get_overloading lthy;
+    fun subst (t as Const (c, ty)) = (case AList.lookup (op =) overloading (c, ty)
+         of SOME (v, _) => Free (v, ty)
+          | NONE => t)
+      | subst t = t;
+    val ts' = (map o map_aterms) subst ts;
+  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+
+fun term_uncheck ts lthy =
+  let
+    val overloading = get_overloading lthy;
+    fun subst (t as Free (v, ty)) = (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) overloading
+         of SOME c => Const (c, ty)
+          | NONE => t)
+      | subst t = t;
+    val ts' = (map o map_aterms) subst ts;
+  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+
+
+(* target *)
+
+fun init overloading thy =
+  let
+    val _ = if null overloading then error "At least one parameter must be given" else ();
+  in
+    thy
+    |> ProofContext.init
+    |> OverloadingData.put overloading
+    |> fold (Variable.declare_term o Logic.mk_type o snd o fst) overloading
+    |> Context.proof_map (
+        Syntax.add_term_check 0 "overloading" term_check
+        #> Syntax.add_term_uncheck 0 "overloading" term_uncheck)
+  end;
+
+fun conclude lthy =
+  let
+    val overloading = get_overloading lthy;
+    val _ = if null overloading then () else
+      error ("Missing definition(s) for parameters " ^ commas (map (quote
+        o Syntax.string_of_term lthy o Const o fst) overloading));
+  in
+    lthy
+  end;
+
+end;
--- a/src/Pure/Isar/theory_target.ML	Mon Dec 03 16:04:16 2007 +0100
+++ b/src/Pure/Isar/theory_target.ML	Mon Dec 03 16:04:17 2007 +0100
@@ -8,11 +8,14 @@
 signature THEORY_TARGET =
 sig
   val peek: local_theory -> {target: string, is_locale: bool,
-    is_class: bool, instantiation: arity list}
+    is_class: bool, instantiation: arity list,
+    overloading: ((string * typ) * (string * bool)) list}
   val init: string option -> theory -> local_theory
   val begin: string -> Proof.context -> local_theory
   val context: xstring -> theory -> local_theory
   val instantiation: arity list -> theory -> local_theory
+  val overloading: ((string * typ) * (string * bool)) list -> theory -> local_theory
+  val overloading_cmd: (((xstring * xstring) * string) * bool) list -> theory -> local_theory
 end;
 
 structure TheoryTarget: THEORY_TARGET =
@@ -21,13 +24,13 @@
 (* context data *)
 
 datatype target = Target of {target: string, is_locale: bool,
-  is_class: bool, instantiation: arity list};
+  is_class: bool, instantiation: arity list, overloading: ((string * typ) * (string * bool)) list};
 
-fun make_target target is_locale is_class instantiation =
+fun make_target target is_locale is_class instantiation overloading =
   Target {target = target, is_locale = is_locale,
-    is_class = is_class, instantiation = instantiation};
+    is_class = is_class, instantiation = instantiation, overloading = overloading};
 
-val global_target = make_target "" false false [];
+val global_target = make_target "" false false [] [];
 
 structure Data = ProofDataFun
 (
@@ -40,7 +43,7 @@
 
 (* pretty *)
 
-fun pretty (Target {target, is_locale, is_class, instantiation}) ctxt =
+fun pretty (Target {target, is_locale, is_class, instantiation, overloading}) ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     val target_name = (if is_class then "class " else "locale ") ^ Locale.extern thy target;
@@ -196,10 +199,16 @@
     val xs = filter depends (#1 (ProofContext.inferred_fixes (LocalTheory.target_of lthy)));
     val U = map #2 xs ---> T;
     val (mx1, mx2, mx3) = fork_mixfix ta mx;
+    fun syntax_error c = error ("Illegal mixfix syntax for overloaded constant " ^ quote c);
     val declare_const = case Class.instantiation_param lthy c
-       of SOME c' => LocalTheory.theory_result (Class.overloaded_const (c', U, mx3))
+       of SOME c' => if mx3 <> NoSyn then syntax_error c'
+          else LocalTheory.theory_result (Class.overloaded_const (c', U))
             ##> Class.confirm_declaration c
-        | NONE => LocalTheory.theory_result (Sign.declare_const pos (c, U, mx3));
+        | NONE => (case Overloading.operation lthy c
+           of SOME (c', _) => if mx3 <> NoSyn then syntax_error c'
+              else LocalTheory.theory_result (Overloading.declare (c', U))
+                ##> Overloading.confirm c
+            | NONE => LocalTheory.theory_result (Sign.declare_const pos (c, U, mx3)));
     val (const, lthy') = lthy |> declare_const;
     val t = Term.list_comb (const, map Free xs);
   in
@@ -261,9 +270,12 @@
     val (_, lhs') = Logic.dest_equals (Thm.prop_of local_def);
 
     (*def*)
-    val define_const = if is_none (Class.instantiation_param lthy c)
-      then (fn name => fn eq => Thm.add_def false (name, Logic.mk_equals eq))
-      else (fn name => fn (Const (c, _), rhs) => Class.overloaded_def name (c, rhs));
+    val define_const = case Overloading.operation lthy c
+     of SOME (_, checked) =>
+          (fn name => fn (Const (c, _), rhs) => Overloading.define checked name (c, rhs))
+      | NONE => if is_none (Class.instantiation_param lthy c)
+          then (fn name => fn eq => Thm.add_def false false (name, Logic.mk_equals eq))
+          else (fn name => fn (Const (c, _), rhs) => Class.overloaded_def name (c, rhs));
     val (global_def, lthy3) = lthy2
       |> LocalTheory.theory_result (define_const name' (lhs', rhs'));
     val def = LocalDefs.trans_terms lthy3
@@ -309,12 +321,15 @@
 local
 
 fun init_target _ NONE = global_target
-  | init_target thy (SOME target) = make_target target true (Class.is_class thy target) [];
+  | init_target thy (SOME target) = make_target target true (Class.is_class thy target) [] [];
+
+fun init_instantiation arities = make_target "" false false arities [];
 
-fun init_instantiaton arities = make_target "" false false arities
+fun init_overloading operations = make_target "" false false [] operations;
 
-fun init_ctxt (Target {target, is_locale, is_class, instantiation}) =
+fun init_ctxt (Target {target, is_locale, is_class, instantiation, overloading}) =
   if not (null instantiation) then Class.init_instantiation instantiation
+  else if not (null overloading) then Overloading.init overloading
   else if not is_locale then ProofContext.init
   else if not is_class then Locale.init target
   else Class.init target;
@@ -343,8 +358,15 @@
 fun context "-" thy = init NONE thy
   | context target thy = init (SOME (Locale.intern thy target)) thy;
 
-fun instantiation arities thy =
-  init_lthy_ctxt (init_instantiaton arities) thy;
+val instantiation = init_lthy_ctxt o init_instantiation;
+
+fun gen_overloading prep_operation operations thy =
+  thy
+  |> init_lthy_ctxt (init_overloading (map (prep_operation thy) operations));
+
+val overloading = gen_overloading (K I);
+val overloading_cmd = gen_overloading (fn thy => fn (((raw_c, rawT), v), checked) =>
+  ((Sign.intern_const thy raw_c, Sign.read_typ thy rawT), (v, checked)));
 
 end;