tuned and generalized construction of code equations for eq
authorhaftmann
Wed, 13 May 2009 18:41:36 +0200
changeset 31134 1a5591ecb764
parent 31133 a9f728dc5c8e
child 31135 e2d777dcf161
tuned and generalized construction of code equations for eq
src/HOL/Tools/datatype_codegen.ML
--- a/src/HOL/Tools/datatype_codegen.ML	Wed May 13 18:41:36 2009 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML	Wed May 13 18:41:36 2009 +0200
@@ -6,7 +6,7 @@
 
 signature DATATYPE_CODEGEN =
 sig
-  val mk_eq: theory -> string -> thm list
+  val mk_eq_eqns: theory -> string -> (thm * bool) list
   val mk_case_cert: theory -> string -> thm
   val setup: theory -> theory
 end;
@@ -309,18 +309,6 @@
 
 (** generic code generator **)
 
-(* specification *)
-
-fun add_datatype_spec vs dtco cos thy =
-  let
-    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
-  in
-    thy
-    |> try (Code.add_datatype cs)
-    |> the_default thy
-  end;
-
-
 (* case certificates *)
 
 fun mk_case_cert thy tyco =
@@ -354,88 +342,41 @@
     |> Thm.varifyT
   end;
 
-fun add_datatype_cases dtco thy =
-  let
-    val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
-    val cert = mk_case_cert thy dtco;
-    fun add_case_liberal thy = thy
-      |> try (Code.add_case cert)
-      |> the_default thy;
-  in
-    thy
-    |> add_case_liberal
-    |> fold_rev Code.add_default_eqn case_rewrites
-  end;
-
 
 (* equality *)
 
-local
-
-val not_sym = @{thm HOL.not_sym};
-val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI];
-val refl = @{thm refl};
-val eqTrueI = @{thm eqTrueI};
-
-fun mk_distinct cos =
-  let
-    fun sym_product [] = []
-      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
-    fun mk_co_args (co, tys) ctxt =
-      let
-        val names = Name.invents ctxt "a" (length tys);
-        val ctxt' = fold Name.declare names ctxt;
-        val vs = map2 (curry Free) names tys;
-      in (vs, ctxt') end;
-    fun mk_dist ((co1, tys1), (co2, tys2)) =
-      let
-        val ((xs1, xs2), _) = Name.context
-          |> mk_co_args (co1, tys1)
-          ||>> mk_co_args (co2, tys2);
-        val prem = HOLogic.mk_eq
-          (list_comb (co1, xs1), list_comb (co2, xs2));
-        val t = HOLogic.mk_not prem;
-      in HOLogic.mk_Trueprop t end;
-  in map mk_dist (sym_product cos) end;
-
-in
-
-fun mk_eq thy dtco =
+fun mk_eq_eqns thy dtco =
   let
-    val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
-    fun mk_triv_inject co =
-      let
-        val ct' = Thm.cterm_of thy
-          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
-        val cty' = Thm.ctyp_of_term ct';
-        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
-          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
-          (Thm.prop_of refl) NONE;
-      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
-    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
-    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
-    val ctxt = ProofContext.init thy;
-    val simpset = Simplifier.context ctxt
-      (Simplifier.empty_ss addsimprocs [DatatypePackage.distinct_simproc]);
-    val cos = map (fn (co, tys) =>
-        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
-    val tac = ALLGOALS (simp_tac simpset)
-      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
-    val distinct =
-      mk_distinct cos
-      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
-      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
-  in inject1 @ inject2 @ distinct end;
+    val (vs, cos) = DatatypePackage.the_datatype_spec thy dtco;
+    val { descr, index, inject = inject_thms, ... } = DatatypePackage.the_datatype thy dtco;
+    val ty = Type (dtco, map TFree vs);
+    fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
+      $ t1 $ t2;
+    fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
+    fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
+    val triv_injects = map_filter
+     (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
+       | _ => NONE) cos;
+    fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
+      trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
+    val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index);
+    fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
+      [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
+    val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index));
+    val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
+    val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
+      addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
+      addsimprocs [DatatypePackage.distinct_simproc]);
+    fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
+      |> Simpdata.mk_eq
+      |> Simplifier.rewrite_rule [@{thm equals_eq}];
+  in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
 
-end;
-
-fun add_datatypes_equality vs dtcos thy =
+fun add_equality vs dtcos thy =
   let
-    val vs' = (map o apsnd)
-      (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
     fun add_def dtco lthy =
       let
-        val ty = Type (dtco, map TFree vs');
+        val ty = Type (dtco, map TFree vs);
         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
           $ Free ("x", ty) $ Free ("y", ty);
         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
@@ -448,52 +389,60 @@
       in (thm', lthy') end;
     fun tac thms = Class.intro_classes_tac []
       THEN ALLGOALS (ProofContext.fact_tac thms);
-    fun mk_eq' thy dtco = mk_eq thy dtco
-      |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq])
-      |> map Simpdata.mk_eq
-      |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}])
-      |> map (AxClass.unoverload thy);
     fun add_eq_thms dtco thy =
       let
-        val ty = Type (dtco, map TFree vs');
+        val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
         val thy_ref = Theory.check_thy thy;
-        val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
-        val eq_refl = @{thm HOL.eq_refl}
-          |> Thm.instantiate
-              ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
-          |> Simpdata.mk_eq
-          |> AxClass.unoverload thy;
-        fun mk_thms () = (eq_refl, false)
-          :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco));
+        fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco));
       in
         Code.add_eqnl (const, Lazy.lazy mk_thms) thy
       end;
   in
     thy
-    |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq])
+    |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq])
     |> fold_map add_def dtcos
-    |-> (fn thms => Class.prove_instantiation_instance (K (tac thms))
-    #> LocalTheory.exit_global
-    #> fold Code.del_eqn thms
-    #> fold add_eq_thms dtcos)
+    |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
+         (fn _ => fn def_thms => tac def_thms) def_thms)
+    |-> (fn def_thms => fold Code.del_eqn def_thms)
+    |> fold add_eq_thms dtcos
+  end;
+
+
+(* liberal addition of code data for datatypes *)
+
+fun mk_constr_consts thy vs dtco cos =
+  let
+    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+    val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
+  in if is_some (try (Code_Unit.constrset_of_consts thy) cs')
+    then SOME cs
+    else NONE
   end;
 
+fun add_all_code dtcos thy =
+  let
+    val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
+    val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
+    val css = if exists is_none any_css then []
+      else map_filter I any_css;
+    val case_rewrites = maps (#case_rewrites o DatatypePackage.the_datatype thy) dtcos;
+    val certs = map (mk_case_cert thy) dtcos;
+  in
+    if null css then thy
+    else thy
+      |> fold Code.add_datatype css
+      |> fold_rev Code.add_default_eqn case_rewrites
+      |> fold Code.add_case certs
+      |> add_equality vs dtcos
+   end;
+
+
 
 (** theory setup **)
 
-fun add_datatype_code dtcos thy =
-  let
-    val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
-  in
-    thy
-    |> fold2 (add_datatype_spec vs) dtcos coss
-    |> fold add_datatype_cases dtcos
-    |> add_datatypes_equality vs dtcos
-  end;
-
 val setup = 
   add_codegen "datatype" datatype_codegen
   #> add_tycodegen "datatype" datatype_tycodegen
-  #> DatatypePackage.interpretation add_datatype_code
+  #> DatatypePackage.interpretation add_all_code
 
 end;