more direct proofs for type classes;
authorwenzelm
Fri, 02 Aug 2019 14:14:49 +0200
changeset 70458 9e2173eb23eb
parent 70457 a8b5d668bf13
child 70459 f0a445c5a82c
more direct proofs for type classes; misc tuning and cleanup;
src/Pure/Proof/proof_rewrite_rules.ML
src/Pure/proofterm.ML
src/Pure/thm.ML
--- a/src/Pure/Proof/proof_rewrite_rules.ML	Fri Aug 02 11:43:36 2019 +0200
+++ b/src/Pure/Proof/proof_rewrite_rules.ML	Fri Aug 02 14:14:49 2019 +0200
@@ -359,7 +359,8 @@
       Same.commit (Proofterm.map_proof_same Same.same Same.same hyp)
   in
     map2 reconstruct
-      (Proofterm.of_sort_proof thy (OfClass o apfst Type.strip_sorts) (subst T, S))
+      (Proofterm.of_sort_proof (Sign.classes_of thy) (Thm.classrel_proof thy) (Thm.arity_proof thy)
+        (OfClass o apfst Type.strip_sorts) (subst T, S))
       (Logic.mk_of_sort (T, S))
   end;
 
--- a/src/Pure/proofterm.ML	Fri Aug 02 11:43:36 2019 +0200
+++ b/src/Pure/proofterm.ML	Fri Aug 02 14:14:49 2019 +0200
@@ -131,12 +131,10 @@
   val equal_elim: term -> term -> proof -> proof -> proof
   val strip_shyps_proof: Sorts.algebra -> (typ * sort) list -> (typ * sort) list ->
     sort list -> proof -> proof
-  val classrel_proof: theory -> class * class -> proof
-  val arity_proof: theory -> string * sort list * class -> proof
-  val of_sort_proof: theory -> (typ * class -> proof) -> typ * sort -> proof list
-  val install_axclass_proofs:
-   {classrel_proof: theory -> class * class -> proof,
-    arity_proof: theory -> string * sort list * class -> proof} -> theory -> theory
+  val of_sort_proof: Sorts.algebra ->
+    (class * class -> proof) ->
+    (string * class list list * class -> proof) ->
+    (typ * class -> proof) -> typ * sort -> proof list
   val axm_proof: string -> term -> proof
   val oracle_proof: string -> term -> oracle * proof
   val shrink_proof: proof -> proof
@@ -159,9 +157,11 @@
 
   val proof_serial: unit -> proof_serial
   val fulfill_norm_proof: theory -> (serial * proof_body) list -> proof_body -> proof_body
-  val thm_proof: theory -> string -> sort list -> term list -> term ->
+  val thm_proof: theory -> (class * class -> proof) ->
+    (string * class list list * class -> proof) -> string -> sort list -> term list -> term ->
     (serial * proof_body future) list -> proof_body -> pthm * proof
-  val unconstrain_thm_proof: theory -> sort list -> term ->
+  val unconstrain_thm_proof: theory -> (class * class -> proof) ->
+    (string * class list list * class -> proof) -> sort list -> term ->
     (serial * proof_body future) list -> proof_body -> pthm * proof
   val get_name: sort list -> term list -> term -> proof -> string
 end
@@ -1053,61 +1053,15 @@
         | NONE => raise Fail "strip_shyps_proof: bad type variable in proof term");
   in Same.commit (map_proof_types_same (Term_Subst.map_atypsT_same replace)) prf end;
 
-
-local
-
-type axclass_proofs =
- {classrel_proof: theory -> class * class -> proof,
-  arity_proof: theory -> string * sort list * class -> proof};
-
-structure Axclass_Proofs = Theory_Data
-(
-  type T = axclass_proofs option;
-  val empty = NONE;
-  val extend = I;
-  val merge = merge_options;
-);
-
-fun the_axclass_proofs which thy x =
-  (case Axclass_Proofs.get thy of
-    NONE => raise Fail "Axclass proof operations not installed"
-  | SOME proofs => which proofs thy x);
-
-in
-
-val classrel_proof = the_axclass_proofs #classrel_proof;
-val arity_proof = the_axclass_proofs #arity_proof;
-
-fun install_axclass_proofs proofs =
-  Axclass_Proofs.map
-    (fn NONE => SOME proofs
-      | SOME _ => raise Fail "Axclass proof operations already installed");
-
-end;
-
-
-local
-
-fun canonical_instance typs =
-  let
-    val names = Name.invent Name.context Name.aT (length typs);
-    val instT = map2 (fn a => fn T => (((a, 0), []), Type.strip_sorts T)) names typs;
-  in instantiate (instT, []) end;
-
-in
-
-fun of_sort_proof thy hyps =
-  Sorts.of_sort_derivation (Sign.classes_of thy)
-   {class_relation = fn typ => fn _ => fn (prf, c1) => fn c2 =>
-      if c1 = c2 then prf
-      else canonical_instance [typ] (classrel_proof thy (c1, c2)) %% prf,
-    type_constructor = fn (a, typs) => fn dom => fn c =>
+fun of_sort_proof algebra classrel_proof arity_proof hyps =
+  Sorts.of_sort_derivation algebra
+   {class_relation = fn _ => fn _ => fn (prf, c1) => fn c2 =>
+      if c1 = c2 then prf else classrel_proof (c1, c2) %% prf,
+    type_constructor = fn (a, _) => fn dom => fn c =>
       let val Ss = map (map snd) dom and prfs = maps (map fst) dom
-      in proof_combP (canonical_instance typs (arity_proof thy (a, Ss, c)), prfs) end,
+      in proof_combP (arity_proof (a, Ss, c), prfs) end,
     type_variable = fn typ => map (fn c => (hyps (typ, c), c)) (Type.sort_of_atyp typ)};
 
-end;
-
 
 
 (** axioms and theorems **)
@@ -2006,17 +1960,17 @@
 
 local
 
-fun unconstrainT_prf thy (ucontext: Logic.unconstrain_context) =
+fun unconstrainT_proof algebra classrel_proof arity_proof (ucontext: Logic.unconstrain_context) =
   let
     fun hyp_map hyp =
       (case AList.lookup (op =) (#constraints ucontext) hyp of
         SOME t => Hyp t
-      | NONE => raise Fail "unconstrainT_prf: missing constraint");
+      | NONE => raise Fail "unconstrainT_proof: missing constraint");
 
     val typ = Term_Subst.map_atypsT_same (Type.strip_sorts o #atyp_map ucontext);
     fun ofclass (ty, c) =
       let val ty' = Term.map_atyps (#atyp_map ucontext) ty;
-      in the_single (of_sort_proof thy hyp_map (ty', [c])) end;
+      in the_single (of_sort_proof algebra classrel_proof arity_proof  hyp_map (ty', [c])) end;
   in
     Same.commit (map_proof_same (Term_Subst.map_types_same typ) typ ofclass)
     #> fold_rev (implies_intr_proof o snd) (#constraints ucontext)
@@ -2062,13 +2016,9 @@
   if Options.default_bool "prune_proofs" then MinProof
   else proof;
 
-fun prepare_thm_proof unconstrain thy name shyps hyps concl promises body =
+fun prepare_thm_proof unconstrain thy classrel_proof arity_proof
+    name shyps hyps concl promises body =
   let
-(*
-    val FIXME =
-      Output.physical_stderr ("pthm " ^ quote name ^ " " ^ Position.here (Position.thread_data ()) ^ "\n");
-*)
-
     val named = name <> "";
 
     val prop = Logic.list_implies (hyps, concl);
@@ -2088,7 +2038,9 @@
     fun new_prf () =
       let
         val i = proof_serial ();
-        val postproc = map_proof_of (unconstrainT_prf thy ucontext) #> named ? publish i;
+        val unconstrainT =
+          unconstrainT_proof (Sign.classes_of thy) classrel_proof arity_proof ucontext;
+        val postproc = map_proof_of unconstrainT #> named ? publish i;
       in (i, fulfill_proof_future thy promises postproc body0) end;
 
     val (i, body') =
@@ -2115,8 +2067,8 @@
 
 val thm_proof = prepare_thm_proof false;
 
-fun unconstrain_thm_proof thy shyps concl promises body =
-  prepare_thm_proof true thy "" shyps [] concl promises body;
+fun unconstrain_thm_proof thy classrel_proof arity_proof shyps concl promises body =
+  prepare_thm_proof true thy classrel_proof arity_proof "" shyps [] concl promises body;
 
 end;
 
--- a/src/Pure/thm.ML	Fri Aug 02 11:43:36 2019 +0200
+++ b/src/Pure/thm.ML	Fri Aug 02 14:14:49 2019 +0200
@@ -115,6 +115,11 @@
   val map_tags: (Properties.T -> Properties.T) -> thm -> thm
   val norm_proof: thm -> thm
   val adjust_maxidx_thm: int -> thm -> thm
+  (*type classes*)
+  val the_classrel: theory -> class * class -> thm
+  val the_arity: theory -> string * sort list * class -> thm
+  val classrel_proof: theory -> class * class -> proof
+  val arity_proof: theory -> string * sort list * class -> proof
   (*oracles*)
   val add_oracle: binding * ('a -> cterm) -> theory -> (string * ('a -> thm)) * theory
   val extern_oracles: bool -> Proof.context -> (Markup.T * xstring) list
@@ -160,8 +165,6 @@
   val bicompose: Proof.context option -> {flatten: bool, match: bool, incremented: bool} ->
     bool * thm * int -> int -> thm -> thm Seq.seq
   val biresolution: Proof.context option -> bool -> (bool * thm) list -> int -> thm -> thm Seq.seq
-  val the_classrel: theory -> class * class -> thm
-  val the_arity: theory -> string * sort list * class -> thm
   val thynames_of_arity: theory -> string * class -> string list
   val add_classrel: thm -> theory -> theory
   val add_arity: thm -> theory -> theory
@@ -747,32 +750,6 @@
   end;
 
 
-(* closed derivations with official name *)
-
-(*non-deterministic, depends on unknown promises*)
-fun derivation_closed (Thm (Deriv {body, ...}, _)) =
-  Proofterm.compact_proof (Proofterm.proof_of body);
-
-(*non-deterministic, depends on unknown promises*)
-fun derivation_name (Thm (Deriv {body, ...}, {shyps, hyps, prop, ...})) =
-  Proofterm.get_name shyps hyps prop (Proofterm.proof_of body);
-
-fun name_derivation name (thm as Thm (der, args)) =
-  let
-    val Deriv {promises, body} = der;
-    val {shyps, hyps, prop, tpairs, ...} = args;
-    val _ = null tpairs orelse raise THM ("put_name: unsolved flex-flex constraints", 0, [thm]);
-    val thy = theory_of_thm thm;
-
-    val ps = map (apsnd (Future.map fulfill_body)) promises;
-    val (pthm, proof) = Proofterm.thm_proof thy name shyps hyps prop ps body;
-    val der' = make_deriv [] [] [pthm] proof;
-  in Thm (der', args) end;
-
-fun close_derivation thm =
-  if derivation_closed thm then thm else name_derivation "" thm;
-
-
 
 (** Axioms **)
 
@@ -828,47 +805,97 @@
 
 (*** Theory data ***)
 
-datatype sorts = Sorts of
+(* type classes *)
+
+datatype classes = Classes of
  {classrels: thm Symreltab.table,
   arities: ((class * sort list) * (thm * string)) list Symtab.table};
 
-fun make_sorts (classrels, arities) = Sorts {classrels = classrels, arities = arities};
+fun make_classes (classrels, arities) = Classes {classrels = classrels, arities = arities};
 
-val empty_sorts = make_sorts (Symreltab.empty, Symtab.empty);
+val empty_classes = make_classes (Symreltab.empty, Symtab.empty);
 
-fun merge_sorts
-   (Sorts {classrels = classrels1, arities = arities1},
-    Sorts {classrels = classrels2, arities = arities2}) =
+(*see Theory.at_begin hook for transitive closure of classrels and arity completion*)
+fun merge_classes
+   (Classes {classrels = classrels1, arities = arities1},
+    Classes {classrels = classrels2, arities = arities2}) =
   let
-    (*see Theory.at_begin hook for transitive closure of classrels and arity completion*)
     val classrels' = Symreltab.merge (K true) (classrels1, classrels2);
     val arities' = Symtab.merge_list (eq_fst op =) (arities1, arities2);
-  in make_sorts (classrels', arities') end;
+  in make_classes (classrels', arities') end;
 
 
+(* data *)
+
 structure Data = Theory_Data
 (
   type T =
     unit Name_Space.table *  (*oracles: authentic derivation names*)
-    sorts;  (*sort algebra within the logic*)
+    classes;  (*type classes within the logic*)
 
-  val empty : T = (Name_Space.empty_table "oracle", empty_sorts);
+  val empty : T = (Name_Space.empty_table "oracle", empty_classes);
   val extend = I;
   fun merge ((oracles1, sorts1), (oracles2, sorts2)) : T =
-    (Name_Space.merge_tables (oracles1, oracles2), merge_sorts (sorts1, sorts2));
+    (Name_Space.merge_tables (oracles1, oracles2), merge_classes (sorts1, sorts2));
 );
 
 val get_oracles = #1 o Data.get;
 val map_oracles = Data.map o apfst;
 
-val get_sorts = (fn (_, Sorts args) => args) o Data.get;
-val get_classrels = #classrels o get_sorts;
-val get_arities = #arities o get_sorts;
+val get_classes = (fn (_, Classes args) => args) o Data.get;
+val get_classrels = #classrels o get_classes;
+val get_arities = #arities o get_classes;
+
+fun map_classes f =
+  (Data.map o apsnd) (fn Classes {classrels, arities} => make_classes (f (classrels, arities)));
+fun map_classrels f = map_classes (fn (classrels, arities) => (f classrels, arities));
+fun map_arities f = map_classes (fn (classrels, arities) => (classrels, f arities));
+
+
+(* type classes *)
+
+fun the_classrel thy (c1, c2) =
+  (case Symreltab.lookup (get_classrels thy) (c1, c2) of
+    SOME thm => transfer thy thm
+  | NONE => error ("Unproven class relation " ^
+      Syntax.string_of_classrel (Proof_Context.init_global thy) [c1, c2]));
+
+fun the_arity thy (a, Ss, c) =
+  (case AList.lookup (op =) (Symtab.lookup_list (get_arities thy) a) (c, Ss) of
+    SOME (thm, _) => transfer thy thm
+  | NONE => error ("Unproven type arity " ^
+      Syntax.string_of_arity (Proof_Context.init_global thy) (a, Ss, [c])));
+
+val classrel_proof = proof_of oo the_classrel;
+val arity_proof = proof_of oo the_arity;
 
-fun map_sorts f =
-  (Data.map o apsnd) (fn Sorts {classrels, arities} => make_sorts (f (classrels, arities)));
-fun map_classrels f = map_sorts (fn (classrels, arities) => (f classrels, arities));
-fun map_arities f = map_sorts (fn (classrels, arities) => (classrels, f arities));
+
+
+(*** Theorems with official name ***)
+
+(*non-deterministic, depends on unknown promises*)
+fun derivation_closed (Thm (Deriv {body, ...}, _)) =
+  Proofterm.compact_proof (Proofterm.proof_of body);
+
+(*non-deterministic, depends on unknown promises*)
+fun derivation_name (Thm (Deriv {body, ...}, {shyps, hyps, prop, ...})) =
+  Proofterm.get_name shyps hyps prop (Proofterm.proof_of body);
+
+fun name_derivation name (thm as Thm (der, args)) =
+  let
+    val Deriv {promises, body} = der;
+    val {shyps, hyps, prop, tpairs, ...} = args;
+    val _ = null tpairs orelse raise THM ("put_name: unsolved flex-flex constraints", 0, [thm]);
+    val thy = theory_of_thm thm;
+
+    val ps = map (apsnd (Future.map fulfill_body)) promises;
+    val (pthm, proof) =
+      Proofterm.thm_proof thy (classrel_proof thy) (arity_proof thy) name shyps hyps prop ps body;
+    val der' = make_deriv [] [] [pthm] proof;
+  in Thm (der', args) end;
+
+fun close_derivation thm =
+  if derivation_closed thm then thm else name_derivation "" thm;
 
 
 
@@ -1501,7 +1528,8 @@
     val _ = null tfrees orelse err ("illegal free type variables " ^ commas_quote tfrees);
 
     val ps = map (apsnd (Future.map fulfill_body)) promises;
-    val (pthm, proof) = Proofterm.unconstrain_thm_proof thy shyps prop ps body;
+    val (pthm, proof) =
+      Proofterm.unconstrain_thm_proof thy (classrel_proof thy) (arity_proof thy) shyps prop ps body;
     val der' = make_deriv [] [] [pthm] proof;
     val prop' = Proofterm.thm_node_prop (#2 pthm);
   in
@@ -1977,7 +2005,7 @@
 
 
 
-(*** sort algebra within the logic ***)
+(**** Type classes ****)
 
 fun standard_tvars thm =
   let
@@ -1991,12 +2019,6 @@
 
 (* class relations *)
 
-fun the_classrel thy (c1, c2) =
-  (case Symreltab.lookup (get_classrels thy) (c1, c2) of
-    SOME thm => transfer thy thm
-  | NONE => error ("Unproven class relation " ^
-      Syntax.string_of_classrel (Proof_Context.init_global thy) [c1, c2]));
-
 val is_classrel = Symreltab.defined o get_classrels;
 
 fun complete_classrels thy =
@@ -2029,12 +2051,6 @@
 
 (* type arities *)
 
-fun the_arity thy (a, Ss, c) =
-  (case AList.lookup (op =) (Symtab.lookup_list (get_arities thy) a) (c, Ss) of
-    SOME (thm, _) => transfer thy thm
-  | NONE => error ("Unproven type arity " ^
-      Syntax.string_of_arity (Proof_Context.init_global thy) (a, Ss, [c])));
-
 fun thynames_of_arity thy (a, c) =
   Symtab.lookup_list (get_arities thy) a
   |> map_filter (fn ((c', _), (_, name)) => if c = c' then SOME name else NONE)
@@ -2084,10 +2100,7 @@
 val _ =
   Theory.setup
    (Theory.at_begin complete_classrels #>
-    Theory.at_begin complete_arities #>
-    Proofterm.install_axclass_proofs
-     {classrel_proof = proof_of oo the_classrel,
-      arity_proof = proof_of oo the_arity});
+    Theory.at_begin complete_arities);
 
 
 (* primitive rules *)