improved class syntax
authorhaftmann
Thu, 18 Oct 2007 16:09:38 +0200
changeset 25083 765528b4b419
parent 25082 c93a234ccf2b
child 25084 30ce1e078b72
improved class syntax
src/Pure/Isar/class.ML
src/Pure/Isar/theory_target.ML
--- a/src/Pure/Isar/class.ML	Thu Oct 18 16:09:36 2007 +0200
+++ b/src/Pure/Isar/class.ML	Thu Oct 18 16:09:38 2007 +0200
@@ -16,14 +16,14 @@
     -> string list -> theory -> string * Proof.context
   val class_cmd: bstring -> xstring list -> Element.context Locale.element list
     -> xstring list -> theory -> string * Proof.context
-  val init: class -> Proof.context -> Proof.context
-  val add_const_in_class: string -> (string * mixfix) * term
-    -> theory -> string * theory
-  val add_abbrev_in_class: string -> Syntax.mode -> (string * mixfix) * term
-    -> theory -> string * theory
-  val remove_constraint: class -> string -> Proof.context -> Proof.context
   val is_class: theory -> class -> bool
   val these_params: theory -> sort -> (string * (string * typ)) list
+  val init: class -> Proof.context -> Proof.context
+  val add_logical_const: string -> (string * mixfix) * term
+    -> theory -> string * theory
+  val add_syntactic_const: string -> Syntax.mode -> (string * mixfix) * term
+    -> theory -> string * theory
+  val refresh_syntax: class -> Proof.context -> Proof.context
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
   val print_classes: theory -> unit
@@ -309,8 +309,8 @@
   consts: (string * string) list
     (*locale parameter ~> constant name*),
   base_sort: sort,
-  inst: (typ option list * term option list) * term Symtab.table
-    (*canonical interpretation FIXME*),
+  inst: term option list
+    (*canonical interpretation*),
   morphism: morphism,
     (*partial morphism of canonical interpretation*)
   intro: thm,
@@ -433,23 +433,16 @@
 
 (* updaters *)
 
-fun add_class_data ((class, superclasses), (cs, base_sort, inst, phi, intro)) thy =
+fun add_class_data ((class, superclasses), (cs, base_sort, insttab, phi, intro)) thy =
   let
-    (*FIXME*)
-    val is_empty = null (fold (fn ((_, ty), _) => fold_atyps cons ty) cs [])
-      andalso null ((fold o fold_types o fold_atyps) cons
-        (maps snd (Locale.global_asms_of thy class)) []);
-    (*FIXME*)
-    val inst_params = map
-      (SOME o the o Symtab.lookup inst o fst o fst)
+    val inst = map
+      (SOME o the o Symtab.lookup insttab o fst o fst)
         (Locale.parameters_of_expr thy (Locale.Locale class));
-    val instT = if is_empty then [] else [SOME (TFree (Name.aT, [class]))];
-    val inst' = ((instT, inst_params), inst);
     val operations = map (fn (v_ty as (_, ty'), (c, ty)) =>
       (c, ((Free v_ty, ty'), (Logic.varifyT ty, 0)))) cs;
     val cs = (map o pairself) fst cs;
     val add_class = Graph.new_node (class,
-        mk_class_data ((cs, base_sort, inst', phi, intro), ([], operations)))
+        mk_class_data ((cs, base_sort, inst, phi, intro), ([], operations)))
       #> fold (curry Graph.add_edge class) superclasses;
   in
     ClassData.map add_class thy
@@ -521,7 +514,7 @@
 fun class_interpretation class facts defs thy =
   let
     val params = these_params thy [class];
-    val { inst = ((_, inst), _), ... } = the_class_data thy class;
+    val inst = (#inst o the_class_data thy) class;
     val tac = ALLGOALS (ProofContext.fact_tac facts);
     val prfx = class_prefix class;
   in
@@ -562,38 +555,78 @@
 
 (* class context syntax *)
 
-fun internal_remove_constraint base_sort (c, (_, (ty, _))) ctxt =
+structure ClassSyntax = ProofDataFun(
+  type T = {
+    constraints: (string * typ) list,
+    base_sort: sort,
+    local_operation: string * typ -> (typ * term) option,
+    rews: (term * term) list,
+    passed: bool
+  } option;
+  fun init _ = NONE;
+);
+
+fun synchronize_syntax thy sups base_sort ctxt =
   let
-    val ty' = ty
-      |> map_atyps (fn ty as TVar ((v, 0), _) =>
-           if v = Name.aT then TVar ((v, 0), base_sort) else ty)
-      |> SOME;
-  in ProofContext.add_const_constraint (c, ty') ctxt end;
+    val operations = these_operations thy sups;
+
+    (* constraints *)
+    fun local_constraint (c, (_, (ty, _))) =
+      let
+        val ty' = ty
+          |> map_atyps (fn ty as TVar ((v, 0), _) =>
+               if v = Name.aT then TVar ((v, 0), base_sort) else ty)
+          |> SOME;
+      in (c, ty') end
+    val constraints = (map o apsnd) (fst o snd) operations;
+
+    (* check phase *)
+    val typargs = Consts.typargs (ProofContext.consts_of ctxt);
+    fun check_const (c, ty) ((t, _), (_, idx)) =
+      ((nth (typargs (c, ty)) idx), t);
+    fun local_operation (c_ty as (c, _)) = AList.lookup (op =) operations c
+      |> Option.map (check_const c_ty);
 
-fun remove_constraint class c ctxt =
+    (* uncheck phase *)
+    val proto_rews = map (fn (c, ((t, ty), _)) => (t, Const (c, ty))) operations;
+    fun rew_app f (t1 $ t2) = rew_app f t1 $ f t2
+      | rew_app f t = t;
+    val rews = (map o apfst o rew_app)
+      (Pattern.rewrite_term thy proto_rews []) proto_rews;
+  in
+    ctxt
+    |> fold (ProofContext.add_const_constraint o local_constraint) operations
+    |> ClassSyntax.map (K (SOME {
+        constraints = constraints,
+        base_sort = base_sort,
+        local_operation = local_operation,
+        rews = rews,
+        passed = false
+      }))
+  end;
+
+fun refresh_syntax class ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     val base_sort = (#base_sort o the_class_data thy) class;
-    val SOME entry = local_operation thy [class] c;
-  in
-    internal_remove_constraint base_sort (c, entry) ctxt
-  end;
+  in synchronize_syntax thy [class] base_sort ctxt end;
 
-fun sort_term_check sups base_sort ts ctxt =
+val mark_passed = (ClassSyntax.map o Option.map)
+  (fn { constraints, base_sort, local_operation, rews, passed } =>
+    { constraints = constraints, base_sort = base_sort,
+      local_operation = local_operation, rews = rews, passed = true });
+
+fun sort_term_check ts ctxt =
   let
-    val thy = ProofContext.theory_of ctxt;
-    val local_operation = local_operation thy sups o fst;
-    val typargs = Consts.typargs (ProofContext.consts_of ctxt);
-    val constraints = these_operations thy sups |> (map o apsnd) (fst o snd);
-    fun check_typ (c, ty) (TFree (v, _)) t = if v = Name.aT
+    val { constraints, base_sort, local_operation, passed, ... } =
+      the (ClassSyntax.get ctxt);
+    fun check_typ (c, ty) (TFree (v, _), t) = if v = Name.aT
           then apfst (AList.update (op =) ((c, ty), t)) else I
-      | check_typ (c, ty) (TVar (vi, _)) t = if TypeInfer.is_param vi
+      | check_typ (c, ty) (TVar (vi, _), t) = if TypeInfer.is_param vi
           then apfst (AList.update (op =) ((c, ty), t))
             #> apsnd (insert (op =) vi) else I
-      | check_typ _ _ _ = I;
-    fun check_const (c, ty) ((t, _), (_, idx)) =
-      check_typ (c, ty) (nth (typargs (c, ty)) idx) t;
-    fun add_const (Const c_ty) = Option.map (check_const c_ty) (local_operation c_ty)
+      | check_typ _ _ = I;
+    fun add_const (Const c_ty) = Option.map (check_typ c_ty) (local_operation c_ty)
           |> the_default I
       | add_const _ = I;
     val (cs, typarams) = (fold o fold_aterms) add_const ts ([], []);
@@ -603,45 +636,41 @@
         (fn t as Const (c, ty) => the_default t (AList.lookup (op =) cs (c, ty)) | t => t)
           #> map_types subst_typ;
     val ts' = map subst_term ts;
-    val ctxt' = fold (ProofContext.add_const_constraint o apsnd SOME) constraints ctxt;
-  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt') end;
+  in if eq_list (op aconv) (ts, ts') andalso passed then NONE
+  else
+    ctxt
+    |> fold (ProofContext.add_const_constraint o apsnd SOME) constraints
+    |> mark_passed
+    |> pair ts'
+    |> SOME
+  end;
 
 val uncheck = ref true;
 
-fun sort_term_uncheck sups ts ctxt =
+fun sort_term_uncheck ts ctxt =
   let
     (*FIXME abbreviations*)
     val thy = ProofContext.theory_of ctxt;
-    fun rew_app f (t1 $ t2) = rew_app f t1 $ f t2
-      | rew_app f t = t;
-    val rews = map (fn (c, ((t, ty), _)) => (t, Const (c, ty))) (these_operations thy sups);
-    val rews' = (map o apfst o rew_app) (Pattern.rewrite_term thy rews []) rews;
-    val _ = map (Thm.cterm_of thy o Logic.mk_equals) rews';
+    val rews = (#rews o the o ClassSyntax.get) ctxt;
     val ts' = if ! uncheck
-      then map (Pattern.rewrite_term thy rews' []) ts else ts;
+      then map (Pattern.rewrite_term thy rews []) ts else ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
 
-fun init_class_ctxt sups base_sort ctxt =
-  let
-    val operations = these_operations (ProofContext.theory_of ctxt) sups;
-    fun standard_infer_types ts ctxt =
-      let
-        val ts' = ProofContext.standard_infer_types ctxt ts;
-      in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
-  in
-    ctxt
-    |> Variable.declare_term
-        (Logic.mk_type (TFree (Name.aT, base_sort)))
-    |> fold (internal_remove_constraint base_sort) operations
-    |> Context.proof_map (Syntax.add_term_check 1 "class"
-            (sort_term_check sups base_sort)
-        #> Syntax.add_term_check 1 "standard" standard_infer_types
-        #> Syntax.add_term_uncheck (~10) "class" (sort_term_uncheck sups))
-  end;
+fun init_ctxt thy sups base_sort ctxt =
+  ctxt
+  |> Variable.declare_term
+      (Logic.mk_type (TFree (Name.aT, base_sort)))
+  |> synchronize_syntax thy sups base_sort
+  |> Context.proof_map (
+      Syntax.add_term_check 0 "class" sort_term_check
+      #> Syntax.add_term_uncheck (~10) "class" sort_term_uncheck)
 
 fun init class ctxt =
-  init_class_ctxt [class]
-    ((#base_sort o the_class_data (ProofContext.theory_of ctxt)) class) ctxt;
+  let
+    val thy = ProofContext.theory_of ctxt;
+  in
+    init_ctxt thy [class] ((#base_sort o the_class_data thy) class) ctxt
+  end;
 
 
 (* class definition *)
@@ -667,7 +696,7 @@
     ProofContext.init thy
     |> Locale.cert_expr supexpr [constrain]
     |> snd
-    |> init_class_ctxt sups base_sort
+    |> init_ctxt thy sups base_sort
     |> process_expr Locale.empty raw_elems
     |> fst
     |> (fn elems => ((((sups, supconsts), (supsort, base_sort, mergeexpr)),
@@ -681,7 +710,8 @@
   let
     val superclasses = map (Sign.certify_class thy) raw_superclasses;
     val consts = (map o apfst o apsnd) (Sign.certify_typ thy) raw_consts;
-    fun add_const ((c, ty), syn) = Sign.declare_const [] (c, ty, syn) #>> Term.dest_Const;
+    fun add_const ((c, ty), syn) =
+      Sign.declare_const [] (c, Type.strip_sorts ty, syn) #>> Term.dest_Const;
     fun mk_axioms cs thy =
       raw_dep_axioms thy cs
       |> (map o apsnd o map) (Sign.cert_prop thy)
@@ -773,36 +803,37 @@
 
 (* definition in class target *)
 
-fun add_const_in_class class ((c, mx), rhs) thy =
+fun add_logical_const class ((c, mx), dict) thy =
   let
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
     val phi = morphism thy' class;
 
     val c' = Sign.full_name thy' c;
-    val rhs' = (map_types Logic.unvarifyT o Morphism.term phi) rhs;
-    val ty' = Term.fastype_of rhs';
-    val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
+    val dict' = (map_types Logic.unvarifyT o Morphism.term phi) dict;
+    val ty' = Term.fastype_of dict';
+    val ty'' = Type.strip_sorts ty';
+    val def_eq = Logic.mk_equals (Const (c', ty'), dict');
     val c'' = NameSpace.full (Sign.naming_of thy' |> NameSpace.add_path prfx) c;
   in
     thy'
     |> Sign.hide_consts_i false [c'']
-    |> Sign.declare_const [] (c, ty', mx) |> snd
+    |> Sign.declare_const [] (c, ty'', mx) |> snd
     |> Sign.parent_path
     |> Sign.sticky_prefix prfx
-    |> yield_singleton (PureThy.add_defs_i false) (def, [])
+    |> Thm.add_def false (c, def_eq)
     |>> Thm.symmetric
-    |-> (fn def => class_interpretation class [def]
-                [(map_types Logic.unvarifyT o Thm.prop_of) def]
-          #> register_operation class ((c', rhs), SOME def))
+    |-> (fn def => class_interpretation class [def] [Thm.prop_of def]
+          #> register_operation class ((c', dict), SOME (Thm.varifyT def)))
     |> Sign.restore_naming thy
+    |> Sign.add_const_constraint (c', SOME ty')
     |> pair c'
   end;
 
 
 (* abbreviation in class target *)
 
-fun add_abbrev_in_class class prmode ((c, mx), rhs) thy =
+fun add_syntactic_const class prmode ((c, mx), rhs) thy =
   let
     val prfx = class_prefix class;
     val phi = morphism thy class;
--- a/src/Pure/Isar/theory_target.ML	Thu Oct 18 16:09:36 2007 +0200
+++ b/src/Pure/Isar/theory_target.ML	Thu Oct 18 16:09:38 2007 +0200
@@ -195,8 +195,9 @@
         val t = Term.list_comb (const, map Free xs);
       in (((c, mx12), t), thy') end;
     fun class_const ((c, _), _) ((_, (mx1, _)), t) =
-      LocalTheory.raw_theory_result (Class.add_const_in_class target ((c, mx1), t))
-      #-> LocalTheory.target o Class.remove_constraint target;
+      LocalTheory.raw_theory_result (Class.add_logical_const target ((c, mx1), t))
+      #> snd
+      #> LocalTheory.target (Class.refresh_syntax target);
 
     val (abbrs, lthy') = lthy
       |> LocalTheory.theory_result (fold_map const decls)
@@ -218,9 +219,10 @@
   |> LocalDefs.add_def ((c, NoSyn), t);
 
 fun class_abbrev target prmode ((c, mx), rhs) lthy = lthy   (* FIXME pos *)
-  |> LocalTheory.raw_theory_result (Class.add_abbrev_in_class target prmode
+  |> LocalTheory.raw_theory_result (Class.add_syntactic_const target prmode
       ((c, mx), rhs))
-  |-> LocalTheory.target o Class.remove_constraint target;
+  |> snd
+  |> LocalTheory.target (Class.refresh_syntax target);
 
 in