src/Pure/axclass.ML
changeset 36325 8715343af626
parent 36106 19deea200358
child 36326 85d026788fce
--- a/src/Pure/axclass.ML	Sun Apr 25 16:10:05 2010 +0200
+++ b/src/Pure/axclass.ML	Sun Apr 25 19:09:37 2010 +0200
@@ -89,11 +89,12 @@
 val arity_prefix = "arity_";
 
 type instances =
-  ((class * class) * thm) list *  (*classrel theorems*)
-  ((class * sort list) * (thm * string)) list Symtab.table;  (*arity theorems with theory name*)
+  (thm * proof) Symreltab.table *  (*classrel theorems*)
+  ((class * sort list) * ((thm * string) * proof)) list Symtab.table;  (*arity theorems with theory name*)
 
+(*transitive closure of classrels and arity completion is done in Theory.at_begin hook*)
 fun merge_instances ((classrel1, arities1): instances, (classrel2, arities2)) =
- (merge (eq_fst op =) (classrel1, classrel2),
+ (Symreltab.join (K fst) (classrel1, classrel2),
   Symtab.join (K (merge (eq_fst op =))) (arities1, arities2));
 
 
@@ -113,12 +114,21 @@
 
 structure AxClassData = Theory_Data_PP
 (
-  type T = axclasses * (instances * inst_params);
-  val empty = ((Symtab.empty, []), (([], Symtab.empty), (Symtab.empty, Symtab.empty)));
+  type T = axclasses * ((instances * inst_params) * (class * class) list);
+  val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
   val extend = I;
-  fun merge pp ((axclasses1, (instances1, inst_params1)), (axclasses2, (instances2, inst_params2))) =
-    (merge_axclasses pp (axclasses1, axclasses2),
-      (merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)));
+  fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
+    (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
+    let
+      val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2)
+      val diff_merge_classrels = subtract (op =) classrels1 classrels2
+        @ subtract (op =) classrels2 classrels1
+        @ diff_merge_classrels1 @ diff_merge_classrels2
+    in
+      (merge_axclasses pp (axclasses1, axclasses2),
+        ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
+          diff_merge_classrels))
+    end;
 );
 
 
@@ -155,48 +165,103 @@
 
 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
 
-val get_instances = #1 o #2 o AxClassData.get;
-val map_instances = AxClassData.map o apsnd o apfst;
+val get_instances = #1 o #1 o #2 o AxClassData.get;
+val map_instances = AxClassData.map o apsnd o apfst o apfst;
+
+val get_diff_merge_classrels = #2 o #2 o AxClassData.get;
+val clear_diff_merge_classrels = AxClassData.map (apsnd (apsnd (K [])));
 
 
 fun the_classrel thy (c1, c2) =
-  (case AList.lookup (op =) (#1 (get_instances thy)) (c1, c2) of
-    SOME th => Thm.transfer thy th
+  (case Symreltab.lookup (#1 (get_instances thy)) (c1, c2) of
+    SOME classrel => classrel
   | NONE => error ("Unproven class relation " ^
       Syntax.string_of_classrel (ProofContext.init thy) [c1, c2]));
 
-fun put_classrel arg = map_instances (fn (classrel, arities) =>
-  (insert (eq_fst op =) arg classrel, arities));
+fun the_classrel_thm thy = Thm.transfer thy o fst o the_classrel thy;
+fun the_classrel_prf thy = snd o the_classrel thy;
+
+fun put_trancl_classrel ((c1, c2), th) thy =
+  let
+    val classrels = fst (get_instances thy)
+    val alg = Sign.classes_of thy
+    val {classes, ...} = alg |> Sorts.rep_algebra
+
+    fun reflcl_classrel (c1', c2') =
+      if c1' = c2' then Thm.trivial (Logic.mk_of_class (TVar(("'a",0),[]), c1') |> cterm_of thy)
+      else the_classrel_thm thy (c1', c2')
+    fun gen_classrel (c1_pred, c2_succ) =
+      let
+        val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
+          |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
+          |> Thm.close_derivation
+        val prf' = th' |> Thm.proof_of
+      in ((c1_pred, c2_succ), (th',prf')) end
+
+    val new_classrels = Library.map_product pair
+        (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
+      |> filter_out (Symreltab.defined classrels)
+      |> map gen_classrel
+    val needed = length new_classrels > 0
+  in
+    (needed,
+     if needed then
+       thy |> map_instances (fn (classrels, arities) =>
+         (classrels |> fold Symreltab.update new_classrels, arities))
+     else thy)
+  end;
+
+fun complete_classrels thy =
+  let
+    val diff_merge_classrels = get_diff_merge_classrels thy
+    val classrels = fst (get_instances thy)
+    val (needed, thy') = (false, thy) |>
+      fold (fn c12 => fn (needed, thy) =>
+          put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
+          |>> (fn b => needed orelse b))
+        diff_merge_classrels
+  in
+    if null diff_merge_classrels then NONE
+    else thy' |> clear_diff_merge_classrels |> SOME
+  end;
 
 
 fun the_arity thy a (c, Ss) =
   (case AList.lookup (op =) (Symtab.lookup_list (#2 (get_instances thy)) a) (c, Ss) of
-    SOME (th, _) => Thm.transfer thy th
+    SOME arity => arity
   | NONE => error ("Unproven type arity " ^
       Syntax.string_of_arity (ProofContext.init thy) (a, Ss, [c])));
 
+fun the_arity_thm thy a c_Ss = the_arity thy a c_Ss |> fst |> fst |> Thm.transfer thy;
+fun the_arity_prf thy a c_Ss = the_arity thy a c_Ss |> snd;
+
 fun thynames_of_arity thy (c, a) =
   Symtab.lookup_list (#2 (get_instances thy)) a
-  |> map_filter (fn ((c', _), (_, name)) => if c = c' then SOME name else NONE)
+  |> map_filter (fn ((c', _), ((_, name),_)) => if c = c' then SOME name else NONE)
   |> rev;
 
-fun insert_arity_completions thy (t, ((c, Ss), (th, thy_name))) arities =
+fun insert_arity_completions thy (t, ((c, Ss), ((th, thy_name), _))) arities =
   let
     val algebra = Sign.classes_of thy;
     val super_class_completions =
       Sign.super_classes thy c
       |> filter_out (fn c1 => exists (fn ((c2, Ss2), _) => c1 = c2
           andalso Sorts.sorts_le algebra (Ss2, Ss)) (Symtab.lookup_list arities t));
-    val completions = map (fn c1 => (Sorts.classrel_derivation algebra
-      (fn (th, c2) => fn c3 => th RS the_classrel thy (c2, c3)) (th, c) c1
-        |> Thm.close_derivation, c1)) super_class_completions;
-    val arities' = fold (fn (th1, c1) => Symtab.cons_list (t, ((c1, Ss), (th1, thy_name))))
+    val names_and_Ss = Name.names Name.context Name.aT (map (K []) Ss);
+    val completions = super_class_completions |> map (fn c1 =>
+      let
+        val th1 = (th RS the_classrel_thm thy (c, c1))
+          |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
+          |> Thm.close_derivation
+        val prf1 = Thm.proof_of th1
+      in (((th1,thy_name), prf1), c1) end)
+    val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
       completions arities;
   in (null completions, arities') end;
 
 fun put_arity ((t, Ss, c), th) thy =
   let
-    val arity' = (t, ((c, Ss), (th, Context.theory_name thy)));
+    val arity' = (t, ((c, Ss), ((th, Context.theory_name thy), Thm.proof_of th)));
   in
     thy
     |> map_instances (fn (classrel, arities) => (classrel,
@@ -216,13 +281,14 @@
     else SOME (thy |> map_instances (fn (classrel, _) => (classrel, arities')))
   end;
 
-val _ = Context.>> (Context.map_theory (Theory.at_begin complete_arities));
+val _ = Context.>> (Context.map_theory
+  (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities))
 
 
 (* maintain instance parameters *)
 
-val get_inst_params = #2 o #2 o AxClassData.get;
-val map_inst_params = AxClassData.map o apsnd o apsnd;
+val get_inst_params = #2 o #1 o #2 o AxClassData.get;
+val map_inst_params = AxClassData.map o apsnd o apfst o apsnd;
 
 fun get_inst_param thy (c, tyco) =
   case Symtab.lookup ((the_default Symtab.empty o Symtab.lookup (fst (get_inst_params thy))) c) tyco
@@ -280,6 +346,11 @@
   cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
     handle TYPE (msg, _, _) => error msg;
 
+fun check_shyps_topped th errmsg =
+  let val {shyps, ...} = Thm.rep_thm th
+  in
+    forall null shyps orelse raise Fail errmsg
+  end;
 
 (* declaration and definition of instances of overloaded constants *)
 
@@ -338,10 +409,14 @@
     fun err () = raise THM ("add_classrel: malformed class relation", 0, [th]);
     val rel = Logic.dest_classrel prop handle TERM _ => err ();
     val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
+    val th' = th
+      |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [c1])))] []
+      |> Drule.unconstrainTs;
+    val _ = check_shyps_topped th' "add_classrel: nontop shyps after unconstrain"
   in
     thy
     |> Sign.primitive_classrel (c1, c2)
-    |> put_classrel ((c1, c2), Thm.close_derivation (Drule.unconstrainTs th))
+    |> (snd oo put_trancl_classrel) ((c1, c2), th')
     |> perhaps complete_arities
   end;
 
@@ -351,17 +426,22 @@
     val prop = Thm.plain_prop_of th;
     fun err () = raise THM ("add_arity: malformed type arity", 0, [th]);
     val (t, Ss, c) = Logic.dest_arity prop handle TERM _ => err ();
-    val T = Type (t, map TFree (Name.names Name.context Name.aT Ss));
+    val names = Name.names Name.context Name.aT Ss;
+    val T = Type (t, map TFree names);
     val missing_params = Sign.complete_sort thy [c]
       |> maps (these o Option.map #params o try (get_info thy))
       |> filter_out (fn (const, _) => can (get_inst_param thy) (const, t))
       |> (map o apsnd o map_atyps) (K T);
     val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
+    val th' = th
+      |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
+      |> Drule.unconstrainTs;
+    val _ = check_shyps_topped th' "add_arity: nontop shyps after unconstrain"
   in
     thy
     |> fold (snd oo declare_overloaded) missing_params
     |> Sign.primitive_arity (t, Ss, [c])
-    |> put_arity ((t, Ss, c), Thm.close_derivation (Drule.unconstrainTs th))
+    |> put_arity ((t, Ss, c), th')
   end;
 
 
@@ -495,7 +575,7 @@
     val axclass = make_axclass ((def, intro, axioms), params);
     val result_thy =
       facts_thy
-      |> fold put_classrel (map (pair class) super ~~ classrel)
+      |> fold (snd oo put_trancl_classrel) (map (pair class) super ~~ classrel)
       |> Sign.qualified_path false bconst
       |> PureThy.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
       |> Sign.restore_naming facts_thy