src/Pure/axclass.ML
changeset 36329 85004134055c
parent 36328 4d9deabf6474
child 36330 0584e203960e
--- a/src/Pure/axclass.ML	Sun Apr 25 21:18:04 2010 +0200
+++ b/src/Pure/axclass.ML	Sun Apr 25 22:50:47 2010 +0200
@@ -44,6 +44,18 @@
 
 (** theory data **)
 
+(* axclass info *)
+
+type info =
+ {def: thm,
+  intro: thm,
+  axioms: thm list,
+  params: (string * typ) list};
+
+fun make_axclass (def, intro, axioms, params): info =
+  {def = def, intro = intro, axioms = axioms, params = params};
+
+
 (* class parameters (canonical order) *)
 
 type param = string * class;
@@ -55,85 +67,109 @@
       " for " ^ Pretty.string_of_sort pp [c] ^
       (if c = c' then "" else " and " ^ Pretty.string_of_sort pp [c'])));
 
-fun merge_params _ ([], qs) = qs
-  | merge_params pp (ps, qs) =
-      fold_rev (fn q => if member (op =) ps q then I else add_param pp q) qs ps;
-
-
-(* axclass info *)
-
-type info =
- {def: thm,
-  intro: thm,
-  axioms: thm list,
-  params: (string * typ) list};
-
-type axclasses = info Symtab.table * param list;
-
-fun make_axclass ((def, intro, axioms), params): info =
-  {def = def, intro = intro, axioms = axioms, params = params};
-
-fun merge_axclasses pp ((tab1, params1), (tab2, params2)) : axclasses =
-  (Symtab.merge (K true) (tab1, tab2), merge_params pp (params1, params2));
-
-
-(* instances *)
-
-val classrel_prefix = "classrel_";
-val arity_prefix = "arity_";
-
-type instances =
-  (thm * proof) Symreltab.table *  (*classrel theorems*)
-  ((class * sort list) * ((thm * string) * proof)) list Symtab.table;  (*arity theorems with theory name*)
-
-(*transitive closure of classrels and arity completion is done in Theory.at_begin hook*)
-fun merge_instances ((classrel1, arities1): instances, (classrel2, arities2)) =
- (Symreltab.join (K fst) (classrel1, classrel2),
-  Symtab.join (K (merge (eq_fst op =))) (arities1, arities2));
-
-
-(* instance parameters *)
-
-type inst_params =
-  (string * thm) Symtab.table Symtab.table
-    (*constant name ~> type constructor ~> (constant name, equation)*)
-  * (string * string) Symtab.table; (*constant name ~> (constant name, type constructor)*)
-
-fun merge_inst_params ((const_param1, param_const1), (const_param2, param_const2)) =
-  (Symtab.join  (K (Symtab.merge (K true))) (const_param1, const_param2),
-    Symtab.merge (K true) (param_const1, param_const2));
-
 
 (* setup data *)
 
+datatype data = Data of
+ {axclasses: info Symtab.table,
+  params: param list,
+  proven_classrels: (thm * proof) Symreltab.table,
+  proven_arities: ((class * sort list) * ((thm * string) * proof)) list Symtab.table,
+    (*arity theorems with theory name*)
+  inst_params:
+    (string * thm) Symtab.table Symtab.table *
+      (*constant name ~> type constructor ~> (constant name, equation)*)
+    (string * string) Symtab.table (*constant name ~> (constant name, type constructor)*),
+  diff_merge_classrels: (class * class) list};
+
+fun make_data
+    (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =
+  Data {axclasses = axclasses, params = params, proven_classrels = proven_classrels,
+    proven_arities = proven_arities, inst_params = inst_params,
+    diff_merge_classrels = diff_merge_classrels};
+
 structure Data = Theory_Data_PP
 (
-  type T = axclasses * ((instances * inst_params) * (class * class) list);
-  val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
+  type T = data;
+  val empty =
+    make_data (Symtab.empty, [], Symreltab.empty, Symtab.empty, (Symtab.empty, Symtab.empty), []);
   val extend = I;
-  fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
-    (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
+  fun merge pp
+      (Data {axclasses = axclasses1, params = params1, proven_classrels = proven_classrels1,
+        proven_arities = proven_arities1, inst_params = inst_params1,
+        diff_merge_classrels = diff_merge_classrels1},
+       Data {axclasses = axclasses2, params = params2, proven_classrels = proven_classrels2,
+        proven_arities = proven_arities2, inst_params = inst_params2,
+        diff_merge_classrels = diff_merge_classrels2}) =
     let
-      val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2);
-      val diff_merge_classrels =
+      val axclasses' = Symtab.merge (K true) (axclasses1, axclasses2);
+      val params' =
+        if null params1 then params2
+        else fold_rev (fn q => if member (op =) params1 q then I else add_param pp q) params2 params1;
+
+      (*transitive closure of classrels and arity completion is done in Theory.at_begin hook*)
+      val proven_classrels' = Symreltab.join (K #1) (proven_classrels1, proven_classrels2);
+      val proven_arities' =
+        Symtab.join (K (Library.merge (eq_fst op =))) (proven_arities1, proven_arities2);
+
+      val classrels1 = Symreltab.keys proven_classrels1;
+      val classrels2 = Symreltab.keys proven_classrels2;
+      val diff_merge_classrels' =
         subtract (op =) classrels1 classrels2 @
         subtract (op =) classrels2 classrels1 @
         diff_merge_classrels1 @ diff_merge_classrels2;
+
+      val inst_params' =
+        (Symtab.join (K (Symtab.merge (K true))) (#1 inst_params1, #1 inst_params2),
+          Symtab.merge (K true) (#2 inst_params1, #2 inst_params2));
     in
-      (merge_axclasses pp (axclasses1, axclasses2),
-        ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
-          diff_merge_classrels))
+      make_data (axclasses', params', proven_classrels', proven_arities', inst_params',
+        diff_merge_classrels')
     end;
 );
 
+fun map_data f =
+  Data.map (fn Data {axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels} =>
+    make_data (f (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels)));
+
+fun map_axclasses f =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
+    (f axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+
+fun map_params f =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
+    (axclasses, f params, proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+
+fun map_proven_classrels f =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
+    (axclasses, params, f proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+
+fun map_proven_arities f =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
+    (axclasses, params, proven_classrels, f proven_arities, inst_params, diff_merge_classrels));
+
+fun map_inst_params f =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
+    (axclasses, params, proven_classrels, proven_arities, f inst_params, diff_merge_classrels));
+
+val clear_diff_merge_classrels =
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, _) =>
+    (axclasses, params, proven_classrels, proven_arities, inst_params, []));
+
+val rep_data = Data.get #> (fn Data args => args);
+
+val axclasses_of = #axclasses o rep_data;
+val params_of = #params o rep_data;
+val proven_classrels_of = #proven_classrels o rep_data;
+val proven_arities_of = #proven_arities o rep_data;
+val inst_params_of = #inst_params o rep_data;
+val diff_merge_classrels_of = #diff_merge_classrels o rep_data;
+
 
 (* maintain axclasses *)
 
-val get_axclasses = #1 o Data.get;
-val map_axclasses = Data.map o apfst;
-
 fun get_info thy c =
-  (case Symtab.lookup (#1 (get_axclasses thy)) c of
+  (case Symtab.lookup (axclasses_of thy) c of
     SOME info => info
   | NONE => error ("No such axclass: " ^ quote c));
 
@@ -143,40 +179,40 @@
     val classes = Sign.all_classes thy;
   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
 
+
+(* maintain params *)
+
 fun all_params_of thy S =
-  let val params = #2 (get_axclasses thy);
+  let val params = params_of thy;
   in fold (fn (x, c) => if Sign.subsort thy (S, [c]) then cons x else I) params [] end;
 
-fun class_of_param thy = AList.lookup (op =) (#2 (get_axclasses thy));
+fun class_of_param thy = AList.lookup (op =) (params_of thy);
 
 
 (* maintain instances *)
 
-fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
+val classrel_prefix = "classrel_";
+val arity_prefix = "arity_";
 
-val get_instances = #1 o #1 o #2 o Data.get;
-val map_instances = Data.map o apsnd o apfst o apfst;
-
-val get_diff_merge_classrels = #2 o #2 o Data.get;
-val clear_diff_merge_classrels = Data.map (apsnd (apsnd (K [])));
+fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
 
 
 fun the_classrel thy (c1, c2) =
-  (case Symreltab.lookup (#1 (get_instances thy)) (c1, c2) of
+  (case Symreltab.lookup (proven_classrels_of thy) (c1, c2) of
     SOME classrel => classrel
   | NONE => error ("Unproven class relation " ^
       Syntax.string_of_classrel (ProofContext.init thy) [c1, c2]));
 
-fun the_classrel_thm thy = Thm.transfer thy o fst o the_classrel thy;
-fun the_classrel_prf thy = snd o the_classrel thy;
+fun the_classrel_thm thy = Thm.transfer thy o #1 o the_classrel thy;
+fun the_classrel_prf thy = #2 o the_classrel thy;
 
 fun put_trancl_classrel ((c1, c2), th) thy =
   let
     val cert = Thm.cterm_of thy;
     val certT = Thm.ctyp_of thy;
 
-    val classrels = fst (get_instances thy);
     val classes = Sorts.classes_of (Sign.classes_of thy);
+    val classrels = proven_classrels_of thy;
 
     fun reflcl_classrel (c1', c2') =
       if c1' = c2'
@@ -187,7 +223,7 @@
         val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
           |> Drule.instantiate' [SOME (certT (TVar ((Name.aT, 0), [])))] []
           |> Thm.close_derivation;
-        val prf' = th' |> Thm.proof_of;
+        val prf' = Thm.proof_of th';
       in ((c1_pred, c2_succ), (th', prf')) end;
 
     val new_classrels =
@@ -197,38 +233,36 @@
     val needed = not (null new_classrels);
   in
     (needed,
-     if needed then
-       thy |> map_instances (fn (classrels, arities) =>
-         (classrels |> fold Symreltab.update new_classrels, arities))
-     else thy)
+      if needed then map_proven_classrels (fold Symreltab.update new_classrels) thy
+      else thy)
   end;
 
 fun complete_classrels thy =
   let
-    val diff_merge_classrels = get_diff_merge_classrels thy;
-    val classrels = fst (get_instances thy);
+    val classrels = proven_classrels_of thy;
+    val diff_merge_classrels = diff_merge_classrels_of thy;
     val (needed, thy') = (false, thy) |>
       fold (fn c12 => fn (needed, thy) =>
-          put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
+          put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> #1) thy
           |>> (fn b => needed orelse b))
         diff_merge_classrels;
   in
     if null diff_merge_classrels then NONE
-    else thy' |> clear_diff_merge_classrels |> SOME
+    else SOME (clear_diff_merge_classrels thy')
   end;
 
 
 fun the_arity thy a (c, Ss) =
-  (case AList.lookup (op =) (Symtab.lookup_list (#2 (get_instances thy)) a) (c, Ss) of
+  (case AList.lookup (op =) (Symtab.lookup_list (proven_arities_of thy) a) (c, Ss) of
     SOME arity => arity
   | NONE => error ("Unproven type arity " ^
       Syntax.string_of_arity (ProofContext.init thy) (a, Ss, [c])));
 
-fun the_arity_thm thy a c_Ss = the_arity thy a c_Ss |> fst |> fst |> Thm.transfer thy;
-fun the_arity_prf thy a c_Ss = the_arity thy a c_Ss |> snd;
+fun the_arity_thm thy a c_Ss = the_arity thy a c_Ss |> #1 |> #1 |> Thm.transfer thy;
+fun the_arity_prf thy a c_Ss = the_arity thy a c_Ss |> #2;
 
 fun thynames_of_arity thy (c, a) =
-  Symtab.lookup_list (#2 (get_instances thy)) a
+  Symtab.lookup_list (proven_arities_of thy) a
   |> map_filter (fn ((c', _), ((_, name),_)) => if c = c' then SOME name else NONE)
   |> rev;
 
@@ -256,34 +290,30 @@
     val arity' = (t, ((c, Ss), ((th, Context.theory_name thy), Thm.proof_of th)));
   in
     thy
-    |> map_instances (fn (classrel, arities) => (classrel,
-      arities
-      |> Symtab.insert_list (eq_fst op =) arity'
-      |> insert_arity_completions thy arity'
-      |> snd))
+    |> map_proven_arities
+      (Symtab.insert_list (eq_fst op =) arity' #>
+        insert_arity_completions thy arity' #> snd)
   end;
 
 fun complete_arities thy =
   let
-    val arities = snd (get_instances thy);
+    val arities = proven_arities_of thy;
     val (finished, arities') = arities
       |> fold_map (insert_arity_completions thy) (Symtab.dest_list arities);
   in
-    if forall I finished then NONE
-    else SOME (thy |> map_instances (fn (classrel, _) => (classrel, arities')))
+    if forall I finished
+    then NONE
+    else SOME (map_proven_arities (K arities') thy)
   end;
 
 val _ = Context.>> (Context.map_theory
-  (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities))
+  (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities));
 
 
 (* maintain instance parameters *)
 
-val get_inst_params = #2 o #1 o #2 o Data.get;
-val map_inst_params = Data.map o apsnd o apfst o apsnd;
-
 fun get_inst_param thy (c, tyco) =
-  (case Symtab.lookup (the_default Symtab.empty (Symtab.lookup (#1 (get_inst_params thy)) c)) tyco of
+  (case Symtab.lookup (the_default Symtab.empty (Symtab.lookup (#1 (inst_params_of thy)) c)) tyco of
     SOME c' => c'
   | NONE => error ("No instance parameter for constant " ^ quote c ^ " on type " ^ quote tyco));
 
@@ -291,11 +321,11 @@
   (map_inst_params o apfst o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
   #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
 
-val inst_of_param = Symtab.lookup o snd o get_inst_params;
+val inst_of_param = Symtab.lookup o #2 o inst_params_of;
 val param_of_inst = fst oo get_inst_param;
 
 fun inst_thms thy =
-  (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst) (get_inst_params thy) [];
+  Symtab.fold (Symtab.fold (cons o #2 o #2) o #2) (#1 (inst_params_of thy)) [];
 
 fun get_inst_tyco consts = try (fst o dest_Type o the_single o Consts.typargs consts);
 
@@ -339,8 +369,6 @@
   cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
     handle TYPE (msg, _, _) => error msg;
 
-val shyps_topped = forall null o #shyps o Thm.rep_thm;
-
 
 (* declaration and definition of instances of overloaded constants *)
 
@@ -392,6 +420,8 @@
 
 (* primitive rules *)
 
+val shyps_topped = forall null o #shyps o Thm.rep_thm;
+
 fun add_classrel raw_th thy =
   let
     val th = Thm.strip_shyps (Thm.transfer thy raw_th);
@@ -562,16 +592,15 @@
 
     (* result *)
 
-    val axclass = make_axclass ((def, intro, axioms), params);
+    val axclass = make_axclass (def, intro, axioms, params);
     val result_thy =
       facts_thy
       |> fold (snd oo put_trancl_classrel) (map (pair class) super ~~ classrel)
       |> Sign.qualified_path false bconst
       |> PureThy.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
       |> Sign.restore_naming facts_thy
-      |> map_axclasses (fn (axclasses, parameters) =>
-        (Symtab.update (class, axclass) axclasses,
-          fold (fn (x, _) => add_param pp (x, class)) params parameters));
+      |> map_axclasses (Symtab.update (class, axclass))
+      |> map_params (fold (fn (x, _) => add_param pp (x, class)) params);
 
   in (class, result_thy) end;