swapped generic code generator and SML code generator
authorhaftmann
Thu, 22 Apr 2010 12:07:00 +0200
changeset 36298 2d55c4aba1dc
parent 36275 c6ca9e258269
child 36299 a35b83da74ce
swapped generic code generator and SML code generator
src/HOL/Tools/Datatype/datatype_codegen.ML
--- a/src/HOL/Tools/Datatype/datatype_codegen.ML	Thu Apr 22 09:30:39 2010 +0200
+++ b/src/HOL/Tools/Datatype/datatype_codegen.ML	Thu Apr 22 12:07:00 2010 +0200
@@ -12,6 +12,137 @@
 structure Datatype_Codegen : DATATYPE_CODEGEN =
 struct
 
+(** generic code generator **)
+
+(* liberal addition of code data for datatypes *)
+
+fun mk_constr_consts thy vs dtco cos =
+  let
+    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+    val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
+  in if is_some (try (Code.constrset_of_consts thy) cs')
+    then SOME cs
+    else NONE
+  end;
+
+
+(* case certificates *)
+
+fun mk_case_cert thy tyco =
+  let
+    val raw_thms =
+      (#case_rewrites o Datatype_Data.the_info thy) tyco;
+    val thms as hd_thm :: _ = raw_thms
+      |> Conjunction.intr_balanced
+      |> Thm.unvarify_global
+      |> Conjunction.elim_balanced (length raw_thms)
+      |> map Simpdata.mk_meta_eq
+      |> map Drule.zero_var_indexes
+    val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
+      | _ => I) (Thm.prop_of hd_thm) [];
+    val rhs = hd_thm
+      |> Thm.prop_of
+      |> Logic.dest_equals
+      |> fst
+      |> Term.strip_comb
+      |> apsnd (fst o split_last)
+      |> list_comb;
+    val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
+    val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
+  in
+    thms
+    |> Conjunction.intr_balanced
+    |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
+    |> Thm.implies_intr asm
+    |> Thm.generalize ([], params) 0
+    |> AxClass.unoverload thy
+    |> Thm.varifyT_global
+  end;
+
+
+(* equality *)
+
+fun mk_eq_eqns thy dtco =
+  let
+    val (vs, cos) = Datatype_Data.the_spec thy dtco;
+    val { descr, index, inject = inject_thms, distinct = distinct_thms, ... } =
+      Datatype_Data.the_info thy dtco;
+    val ty = Type (dtco, map TFree vs);
+    fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
+      $ t1 $ t2;
+    fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
+    fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
+    val triv_injects = map_filter
+     (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
+       | _ => NONE) cos;
+    fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
+      trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
+    val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr] vs) index);
+    fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
+      [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
+    val distincts = maps prep_distinct (snd (nth (Datatype_Prop.make_distincts [descr] vs) index));
+    val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
+    val simpset = Simplifier.global_context thy (HOL_basic_ss addsimps 
+      (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms @ distinct_thms)));
+    fun prove prop = Skip_Proof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
+      |> Simpdata.mk_eq;
+  in (map prove (triv_injects @ injects @ distincts), prove refl) end;
+
+fun add_equality vs dtcos thy =
+  let
+    fun add_def dtco lthy =
+      let
+        val ty = Type (dtco, map TFree vs);
+        fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
+          $ Free ("x", ty) $ Free ("y", ty);
+        val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
+          (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
+        val def' = Syntax.check_term lthy def;
+        val ((_, (_, thm)), lthy') = Specification.definition
+          (NONE, (Attrib.empty_binding, def')) lthy;
+        val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
+        val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
+      in (thm', lthy') end;
+    fun tac thms = Class.intro_classes_tac []
+      THEN ALLGOALS (ProofContext.fact_tac thms);
+    fun add_eq_thms dtco =
+      Theory.checkpoint
+      #> `(fn thy => mk_eq_eqns thy dtco)
+      #-> (fn (thms, thm) =>
+        Code.add_nbe_eqn thm
+        #> fold_rev Code.add_eqn thms);
+  in
+    thy
+    |> Theory_Target.instantiation (dtcos, vs, [HOLogic.class_eq])
+    |> fold_map add_def dtcos
+    |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
+         (fn _ => fn def_thms => tac def_thms) def_thms)
+    |-> (fn def_thms => fold Code.del_eqn def_thms)
+    |> fold add_eq_thms dtcos
+  end;
+
+
+(* register a datatype etc. *)
+
+fun add_all_code config dtcos thy =
+  let
+    val (vs :: _, coss) = (split_list o map (Datatype_Data.the_spec thy)) dtcos;
+    val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
+    val css = if exists is_none any_css then []
+      else map_filter I any_css;
+    val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) dtcos;
+    val certs = map (mk_case_cert thy) dtcos;
+  in
+    if null css then thy
+    else thy
+      |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
+      |> fold Code.add_datatype css
+      |> fold_rev Code.add_default_eqn case_rewrites
+      |> fold Code.add_case certs
+      |> add_equality vs dtcos
+   end;
+
+
 (** SML code generator **)
 
 open Codegen;
@@ -288,142 +419,11 @@
   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
 
 
-(** generic code generator **)
-
-(* liberal addition of code data for datatypes *)
-
-fun mk_constr_consts thy vs dtco cos =
-  let
-    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
-    val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
-  in if is_some (try (Code.constrset_of_consts thy) cs')
-    then SOME cs
-    else NONE
-  end;
-
-
-(* case certificates *)
-
-fun mk_case_cert thy tyco =
-  let
-    val raw_thms =
-      (#case_rewrites o Datatype_Data.the_info thy) tyco;
-    val thms as hd_thm :: _ = raw_thms
-      |> Conjunction.intr_balanced
-      |> Thm.unvarify_global
-      |> Conjunction.elim_balanced (length raw_thms)
-      |> map Simpdata.mk_meta_eq
-      |> map Drule.zero_var_indexes
-    val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
-      | _ => I) (Thm.prop_of hd_thm) [];
-    val rhs = hd_thm
-      |> Thm.prop_of
-      |> Logic.dest_equals
-      |> fst
-      |> Term.strip_comb
-      |> apsnd (fst o split_last)
-      |> list_comb;
-    val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
-    val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
-  in
-    thms
-    |> Conjunction.intr_balanced
-    |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
-    |> Thm.implies_intr asm
-    |> Thm.generalize ([], params) 0
-    |> AxClass.unoverload thy
-    |> Thm.varifyT_global
-  end;
-
-
-(* equality *)
-
-fun mk_eq_eqns thy dtco =
-  let
-    val (vs, cos) = Datatype_Data.the_spec thy dtco;
-    val { descr, index, inject = inject_thms, distinct = distinct_thms, ... } =
-      Datatype_Data.the_info thy dtco;
-    val ty = Type (dtco, map TFree vs);
-    fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
-      $ t1 $ t2;
-    fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
-    fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
-    val triv_injects = map_filter
-     (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
-       | _ => NONE) cos;
-    fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
-      trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
-    val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr] vs) index);
-    fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
-      [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
-    val distincts = maps prep_distinct (snd (nth (Datatype_Prop.make_distincts [descr] vs) index));
-    val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
-    val simpset = Simplifier.global_context thy (HOL_basic_ss addsimps 
-      (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms @ distinct_thms)));
-    fun prove prop = Skip_Proof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
-      |> Simpdata.mk_eq;
-  in (map prove (triv_injects @ injects @ distincts), prove refl) end;
-
-fun add_equality vs dtcos thy =
-  let
-    fun add_def dtco lthy =
-      let
-        val ty = Type (dtco, map TFree vs);
-        fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
-          $ Free ("x", ty) $ Free ("y", ty);
-        val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
-          (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
-        val def' = Syntax.check_term lthy def;
-        val ((_, (_, thm)), lthy') = Specification.definition
-          (NONE, (Attrib.empty_binding, def')) lthy;
-        val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
-        val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
-      in (thm', lthy') end;
-    fun tac thms = Class.intro_classes_tac []
-      THEN ALLGOALS (ProofContext.fact_tac thms);
-    fun add_eq_thms dtco =
-      Theory.checkpoint
-      #> `(fn thy => mk_eq_eqns thy dtco)
-      #-> (fn (thms, thm) =>
-        Code.add_nbe_eqn thm
-        #> fold_rev Code.add_eqn thms);
-  in
-    thy
-    |> Theory_Target.instantiation (dtcos, vs, [HOLogic.class_eq])
-    |> fold_map add_def dtcos
-    |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
-         (fn _ => fn def_thms => tac def_thms) def_thms)
-    |-> (fn def_thms => fold Code.del_eqn def_thms)
-    |> fold add_eq_thms dtcos
-  end;
-
-
-(* register a datatype etc. *)
-
-fun add_all_code config dtcos thy =
-  let
-    val (vs :: _, coss) = (split_list o map (Datatype_Data.the_spec thy)) dtcos;
-    val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
-    val css = if exists is_none any_css then []
-      else map_filter I any_css;
-    val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) dtcos;
-    val certs = map (mk_case_cert thy) dtcos;
-  in
-    if null css then thy
-    else thy
-      |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
-      |> fold Code.add_datatype css
-      |> fold_rev Code.add_default_eqn case_rewrites
-      |> fold Code.add_case certs
-      |> add_equality vs dtcos
-   end;
-
-
 (** theory setup **)
 
 val setup = 
-  add_codegen "datatype" datatype_codegen
-  #> add_tycodegen "datatype" datatype_tycodegen
-  #> Datatype_Data.interpretation add_all_code
+  Datatype_Data.interpretation add_all_code
+  #> add_codegen "datatype" datatype_codegen
+  #> add_tycodegen "datatype" datatype_tycodegen;
 
 end;