src/Pure/axclass.ML
changeset 19574 7c761751e998
parent 19531 89970e06351f
child 19585 70a1ce3b23ae
--- a/src/Pure/axclass.ML	Fri May 05 21:59:41 2006 +0200
+++ b/src/Pure/axclass.ML	Fri May 05 21:59:43 2006 +0200
@@ -29,7 +29,9 @@
   val axiomatize_classrel_i: (class * class) list -> theory -> theory
   val axiomatize_arity: xstring * string list * string -> theory -> theory
   val axiomatize_arity_i: arity -> theory -> theory
-  val of_sort: theory -> typ * sort -> thm list option
+  type cache
+  val cache: cache
+  val of_sort: theory -> typ * sort -> cache -> thm list * cache  (*exception Sorts.CLASS_ERROR*)
 end;
 
 structure AxClass: AX_CLASS =
@@ -78,26 +80,13 @@
 val classrelN = "classrel";
 val arityN = "arity";
 
-datatype instances = Instances of
- {classes: unit Graph.T,                 (*raw relation -- no closure!*)
-  classrel: ((class * class) * thm) list,
-  arities: ((class * sort list) * thm) list Symtab.table,
-  types: (class * thm) list Typtab.table};
-
-fun make_instances (classes, classrel, arities, types) =
-  Instances {classes = classes, classrel = classrel, arities = arities, types = types};
+type instances =
+  ((class * class) * thm) list *
+  ((class * sort list) * thm) list Symtab.table;
 
-fun map_instances f (Instances {classes, classrel, arities, types}) =
-  make_instances (f (classes, classrel, arities, types));
-
-fun merge_instances
-   (Instances {classes = classes1, classrel = classrel1, arities = arities1, types = types1},
-    Instances {classes = classes2, classrel = classrel2, arities = arities2, types = types2}) =
-  make_instances
-   (Graph.merge (K true) (classes1, classes2),
-    merge (eq_fst op =) (classrel1, classrel2),
-    Symtab.join (K (merge (eq_fst op =))) (arities1, arities2),
-    Typtab.join (K (merge (eq_fst op =))) (types1, types2));
+fun merge_instances ((classrel1, arities1): instances, (classrel2, arities2)) =
+ (merge (eq_fst op =) (classrel1, classrel2),
+  Symtab.join (K (merge (eq_fst op =))) (arities1, arities2));
 
 
 (* setup data *)
@@ -105,22 +94,24 @@
 structure AxClassData = TheoryDataFun
 (struct
   val name = "Pure/axclass";
-  type T = axclasses * instances ref;
-  val empty : T =
-    ((Symtab.empty, []), ref (make_instances (Graph.empty, [], Symtab.empty, Typtab.empty)));
-  fun copy (axclasses, ref instances) : T = (axclasses, ref instances);
-  val extend = copy;
-  fun merge pp ((axclasses1, ref instances1), (axclasses2, ref instances2)) =
-    (merge_axclasses pp (axclasses1, axclasses2), ref (merge_instances (instances1, instances2)));
+  type T = axclasses * instances;
+  val empty : T = ((Symtab.empty, []), ([], Symtab.empty));
+  val copy = I;
+  val extend = I;
+  fun merge pp ((axclasses1, instances1), (axclasses2, instances2)) =
+    (merge_axclasses pp (axclasses1, axclasses2), (merge_instances (instances1, instances2)));
   fun print _ _ = ();
 end);
 
 val _ = Context.add_setup AxClassData.init;
 
 
-(* retrieve axclasses *)
+(* maintain axclasses *)
 
-val lookup_def = Symtab.lookup o #1 o #1 o AxClassData.get;
+val get_axclasses = #1 o AxClassData.get;
+fun map_axclasses f = AxClassData.map (apfst f);
+
+val lookup_def = Symtab.lookup o #1 o get_axclasses;
 
 fun get_definition thy c =
   (case lookup_def thy c of
@@ -135,10 +126,8 @@
   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
 
 
-(* retrieve parameters *)
-
 fun get_params thy pred =
-  let val params = #2 (#1 (AxClassData.get thy))
+  let val params = #2 (get_axclasses thy);
   in fold (fn (x, c) => if pred c then cons x else I) params [] end;
 
 fun params_of thy c = get_params thy (fn c' => c' = c);
@@ -147,40 +136,33 @@
 
 (* maintain instances *)
 
-val get_instances = AxClassData.get #> (fn (_, ref (Instances insts)) => insts);
-
-fun lookup_classrel thy =
-  Option.map (Thm.transfer thy) o AList.lookup (op =) (#classrel (get_instances thy));
-
-fun lookup_arity thy =
-  Option.map (Thm.transfer thy) oo
-    (AList.lookup (op =) o Symtab.lookup_list (#arities (get_instances thy)));
-
-val lookup_type = AList.lookup (op =) oo (Typtab.lookup_list o #types o get_instances);
+val get_instances = #2 o AxClassData.get;
+fun map_instances f = AxClassData.map (apsnd f);
 
 
-fun store_instance f thy (x, th) =
-  (change (#2 (AxClassData.get thy)) (map_instances (f (x, th))); th);
+fun the_classrel thy (c1, c2) =
+  (case AList.lookup (op =) (#1 (get_instances thy)) (c1, c2) of
+    SOME th => Thm.transfer thy th
+  | NONE => error ("Unproven class relation " ^ Sign.string_of_classrel thy [c1, c2]));
+
+fun put_classrel arg = map_instances (fn (classrel, arities) =>
+  (insert (eq_fst op =) arg classrel, arities));
 
-val store_classrel = store_instance (fn ((c1, c2), th) => fn (classes, classrel, arities, types) =>
-  (classes
-    |> Graph.default_node (c1, ())
-    |> Graph.default_node (c2, ())
-    |> Graph.add_edge (c1, c2),
-    insert (eq_fst op =) ((c1, c2), th) classrel, arities, types));
 
-val store_arity = store_instance (fn ((t, Ss, c), th) => fn (classes, classrel, arities, types) =>
-  (classes, classrel, arities |> Symtab.insert_list (eq_fst op =) (t, ((c, Ss), th)), types));
+fun the_arity thy a (c, Ss) =
+  (case AList.lookup (op =) (Symtab.lookup_list (#2 (get_instances thy)) a) (c, Ss)  of
+    SOME th => Thm.transfer thy th
+  | NONE => error ("Unproven type arity " ^ Sign.string_of_arity thy (a, Ss, [c])));
 
-val store_type = store_instance (fn ((T, c), th) => fn (classes, classrel, arities, types) =>
-  (classes, classrel, arities, types |> Typtab.insert_list (eq_fst op =) (T, (c, th))));
+fun put_arity ((t, Ss, c), th) = map_instances (fn (classrel, arities) =>
+  (classrel, arities |> Symtab.insert_list (eq_fst op =) (t, ((c, Ss), th))));
 
 
 (* print data *)
 
 fun print_axclasses thy =
   let
-    val axclasses = #1 (#1 (AxClassData.get thy));
+    val axclasses = #1 (get_axclasses thy);
     val ctxt = ProofContext.init thy;
 
     fun pretty_axclass (class, AxClass {def, intro, axioms}) =
@@ -223,9 +205,11 @@
     val prop = Drule.plain_prop_of (Thm.transfer thy th);
     val rel = Logic.dest_classrel prop handle TERM _ => err ();
     val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
-    val thy' = thy |> Sign.primitive_classrel (c1, c2);
-    val _ = store_classrel thy' ((c1, c2), Drule.unconstrainTs th);
-  in thy' end;
+  in
+    thy
+    |> Sign.primitive_classrel (c1, c2)
+    |> put_classrel ((c1, c2), Drule.unconstrainTs th)
+  end;
 
 fun add_arity th thy =
   let
@@ -233,9 +217,11 @@
     val prop = Drule.plain_prop_of (Thm.transfer thy th);
     val (t, Ss, c) = Logic.dest_arity prop handle TERM _ => err ();
     val _ = if map (Sign.certify_sort thy) Ss = Ss then () else err ();
-    val thy' = thy |> Sign.primitive_arity (t, Ss, [c]);
-    val _ = store_arity thy' ((t, Ss, c), Drule.unconstrainTs th);
-  in thy' end;
+  in
+    thy
+    |> Sign.primitive_arity (t, Ss, [c])
+    |> put_arity ((t, Ss, c), Drule.unconstrainTs th)
+  end;
 
 
 (* tactical proofs *)
@@ -325,19 +311,19 @@
         [((introN, []), [([Drule.standard raw_intro], [])]),
          ((superN, []), [(map Drule.standard raw_classrel, [])]),
          ((axiomsN, []), [(map (fn th => Drule.standard (class_triv RS th)) raw_axioms, [])])];
-    val _ = map (store_classrel facts_thy) (map (pair class) super ~~ classrel);
 
 
     (* result *)
 
     val result_thy =
       facts_thy
+      |> fold put_classrel (map (pair class) super ~~ classrel)
       |> Sign.add_path bconst
       |> PureThy.note_thmss_i "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
       |> Sign.restore_naming facts_thy
-      |> AxClassData.map (apfst (fn (axclasses, parameters) =>
+      |> map_axclasses (fn (axclasses, parameters) =>
         (Symtab.update (class, make_axclass (def, intro, axioms)) axclasses,
-          fold (fn x => add_param pp (x, class)) params parameters)));
+          fold (fn x => add_param pp (x, class)) params parameters));
 
   in (class, result_thy) end;
 
@@ -389,27 +375,13 @@
 
 (** explicit derivations -- cached **)
 
+datatype cache = Types of (class * thm) list Typtab.table;
+val cache = Types Typtab.empty;
+
 local
 
-fun derive_classrel thy (th, c1) c2 =
-  let
-    fun derive [c, c'] = the (lookup_classrel thy (c, c'))
-      | derive (c :: c' :: cs) = derive [c, c'] RS derive (c' :: cs);
-    val th' =
-      (case lookup_classrel thy (c1, c2) of
-        SOME rule => rule
-      | NONE =>
-          (case Graph.find_paths (#classes (get_instances thy)) (c1, c2) of
-            [] => error ("Cannot derive class relation " ^ Sign.string_of_classrel thy [c1, c2])
-          | path :: _ => store_classrel thy ((c1, c2), derive path)))
-  in th RS th' end;
-
-fun derive_constructor thy a dom c =
-  let val Ss = map (map snd) dom and ths = maps (map fst) dom in
-    (case lookup_arity thy a (c, Ss) of
-      SOME rule => ths MRS rule
-    | NONE => error ("Cannot derive type arity " ^ Sign.string_of_arity thy (a, Ss, [c])))
-  end;
+fun lookup_type (Types cache) = AList.lookup (op =) o Typtab.lookup_list cache;
+fun insert_type T der (Types cache) = Types (Typtab.insert_list (eq_fst op =) (T, der) cache);
 
 fun derive_type _ (_, []) = []
   | derive_type thy (typ, sort) =
@@ -418,28 +390,35 @@
             (fn T as TFree (_, S) => insert (eq_fst op =) (T, S)
               | T as TVar (_, S) => insert (eq_fst op =) (T, S)
               | _ => I) typ [];
-        val hyps = vars |> map (fn (T, S) => (T, Drule.sort_triv thy (T, S) ~~ S));
+        val hyps = vars
+          |> map (fn (T, S) => (T, Drule.sort_triv thy (T, S) ~~ S));
         val ths = (typ, sort)
           |> Sorts.of_sort_derivation (Sign.pp thy) (Sign.classes_of thy, Sign.arities_of thy)
-            {classrel = derive_classrel thy,
-              constructor = derive_constructor thy,
-              variable = the_default [] o AList.lookup (op =) hyps};
-      in map (store_type thy) (map (pair typ) sort ~~ ths) end;
+           {classrel =
+              fn (th, c1) => fn c2 => th RS the_classrel thy (c1, c2),
+            constructor =
+              fn a => fn dom => fn c =>
+                let val Ss = map (map snd) dom and ths = maps (map fst) dom
+                in ths MRS the_arity thy a (c, Ss) end,
+            variable =
+              the_default [] o AList.lookup (op =) hyps};
+      in ths end;
 
 in
 
-fun of_sort thy (typ, sort) =
-  if Sign.of_sort thy (typ, sort) then
-    let
-      val cert = Thm.cterm_of thy;
-      fun derive c =
-        Goal.finish (the (lookup_type thy typ c) RS Goal.init (cert (Logic.mk_inclass (typ, c))))
-        |> Thm.adjust_maxidx_thm;
-      val _ = derive_type thy (typ, filter (is_none o lookup_type thy typ) sort)
-        handle ERROR msg => cat_error msg ("The error(s) above occurred for sort derivation: " ^
-          Sign.string_of_typ thy typ ^ " :: " ^ Sign.string_of_sort thy sort);
-    in SOME (map derive sort) end
-  else NONE;
+fun of_sort thy (typ, sort) cache =
+  let
+    val sort' = filter (is_none o lookup_type cache typ) sort;
+    val ths' = derive_type thy (typ, sort')
+      handle ERROR msg => cat_error msg ("The error(s) above occurred for sort derivation: " ^
+        Sign.string_of_typ thy typ ^ " :: " ^ Sign.string_of_sort thy sort');
+    val cache' = cache |> fold (insert_type typ) (sort' ~~ ths');
+    val ths =
+      sort |> map (fn c =>
+        Goal.finish (the (lookup_type cache' typ c) RS
+          Goal.init (Thm.cterm_of thy (Logic.mk_inclass (typ, c))))
+        |> Thm.adjust_maxidx_thm);
+  in (ths, cache') end;
 
 end;