merged
authorwenzelm
Tue, 27 Apr 2010 16:24:57 +0200
changeset 36426 cc8db7295249
parent 36425 a0297b98728c (current diff)
parent 36421 066e35d1c0d7 (diff)
child 36427 85bc9b7c4d18
child 36444 027879c5637d
merged
--- a/src/Pure/axclass.ML	Tue Apr 27 12:20:17 2010 +0200
+++ b/src/Pure/axclass.ML	Tue Apr 27 16:24:57 2010 +0200
@@ -2,7 +2,7 @@
     Author:     Markus Wenzel, TU Muenchen
 
 Type classes defined as predicates, associated with a record of
-parameters.
+parameters.  Proven class relations and type arities.
 *)
 
 signature AX_CLASS =
@@ -72,20 +72,23 @@
 datatype data = Data of
  {axclasses: info Symtab.table,
   params: param list,
-  proven_classrels: (thm * proof) Symreltab.table,
-  proven_arities: ((class * sort list) * ((thm * string) * proof)) list Symtab.table,
+  proven_classrels: thm Symreltab.table,
+  proven_arities: ((class * sort list) * (thm * string)) list Symtab.table,
     (*arity theorems with theory name*)
   inst_params:
     (string * thm) Symtab.table Symtab.table *
       (*constant name ~> type constructor ~> (constant name, equation)*)
     (string * string) Symtab.table (*constant name ~> (constant name, type constructor)*),
-  diff_merge_classrels: (class * class) list};
+  diff_classrels: (class * class) list};
 
 fun make_data
-    (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =
+    (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =
   Data {axclasses = axclasses, params = params, proven_classrels = proven_classrels,
     proven_arities = proven_arities, inst_params = inst_params,
-    diff_merge_classrels = diff_merge_classrels};
+    diff_classrels = diff_classrels};
+
+fun diff_table tab1 tab2 =
+  Symreltab.fold (fn (x, _) => if Symreltab.defined tab2 x then I else cons x) tab1 [];
 
 structure Data = Theory_Data_PP
 (
@@ -96,62 +99,60 @@
   fun merge pp
       (Data {axclasses = axclasses1, params = params1, proven_classrels = proven_classrels1,
         proven_arities = proven_arities1, inst_params = inst_params1,
-        diff_merge_classrels = diff_merge_classrels1},
+        diff_classrels = diff_classrels1},
        Data {axclasses = axclasses2, params = params2, proven_classrels = proven_classrels2,
         proven_arities = proven_arities2, inst_params = inst_params2,
-        diff_merge_classrels = diff_merge_classrels2}) =
+        diff_classrels = diff_classrels2}) =
     let
       val axclasses' = Symtab.merge (K true) (axclasses1, axclasses2);
       val params' =
         if null params1 then params2
-        else fold_rev (fn q => if member (op =) params1 q then I else add_param pp q) params2 params1;
+        else fold_rev (fn p => if member (op =) params1 p then I else add_param pp p) params2 params1;
 
       (*transitive closure of classrels and arity completion is done in Theory.at_begin hook*)
       val proven_classrels' = Symreltab.join (K #1) (proven_classrels1, proven_classrels2);
       val proven_arities' =
         Symtab.join (K (Library.merge (eq_fst op =))) (proven_arities1, proven_arities2);
 
-      val classrels1 = Symreltab.keys proven_classrels1;
-      val classrels2 = Symreltab.keys proven_classrels2;
-      val diff_merge_classrels' =
-        subtract (op =) classrels1 classrels2 @
-        subtract (op =) classrels2 classrels1 @
-        diff_merge_classrels1 @ diff_merge_classrels2;
+      val diff_classrels' =
+        diff_table proven_classrels1 proven_classrels2 @
+        diff_table proven_classrels2 proven_classrels1 @
+        diff_classrels1 @ diff_classrels2;
 
       val inst_params' =
         (Symtab.join (K (Symtab.merge (K true))) (#1 inst_params1, #1 inst_params2),
           Symtab.merge (K true) (#2 inst_params1, #2 inst_params2));
     in
-      make_data (axclasses', params', proven_classrels', proven_arities', inst_params',
-        diff_merge_classrels')
+      make_data
+        (axclasses', params', proven_classrels', proven_arities', inst_params', diff_classrels')
     end;
 );
 
 fun map_data f =
-  Data.map (fn Data {axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels} =>
-    make_data (f (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels)));
+  Data.map (fn Data {axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels} =>
+    make_data (f (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels)));
 
 fun map_axclasses f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
-    (f axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
+    (f axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels));
 
 fun map_params f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
-    (axclasses, f params, proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
+    (axclasses, f params, proven_classrels, proven_arities, inst_params, diff_classrels));
 
 fun map_proven_classrels f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
-    (axclasses, params, f proven_classrels, proven_arities, inst_params, diff_merge_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
+    (axclasses, params, f proven_classrels, proven_arities, inst_params, diff_classrels));
 
 fun map_proven_arities f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
-    (axclasses, params, proven_classrels, f proven_arities, inst_params, diff_merge_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
+    (axclasses, params, proven_classrels, f proven_arities, inst_params, diff_classrels));
 
 fun map_inst_params f =
-  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_merge_classrels) =>
-    (axclasses, params, proven_classrels, proven_arities, f inst_params, diff_merge_classrels));
+  map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, diff_classrels) =>
+    (axclasses, params, proven_classrels, proven_arities, f inst_params, diff_classrels));
 
-val clear_diff_merge_classrels =
+val clear_diff_classrels =
   map_data (fn (axclasses, params, proven_classrels, proven_arities, inst_params, _) =>
     (axclasses, params, proven_classrels, proven_arities, inst_params, []));
 
@@ -162,7 +163,7 @@
 val proven_classrels_of = #proven_classrels o rep_data;
 val proven_arities_of = #proven_arities o rep_data;
 val inst_params_of = #inst_params o rep_data;
-val diff_merge_classrels_of = #diff_merge_classrels o rep_data;
+val diff_classrels_of = #diff_classrels o rep_data;
 
 
 (* axclasses with parameters *)
@@ -187,38 +188,36 @@
 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
 
 
+infix 0 RSO;
+
+fun (SOME a RSO SOME b) = SOME (a RS b)
+  | (x RSO NONE) = x
+  | (NONE RSO y) = y;
+
 fun the_classrel thy (c1, c2) =
   (case Symreltab.lookup (proven_classrels_of thy) (c1, c2) of
-    SOME classrel => classrel
+    SOME thm => Thm.transfer thy thm
   | NONE => error ("Unproven class relation " ^
       Syntax.string_of_classrel (ProofContext.init thy) [c1, c2]));
 
-fun the_classrel_thm thy = Thm.transfer thy o #1 o the_classrel thy;
-fun the_classrel_prf thy = #2 o the_classrel thy;
-
 fun put_trancl_classrel ((c1, c2), th) thy =
   let
-    val cert = Thm.cterm_of thy;
-    val certT = Thm.ctyp_of thy;
-
     val classes = Sorts.classes_of (Sign.classes_of thy);
     val classrels = proven_classrels_of thy;
 
     fun reflcl_classrel (c1', c2') =
-      if c1' = c2'
-      then Thm.trivial (cert (Logic.mk_of_class (TVar ((Name.aT, 0), []), c1')))
-      else the_classrel_thm thy (c1', c2');
+      if c1' = c2' then NONE else SOME (the_classrel thy (c1', c2'));
     fun gen_classrel (c1_pred, c2_succ) =
       let
-        val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
-          |> Drule.instantiate' [SOME (certT (TVar ((Name.aT, 0), [])))] []
+        val th' =
+          the ((reflcl_classrel (c1_pred, c1) RSO SOME th) RSO reflcl_classrel (c2, c2_succ))
+          |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
           |> Thm.close_derivation;
-        val prf' = Thm.proof_of th';
-      in ((c1_pred, c2_succ), (th', prf')) end;
+      in ((c1_pred, c2_succ), th') end;
 
     val new_classrels =
       Library.map_product pair (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
-      |> filter_out (Symreltab.defined classrels)
+      |> filter_out ((op =) orf Symreltab.defined classrels)
       |> map gen_classrel;
     val needed = not (null new_classrels);
   in
@@ -230,75 +229,77 @@
 fun complete_classrels thy =
   let
     val classrels = proven_classrels_of thy;
-    val diff_merge_classrels = diff_merge_classrels_of thy;
+    val diff_classrels = diff_classrels_of thy;
     val (needed, thy') = (false, thy) |>
-      fold (fn c12 => fn (needed, thy) =>
-          put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> #1) thy
+      fold (fn rel => fn (needed, thy) =>
+          put_trancl_classrel (rel, Symreltab.lookup classrels rel |> the) thy
           |>> (fn b => needed orelse b))
-        diff_merge_classrels;
+        diff_classrels;
   in
-    if null diff_merge_classrels then NONE
-    else SOME (clear_diff_merge_classrels thy')
+    if null diff_classrels then NONE
+    else SOME (clear_diff_classrels thy')
   end;
 
 
 fun the_arity thy a (c, Ss) =
   (case AList.lookup (op =) (Symtab.lookup_list (proven_arities_of thy) a) (c, Ss) of
-    SOME arity => arity
+    SOME (thm, _) => Thm.transfer thy thm
   | NONE => error ("Unproven type arity " ^
       Syntax.string_of_arity (ProofContext.init thy) (a, Ss, [c])));
 
-fun the_arity_thm thy a c_Ss = the_arity thy a c_Ss |> #1 |> #1 |> Thm.transfer thy;
-fun the_arity_prf thy a c_Ss = the_arity thy a c_Ss |> #2;
-
 fun thynames_of_arity thy (c, a) =
   Symtab.lookup_list (proven_arities_of thy) a
-  |> map_filter (fn ((c', _), ((_, name),_)) => if c = c' then SOME name else NONE)
+  |> map_filter (fn ((c', _), (_, name)) => if c = c' then SOME name else NONE)
   |> rev;
 
-fun insert_arity_completions thy (t, ((c, Ss), ((th, thy_name), _))) arities =
+fun insert_arity_completions thy t ((c, Ss), ((th, thy_name))) (finished, arities) =
   let
     val algebra = Sign.classes_of thy;
+    val ars = Symtab.lookup_list arities t;
     val super_class_completions =
       Sign.super_classes thy c
-      |> filter_out (fn c1 => exists (fn ((c2, Ss2), _) => c1 = c2
-          andalso Sorts.sorts_le algebra (Ss2, Ss)) (Symtab.lookup_list arities t));
-    val names_and_Ss = Name.names Name.context Name.aT (map (K []) Ss);
+      |> filter_out (fn c1 => exists (fn ((c2, Ss2), _) =>
+            c1 = c2 andalso Sorts.sorts_le algebra (Ss2, Ss)) ars);
+
+    val names = Name.invents Name.context Name.aT (length Ss);
+    val std_vars = map (fn a => SOME (ctyp_of thy (TVar ((a, 0), [])))) names;
+
     val completions = super_class_completions |> map (fn c1 =>
       let
-        val th1 = (th RS the_classrel_thm thy (c, c1))
-          |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
+        val th1 =
+          (th RS the_classrel thy (c, c1))
+          |> Drule.instantiate' std_vars []
           |> Thm.close_derivation;
-        val prf1 = Thm.proof_of th1;
-      in (((th1, thy_name), prf1), c1) end);
-    val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
-      completions arities;
-  in (null completions, arities') end;
+      in ((th1, thy_name), c1) end);
+
+    val finished' = finished andalso null completions;
+    val arities' = fold (fn (th, c1) => Symtab.cons_list (t, ((c1, Ss), th))) completions arities;
+  in (finished', arities') end;
 
 fun put_arity ((t, Ss, c), th) thy =
-  let
-    val arity' = (t, ((c, Ss), ((th, Context.theory_name thy), Thm.proof_of th)));
-  in
+  let val ar = ((c, Ss), (th, Context.theory_name thy)) in
     thy
     |> map_proven_arities
-      (Symtab.insert_list (eq_fst op =) arity' #>
-        insert_arity_completions thy arity' #> snd)
+      (Symtab.insert_list (eq_fst op =) (t, ar) #>
+       curry (insert_arity_completions thy t ar) true #> #2)
   end;
 
 fun complete_arities thy =
   let
     val arities = proven_arities_of thy;
-    val (finished, arities') = arities
-      |> fold_map (insert_arity_completions thy) (Symtab.dest_list arities);
+    val (finished, arities') =
+      Symtab.fold (fn (t, ars) => fold (insert_arity_completions thy t) ars) arities (true, arities);
   in
-    if forall I finished
-    then NONE
+    if finished then NONE
     else SOME (map_proven_arities (K arities') thy)
   end;
 
 val _ = Context.>> (Context.map_theory
   (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities));
 
+val the_classrel_prf = Thm.proof_of oo the_classrel;
+val the_arity_prf = Thm.proof_of ooo the_arity;
+
 
 (* maintain instance parameters *)
 
@@ -309,15 +310,15 @@
 
 fun add_inst_param (c, tyco) inst =
   (map_inst_params o apfst o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
-  #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
+  #> (map_inst_params o apsnd) (Symtab.update_new (#1 inst, (c, tyco)));
 
 val inst_of_param = Symtab.lookup o #2 o inst_params_of;
-val param_of_inst = fst oo get_inst_param;
+val param_of_inst = #1 oo get_inst_param;
 
 fun inst_thms thy =
   Symtab.fold (Symtab.fold (cons o #2 o #2) o #2) (#1 (inst_params_of thy)) [];
 
-fun get_inst_tyco consts = try (fst o dest_Type o the_single o Consts.typargs consts);
+fun get_inst_tyco consts = try (#1 o dest_Type o the_single o Consts.typargs consts);
 
 fun unoverload thy = MetaSimplifier.simplify true (inst_thms thy);
 fun overload thy = MetaSimplifier.simplify true (map Thm.symmetric (inst_thms thy));
@@ -376,7 +377,7 @@
       | NONE => error ("Not a class parameter: " ^ quote c));
     val tyco = inst_tyco_of thy (c, T);
     val name_inst = instance_name (tyco, class) ^ "_inst";
-    val c' = Long_Name.base_name c ^ "_" ^ Long_Name.base_name tyco;
+    val c' = instance_name (tyco, c);
     val T' = Type.strip_sorts T;
   in
     thy
@@ -388,7 +389,7 @@
       #>> apsnd Thm.varifyT_global
       #-> (fn (_, thm) => add_inst_param (c, tyco) (c'', thm)
         #> PureThy.add_thm ((Binding.conceal (Binding.name c'), thm), [])
-        #> snd
+        #> #2
         #> pair (Const (c, T))))
     ||> Sign.restore_naming thy
   end;
@@ -399,8 +400,7 @@
     val tyco = inst_tyco_of thy (c, T);
     val (c', eq) = get_inst_param thy (c, tyco);
     val prop = Logic.mk_equals (Const (c', T), t);
-    val b' = Thm.def_binding_optional
-      (Binding.name (Long_Name.base_name c ^ "_" ^ Long_Name.base_name tyco)) b;
+    val b' = Thm.def_binding_optional (Binding.name (instance_name (tyco, c))) b;
   in
     thy
     |> Thm.add_def false false (b', prop)
@@ -426,7 +426,7 @@
   in
     thy
     |> Sign.primitive_classrel (c1, c2)
-    |> (snd oo put_trancl_classrel) ((c1, c2), th')
+    |> (#2 oo put_trancl_classrel) ((c1, c2), th')
     |> perhaps complete_arities
   end;
 
@@ -436,20 +436,23 @@
     val prop = Thm.plain_prop_of th;
     fun err () = raise THM ("add_arity: malformed type arity", 0, [th]);
     val (t, Ss, c) = Logic.dest_arity prop handle TERM _ => err ();
-    val names = Name.names Name.context Name.aT Ss;
-    val T = Type (t, map TFree names);
+
+    val args = Name.names Name.context Name.aT Ss;
+    val T = Type (t, map TFree args);
+    val std_vars = map (fn (a, S) => SOME (ctyp_of thy (TVar ((a, 0), S)))) args;
+
     val missing_params = Sign.complete_sort thy [c]
       |> maps (these o Option.map #params o try (get_info thy))
       |> filter_out (fn (const, _) => can (get_inst_param thy) (const, t))
       |> (map o apsnd o map_atyps) (K T);
     val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
     val th' = th
-      |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
+      |> Drule.instantiate' std_vars []
       |> Thm.unconstrain_allTs;
     val _ = shyps_topped th' orelse raise Fail "add_arity: nontop shyps after unconstrain";
   in
     thy
-    |> fold (snd oo declare_overloaded) missing_params
+    |> fold (#2 oo declare_overloaded) missing_params
     |> Sign.primitive_arity (t, Ss, [c])
     |> put_arity ((t, Ss, c), th')
   end;
@@ -585,9 +588,9 @@
     val axclass = make_axclass (def, intro, axioms, params);
     val result_thy =
       facts_thy
-      |> fold (snd oo put_trancl_classrel) (map (pair class) super ~~ classrel)
+      |> fold (#2 oo put_trancl_classrel) (map (pair class) super ~~ classrel)
       |> Sign.qualified_path false bconst
-      |> PureThy.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
+      |> PureThy.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> #2
       |> Sign.restore_naming facts_thy
       |> map_axclasses (Symtab.update (class, axclass))
       |> map_params (fold (fn (x, _) => add_param pp (x, class)) params);
@@ -600,8 +603,7 @@
 
 local
 
-(* old-style axioms *)
-
+(*old-style axioms*)
 fun add_axiom (b, prop) =
   Thm.add_axiom (b, prop) #->
   (fn (_, thm) => PureThy.add_thm ((b, Drule.export_without_context thm), []));