src/HOL/Tools/datatype_codegen.ML
changeset 20835 27d049062b56
parent 20681 0e4df994ad34
child 20855 9f60d493c8fe
--- a/src/HOL/Tools/datatype_codegen.ML	Mon Oct 02 23:00:50 2006 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML	Mon Oct 02 23:00:51 2006 +0200
@@ -9,10 +9,8 @@
 sig
   val get_eq: theory -> string -> thm list
   val get_eq_datatype: theory -> string -> thm list
-  val get_eq_typecopy: theory -> string -> thm list
   val get_cert: theory -> bool * string -> thm list
   val get_cert_datatype: theory -> string -> thm list
-  val get_cert_typecopy: theory -> string -> thm list
   val dest_case_expr: theory -> term
     -> ((string * typ) list * ((term * typ) * (term * term) list)) option
   val add_datatype_case_const: string -> theory -> theory
@@ -29,11 +27,13 @@
   val get_codetypes_arities: theory -> (string * bool) list -> sort
     -> (string * (((string * sort list) * sort) * term list)) list option
   val prove_codetypes_arities: (thm list -> tactic) -> (string * bool) list -> sort
-    -> (theory -> ((string * sort list) * sort) list -> (string * term list) list
-    -> ((bstring * attribute list) * term) list) -> (theory -> theory) -> theory -> theory
+    -> (((string * sort list) * sort) list -> (string * term list) list -> theory
+    -> ((bstring * attribute list) * term) list * theory)
+    -> (((string * sort list) * sort) list -> (string * term list) list -> theory -> theory)
+    -> theory -> theory
 
   val setup: theory -> theory
-  val setup2: theory -> theory
+  val setup_hooks: theory -> theory
 end;
 
 structure DatatypeCodegen : DATATYPE_CODEGEN =
@@ -364,7 +364,7 @@
         val names = Name.invents ctxt "a" (length tys);
         val ctxt' = fold Name.declare names ctxt;
         val vs = map2 (curry Free) names tys;
-      in (vs, ctxt) end;
+      in (vs, ctxt') end;
     fun mk_dist ((co1, tys1), (co2, tys2)) =
       let
         val ((xs1, xs2), _) = Name.context
@@ -400,17 +400,39 @@
       |> map (fn t => Goal.prove_global thy [] [] t (K tac))
       |> map (fn thm => not_eq_quodlibet OF [thm])
   in inject @ distinct end
-and get_cert_typecopy thy dtco =
-  let
-    val SOME { inject, ... } = TypecopyPackage.get_typecopy_info thy dtco;
-    val thm = Tactic.rewrite_rule [rew_eq] (bool_eq_implies OF [inject]);
-  in
-    [thm]
-  end;
 end (*local*);
 
-fun get_cert thy (true, dtco) = get_cert_datatype thy dtco
-  | get_cert thy (false, dtco) = get_cert_typecopy thy dtco;
+local
+  val not_sym = thm "HOL.not_sym";
+  val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
+in fun get_eq_datatype thy dtco =
+  let
+    val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
+    fun mk_triv_inject co =
+      let
+        val ct' = Thm.cterm_of thy
+          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
+        val cty' = Thm.ctyp_of_term ct';
+        val refl = Thm.prop_of HOL.refl;
+        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
+          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
+          refl NONE;
+      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) HOL.refl] end;
+    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
+    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
+    val ctxt = Context.init_proof thy;
+    val simpset = Simplifier.context ctxt
+      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
+    val cos = map (fn (co, tys) =>
+        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
+    val tac = ALLGOALS (simp_tac simpset)
+      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
+    val distinct =
+      mk_distinct cos
+      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
+      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
+  in inject1 @ inject2 @ distinct end;
+end (*local*);
 
 fun add_datatype_case_const dtco thy =
   let
@@ -429,8 +451,7 @@
 
 (** codetypes for code 2nd generation **)
 
-type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
-  -> theory -> theory;
+(* abstraction over datatypes vs. type copies *)
 
 fun codetypes_dependency thy =
   let
@@ -461,23 +482,54 @@
     |> map (AList.make (the o AList.lookup (op =) names))
   end;
 
-fun mk_typecopy_spec ({ vs, constr, typ, ... } : TypecopyPackage.info) =
-  (vs, [(constr, [typ])]);
-
 fun get_spec thy (dtco, true) =
       (the o DatatypePackage.get_datatype_spec thy) dtco
   | get_spec thy (tyco, false) =
-      (mk_typecopy_spec o the o TypecopyPackage.get_typecopy_info thy) tyco;
+      TypecopyPackage.get_spec thy tyco;
+
+fun get_cert thy (true, dtco) = get_cert_datatype thy dtco
+  | get_cert thy (false, dtco) = [TypecopyPackage.get_cert thy dtco];
 
-fun add_spec thy (tyco, is_dt) =
-  (tyco, (is_dt, get_spec thy (tyco, is_dt)));
+local
+  val eq_def_sym = thm "eq_def" |> Thm.symmetric;
+  val class_eq = "OperationalEquality.eq";
+  fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
+   of SOME _ => get_eq_datatype thy tyco
+    | NONE => [TypecopyPackage.get_eq thy tyco];
+  fun constrain_op_eq_thms thy thms =
+    let
+      fun add_eq (Const ("op =", ty)) =
+            fold (insert (eq_fst (op =)))
+              (Term.add_tvarsT ty [])
+        | add_eq _ =
+            I
+      val eqs = fold (fold_aterms add_eq o Thm.prop_of) thms [];
+      val instT = map (fn (v_i, sort) =>
+        (Thm.ctyp_of thy (TVar (v_i, sort)),
+           Thm.ctyp_of thy (TVar (v_i, Sorts.inter_sort (Sign.classes_of thy) (sort, [class_eq]))))) eqs;
+    in
+      thms
+      |> map (Thm.instantiate (instT, []))
+    end;
+in
+  fun get_eq thy tyco =
+    get_eq_thms thy tyco
+    |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy)
+    |> constrain_op_eq_thms thy
+    |> map (Tactic.rewrite_rule [eq_def_sym])
+end;
+
+type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
+  -> theory -> theory;
 
 fun add_codetypes_hook_bootstrap hook thy =
   let
+    fun add_spec thy (tyco, is_dt) =
+      (tyco, (is_dt, get_spec thy (tyco, is_dt)));
     fun datatype_hook dtcos thy =
       hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
-    fun typecopy_hook ((tyco, info )) thy =
-      hook ([(tyco, (false, mk_typecopy_spec info))]) thy;
+    fun typecopy_hook ((tyco, _)) thy =
+      hook ([(tyco, (false, TypecopyPackage.get_spec thy tyco))]) thy;
   in
     thy
     |> fold hook ((map o map) (add_spec thy) (codetypes_dependency thy))
@@ -498,6 +550,23 @@
         val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
       in (vs, map2 (fn (tyco, is_dt) => fn cs => (tyco, (is_dt, cs))) tycos' css) end;
 
+
+(* registering code types in code generator *)
+
+fun codetype_hook specs =
+  let
+    fun add (dtco, (flag, spec)) thy =
+      let
+        fun cert thy_ref = (fn () => get_cert (Theory.deref thy_ref) (flag, dtco));
+      in
+        CodegenData.add_datatype
+          (dtco, (spec, CodegenData.lazy (cert (Theory.self_ref thy)))) thy
+      end;
+  in fold add specs end;
+
+
+(* instrumentalizing the sort algebra *)
+
 fun get_codetypes_arities thy tycos sort =
   let
     val algebra = Sign.classes_of thy;
@@ -538,114 +607,35 @@
             then NONE else SOME (arity, (tyco, cs)))) insts;
       in
         thy
-        |> K ((not o null) arities) ? (ClassPackage.prove_instance_arity tac
-             arities ("", []) (f thy arities css) #> after_qed)
+        |> K ((not o null) arities) ? (
+            f arities css
+            #-> (fn defs =>
+              ClassPackage.prove_instance_arity tac arities ("", []) defs
+            #> after_qed arities css))
       end;
 
+
+(* operational equality *)
+
 local
   val class_eq = "OperationalEquality.eq";
-in fun add_eq_instance specs =
-  prove_codetypes_arities
-    (K (ClassPackage.intro_classes_tac []))
-    (map (fn (tyco, (is_dt, _)) => (tyco, is_dt)) specs)
-    [class_eq] ((K o K o K) [])
-end; (*local*)
-
-local
-  val not_sym = thm "HOL.not_sym";
-  val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
-in fun get_eq_datatype thy dtco =
+in fun eq_hook specs =
   let
-(*     val _ = writeln "01";  *)
-    val SOME (vs, cs) = DatatypePackage.get_datatype_spec (Context.check_thy thy) dtco;
-(*     val _ = writeln "02";  *)
-    fun mk_triv_inject co =
+    fun add_eq_thms (dtco, (_, (vs, cs))) thy =
       let
-        val ct' = Thm.cterm_of (Context.check_thy thy)
-          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
-        val cty' = Thm.ctyp_of_term ct';
-        val refl = Thm.prop_of HOL.refl;
-        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
-          (K o SOME) (Thm.cterm_of (Context.check_thy thy) (Var (v, Thm.typ_of cty')), Thm.ctyp_of (Context.check_thy thy) ty) | _ => I)
-          refl NONE;
-      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) HOL.refl] end;
-(*     val _ = writeln "03";  *)
-    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
-(*     val _ = writeln "04";  *)
-    val inject2 = (#inject o DatatypePackage.the_datatype (Context.check_thy thy)) dtco;
-(*     val _ = writeln "05";  *)
-    val ctxt = Context.init_proof (Context.check_thy thy);
-(*     val _ = writeln "06";  *)
-    val simpset = Simplifier.context ctxt
-      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
-(*     val _ = writeln "07";  *)
-    val cos = map (fn (co, tys) =>
-        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
-    val tac = ALLGOALS (simp_tac simpset)
-      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
-(*     val _ = writeln "08";  *)
-    val distinct =
-      mk_distinct cos
-      |> map (fn t => Goal.prove_global (Context.check_thy thy) [] [] t (K tac))
-      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
-(*     val _ = writeln "09";  *)
-  in inject1 @ inject2 @ distinct end;
-
-fun get_eq_typecopy thy tyco =
-  case TypecopyPackage.get_typecopy_info thy tyco
-   of SOME { inject, ... } => [inject]
-    | NONE => [];
-
-local
-  val lift_not_thm = thm "HOL.Eq_FalseI";
-  val lift_thm = thm "HOL.eq_reflection";
-  val eq_def_sym = thm "eq_def" |> Thm.symmetric;
-  fun get_eq_thms thy tyco = case DatatypePackage.get_datatype (Context.check_thy thy) tyco
-   of SOME _ => get_eq_datatype (Context.check_thy thy) tyco
-    | NONE => case TypecopyPackage.get_typecopy_info thy tyco
-       of SOME _ => get_eq_typecopy thy tyco
-        | NONE => [];
-in
-  fun get_eq thy tyco =
-    get_eq_thms (Context.check_thy thy) tyco
-(*     |> tap (fn _ => writeln "10")  *)
-    |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) (Context.check_thy thy))
-(*     |> tap (fn _ => writeln "11")  *)
-    |> constrain_op_eq (Context.check_thy thy)
-(*     |> tap (fn _ => writeln "12")  *)
-    |> map (Tactic.rewrite_rule [eq_def_sym])
-(*     |> tap (fn _ => writeln "13")  *)
-end;
-
-end;
-
-fun add_eq_thms (dtco, (_, (vs, cs))) thy =
-  let
-    val thy_ref = Theory.self_ref thy;
-    val ty = Type (dtco, map TFree vs) |> Logic.varifyT;
-    val c = CodegenConsts.norm thy ("OperationalEquality.eq", [ty]);
-    val get_thms = (fn () => get_eq (Theory.deref thy_ref) dtco |> rev);
-  in
-    CodegenData.add_funcl
-      (c, CodegenData.lazy get_thms) thy
-  end;
-
-fun codetype_hook dtcos theory =
-  let
-    fun add (dtco, (flag, spec)) thy =
-      let
-        fun cert thy_ref = (fn () => get_cert (Theory.deref thy_ref) (flag, dtco));
+        val thy_ref = Theory.self_ref thy;
+        val ty = Type (dtco, map TFree vs) |> Logic.varifyT;
+        val c = CodegenConsts.norm thy ("OperationalEquality.eq", [ty]);
+        val get_thms = (fn () => get_eq (Theory.deref thy_ref) dtco |> rev);
       in
-        CodegenData.add_datatype
-          (dtco, (spec, CodegenData.lazy (cert (Theory.self_ref thy)))) thy
+        CodegenData.add_funcl (c, CodegenData.lazy get_thms) thy
       end;
   in
-    theory
-    |> fold add dtcos
+    prove_codetypes_arities (K (ClassPackage.intro_classes_tac []))
+      (map (fn (tyco, (is_dt, _)) => (tyco, is_dt)) specs)
+      [class_eq] ((K o K o pair) []) ((K o K) (fold add_eq_thms specs))
   end;
-
-fun eq_hook dtcos =
-  add_eq_instance dtcos (fold add_eq_thms dtcos);
+end; (*local*)
 
 
 
@@ -657,7 +647,7 @@
   #> DatatypeHooks.add (fold add_datatype_case_const)
   #> DatatypeHooks.add (fold add_datatype_case_defs)
 
-val setup2 =
+val setup_hooks =
   add_codetypes_hook_bootstrap codetype_hook
   #> add_codetypes_hook_bootstrap eq_hook