more precise and complete transitive closure of proven_classrel, using existing Sorts.classes_of which is already closed;
authorwenzelm
Mon, 07 Jan 2013 22:21:56 +0100
changeset 50764 2bbc7ae80634
parent 50763 e33921360f06
child 50765 ba79e2cb3cbe
more precise and complete transitive closure of proven_classrel, using existing Sorts.classes_of which is already closed; transfer theorems where they are picked from the theory; tuned;
src/Pure/axclass.ML
--- a/src/Pure/axclass.ML	Mon Jan 07 21:49:59 2013 +0100
+++ b/src/Pure/axclass.ML	Mon Jan 07 22:21:56 2013 +0100
@@ -78,31 +78,24 @@
   inst_params:
     (string * thm) Symtab.table Symtab.table *
       (*constant name ~> type constructor ~> (constant name, equation)*)
-    (string * string) Symtab.table (*constant name ~> (constant name, type constructor)*),
-  diff_classrels: (class * class) list};
+    (string * string) Symtab.table (*constant name ~> (constant name, type constructor)*)};
 
 fun make_data
-    (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =
+    (axclasses, params, proven_classrels, proven_arities, inst_params) =
   Data {axclasses = axclasses, params = params, proven_classrels = proven_classrels,
-    proven_arities = proven_arities, inst_params = inst_params,
-    diff_classrels = diff_classrels};
-
-fun diff_table tab1 tab2 =
-  Symreltab.fold (fn (x, _) => if Symreltab.defined tab2 x then I else cons x) tab1 [];
+    proven_arities = proven_arities, inst_params = inst_params};
 
 structure Data = Theory_Data_PP
 (
   type T = data;
   val empty =
-    make_data (Symtab.empty, [], Symreltab.empty, Symtab.empty, (Symtab.empty, Symtab.empty), []);
+    make_data (Symtab.empty, [], Symreltab.empty, Symtab.empty, (Symtab.empty, Symtab.empty));
   val extend = I;
   fun merge pp
       (Data {axclasses = axclasses1, params = params1, proven_classrels = proven_classrels1,
-        proven_arities = proven_arities1, inst_params = inst_params1,
-        diff_classrels = diff_classrels1},
+        proven_arities = proven_arities1, inst_params = inst_params1},
        Data {axclasses = axclasses2, params = params2, proven_classrels = proven_classrels2,
-        proven_arities = proven_arities2, inst_params = inst_params2,
-        diff_classrels = diff_classrels2}) =
+        proven_arities = proven_arities2, inst_params = inst_params2}) =
     let
       val ctxt = Syntax.init_pretty pp;
 
@@ -118,47 +111,37 @@
       val proven_arities' =
         Symtab.join (K (Library.merge (eq_fst op =))) (proven_arities1, proven_arities2);
 
-      val diff_classrels' =
-        diff_table proven_classrels1 proven_classrels2 @
-        diff_table proven_classrels2 proven_classrels1 @
-        diff_classrels1 @ diff_classrels2;
-
       val inst_params' =
         (Symtab.join (K (Symtab.merge (K true))) (#1 inst_params1, #1 inst_params2),
           Symtab.merge (K true) (#2 inst_params1, #2 inst_params2));
     in
-      make_data
-        (axclasses', params', proven_classrels', proven_arities', inst_params', diff_classrels')
+      make_data (axclasses', params', proven_classrels', proven_arities', inst_params')
     end;
 );
 
 fun map_data f =
-  Data.map (fn Data {axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels} =>
-    make_data (f (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels)));
+  Data.map (fn Data {axclasses, params, proven_classrels, proven_arities, inst_params} =>
+    make_data (f (axclasses, params, proven_classrels, proven_arities, inst_params)));
 
 fun map_axclasses f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
-    (f axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params) =>
+    (f axclasses, params, proven_classrels, proven_arities, inst_params));
 
 fun map_params f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
-    (axclasses, f params, proven_classrels, proven_arities, inst_params, diff_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params) =>
+    (axclasses, f params, proven_classrels, proven_arities, inst_params));
 
 fun map_proven_classrels f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
-    (axclasses, params, f proven_classrels, proven_arities, inst_params, diff_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params) =>
+    (axclasses, params, f proven_classrels, proven_arities, inst_params));
 
 fun map_proven_arities f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
-    (axclasses, params, proven_classrels, f proven_arities, inst_params, diff_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params) =>
+    (axclasses, params, proven_classrels, f proven_arities, inst_params));
 
 fun map_inst_params f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
-    (axclasses, params, proven_classrels, proven_arities, f inst_params, diff_classrels));
-
-val clear_diff_classrels =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, _) =>
-    (axclasses, params, proven_classrels, proven_arities, inst_params, []));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params) =>
+    (axclasses, params, proven_classrels, proven_arities, f inst_params));
 
 val rep_data = Data.get #> (fn Data args => args);
 
@@ -167,7 +150,6 @@
 val proven_classrels_of = #proven_classrels o rep_data;
 val proven_arities_of = #proven_arities o rep_data;
 val inst_params_of = #inst_params o rep_data;
-val diff_classrels_of = #diff_classrels o rep_data;
 
 
 (* axclasses with parameters *)
@@ -192,66 +174,47 @@
 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
 
 
-infix 0 RSO;
+val update_classrel = map_proven_classrels o Symreltab.update;
 
-fun (SOME a) RSO (SOME b) = SOME (a RS b)
-  | x RSO NONE = x
-  | NONE RSO y = y;
+val is_classrel = Symreltab.defined o proven_classrels_of;
 
 fun the_classrel thy (c1, c2) =
   (case Symreltab.lookup (proven_classrels_of thy) (c1, c2) of
-    SOME thm => thm
+    SOME thm => Thm.transfer thy thm
   | NONE => error ("Unproven class relation " ^
-      Syntax.string_of_classrel (Proof_Context.init_global thy) [c1, c2]));  (* FIXME stale thy (!?) *)
-
-fun put_trancl_classrel ((c1, c2), th) thy =
-  let
-    val classes = Sorts.classes_of (Sign.classes_of thy);
-    val classrels = proven_classrels_of thy;
-
-    fun reflcl_classrel (c1', c2') =
-      if c1' = c2' then NONE else SOME (Thm.transfer thy (the_classrel thy (c1', c2')));
-    fun gen_classrel (c1_pred, c2_succ) =
-      let
-        val th' =
-          the ((reflcl_classrel (c1_pred, c1) RSO SOME th) RSO reflcl_classrel (c2, c2_succ))
-          |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
-          |> Thm.close_derivation;
-      in ((c1_pred, c2_succ), th') end;
-
-    val new_classrels =
-      Library.map_product pair
-        (c1 :: Graph.immediate_preds classes c1)
-        (c2 :: Graph.immediate_succs classes c2)
-      |> filter_out ((op =) orf Symreltab.defined classrels)
-      |> map gen_classrel;
-    val needed = not (null new_classrels);
-  in
-    (needed,
-      if needed then map_proven_classrels (fold Symreltab.update new_classrels) thy
-      else thy)
-  end;
+      Syntax.string_of_classrel (Proof_Context.init_global thy) [c1, c2]));
 
 fun complete_classrels thy =
   let
-    val classrels = proven_classrels_of thy;
-    val diff_classrels = diff_classrels_of thy;
-    val (needed, thy') = (false, thy) |>
-      fold (fn rel => fn (needed, thy) =>
-          put_trancl_classrel (rel, Symreltab.lookup classrels rel |> the) thy
-          |>> (fn b => needed orelse b))
-        diff_classrels;
+    fun complete (c, (_, (all_preds, all_succs))) (finished1, thy1) =
+      let
+        val proven = is_classrel thy1;
+        val preds = Graph.Keys.fold (fn c1 => proven (c1, c) ? cons c1) all_preds [];
+        val succs = Graph.Keys.fold (fn c2 => proven (c, c2) ? cons c2) all_succs [];
+
+        fun complete c1 c2 (finished2, thy2) =
+          if is_classrel thy2 (c1, c2) then (finished2, thy2)
+          else
+            (false,
+              thy2
+              |> update_classrel ((c1, c2),
+                (the_classrel thy2 (c1, c) RS the_classrel thy2 (c, c2))
+                |> Drule.instantiate' [SOME (ctyp_of thy2 (TVar ((Name.aT, 0), [])))] []
+                |> Thm.close_derivation));
+      in fold_product complete preds succs (finished1, thy1) end;
+
+    val (finished', thy') =
+      Graph.fold complete (Sorts.classes_of (Sign.classes_of thy)) (true, thy);
   in
-    if null diff_classrels then NONE
-    else SOME (clear_diff_classrels thy')
+    if finished' then NONE else SOME thy'
   end;
 
 
 fun the_arity thy (a, Ss, c) =
   (case AList.lookup (op =) (Symtab.lookup_list (proven_arities_of thy) a) (c, Ss) of
-    SOME (thm, _) => thm
+    SOME (thm, _) => Thm.transfer thy thm
   | NONE => error ("Unproven type arity " ^
-      Syntax.string_of_arity (Proof_Context.init_global thy) (a, Ss, [c])));  (* FIXME stale thy (!?) *)
+      Syntax.string_of_arity (Proof_Context.init_global thy) (a, Ss, [c])));
 
 fun thynames_of_arity thy (c, a) =
   Symtab.lookup_list (proven_arities_of thy) a
@@ -273,7 +236,7 @@
     val completions = super_class_completions |> map (fn c1 =>
       let
         val th1 =
-          (th RS Thm.transfer thy (the_classrel thy (c, c1)))
+          (th RS the_classrel thy (c, c1))
           |> Drule.instantiate' std_vars []
           |> Thm.close_derivation;
       in ((th1, thy_name), c1) end);
@@ -282,7 +245,7 @@
     val arities' = fold (fn (th, c1) => Symtab.cons_list (t, ((c1, Ss), th))) completions arities;
   in (finished', arities') end;
 
-fun put_arity ((t, Ss, c), th) thy =
+fun put_arity_completion ((t, Ss, c), th) thy =
   let val ar = ((c, Ss), (th, Context.theory_name thy)) in
     thy
     |> map_proven_arities
@@ -433,7 +396,8 @@
   in
     thy'
     |> Sign.primitive_classrel (c1, c2)
-    |> (#2 oo put_trancl_classrel) ((c1, c2), th'')
+    |> map_proven_classrels (Symreltab.update ((c1, c2), th''))
+    |> perhaps complete_classrels
     |> perhaps complete_arities
   end;
 
@@ -462,7 +426,7 @@
     thy'
     |> fold (#2 oo declare_overloaded) missing_params
     |> Sign.primitive_arity (t, Ss, [c])
-    |> put_arity ((t, Ss, c), th'')
+    |> put_arity_completion ((t, Ss, c), th'')
   end;
 
 
@@ -590,7 +554,9 @@
     val axclass = make_axclass (def, intro, axioms, params);
     val result_thy =
       facts_thy
-      |> fold (#2 oo put_trancl_classrel) (map (pair class) super ~~ classrel)
+      |> map_proven_classrels
+          (fold2 (fn c => fn th => Symreltab.update ((class, c), th)) super classrel)
+      |> perhaps complete_classrels
       |> Sign.qualified_path false bconst
       |> Global_Theory.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms))
       |> #2