repaired constant types
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49203 262ab1ac38b9
parent 49202 f493cd25737f
child 49204 0b735fb2602e
repaired constant types
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -128,7 +128,7 @@
 
     val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
 
-    val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') =
+    val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy') =
       fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
 
     val timer = time (Timer.startRealTimer ());
@@ -141,23 +141,29 @@
     val mk_unf = mk_unf_or_fld domain_type;
     val mk_fld = mk_unf_or_fld range_type;
 
-    val unfs = map (mk_unf As) raw_unfs;
-    val flds = map (mk_fld As) raw_flds;
+    val unfs = map (mk_unf As) unfs0;
+    val flds = map (mk_fld As) flds0;
 
     val fpTs = map (domain_type o fastype_of) unfs;
     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
 
-    fun mk_fp_iter_or_rec Ts Us c =
+    val ns = map length ctr_Tsss;
+    val mss = map (map length) ctr_Tsss;
+    val Css = map2 replicate ns Cs;
+    val Cs' = flat Css;
+
+    fun mk_iter_or_rec Ts Us c =
       let
         val (binders, body) = strip_type (fastype_of c);
-        val Type (_, Ts0) = if gfp then body else List.last binders;
-        val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
+        val (fst_binders, last_binder) = split_last binders;
+        val Type (_, Ts0) = if gfp then body else last_binder;
+        val Us0 = map (if gfp then domain_type else body_type) fst_binders;
       in
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
       end;
 
-    val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
-    val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
+    val fp_iters = map (mk_iter_or_rec As Cs) fp_iters0;
+    val fp_recs = map (mk_iter_or_rec As Cs) fp_recs0;
 
     fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
@@ -199,8 +205,10 @@
         val ctr_defs = map (Morphism.thm phi) raw_ctr_defs;
         val case_def = Morphism.thm phi raw_case_def;
 
-        val ctrs = map (Morphism.term phi) raw_ctrs;
-        val casex = Morphism.term phi raw_case;
+        val ctrs0 = map (Morphism.term phi) raw_ctrs;
+        val casex0 = Morphism.term phi raw_case;
+
+        val ctrs = map (mk_ctr As) ctrs0;
 
         fun exhaust_tac {context = ctxt, ...} =
           let
@@ -245,10 +253,6 @@
 
         val is_fpT = member (op =) fpTs;
 
-        val ns = map length ctr_Tsss;
-        val mss = map (map length) ctr_Tsss;
-        val Css = map2 replicate ns Cs;
-
         fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
             if member (op =) Cs U then Us else [T]
           | dest_rec_pair T = [T];
@@ -303,15 +307,18 @@
             val iter_def = Morphism.thm phi raw_iter_def;
             val rec_def = Morphism.thm phi raw_rec_def;
 
-            val iter = Morphism.term phi raw_iter;
-            val recx = Morphism.term phi raw_rec;
+            val iter0 = Morphism.term phi raw_iter;
+            val rec0 = Morphism.term phi raw_rec;
+
+            val iter = mk_iter_or_rec As Cs' iter0;
+            val recx = mk_iter_or_rec As Cs' rec0;
           in
             ([[ctrs], [[iter]], [[recx]], xss, gss, hss], lthy)
           end;
 
         fun sugar_codatatype no_defs_lthy = ([], no_defs_lthy);
       in
-        wrap_datatype tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
+        wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
         |> (if gfp then sugar_codatatype else sugar_datatype)
       end;
 
@@ -327,9 +334,12 @@
               mk_Trueprop_eq (fc $ xctr, fc $ xctr);
 
             val goal_iterss = map2 (fn giter => map (mk_goal_iter_or_rec giter)) giters xctrss;
-            val goal_recss = [];
-            val iter_tacss = []; (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *)
-            val rec_tacss = [];
+            val goal_recss = map2 (fn hrec => map (mk_goal_iter_or_rec hrec)) hrecs xctrss;
+            val iter_tacss =
+              map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_iterss;
+              (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *)
+            val rec_tacss =
+              map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_recss;
           in
             (map2 (map2 (Skip_Proof.prove lthy [] [])) goal_iterss iter_tacss,
              map2 (map2 (Skip_Proof.prove lthy [] [])) goal_recss rec_tacss)
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -9,6 +9,7 @@
 sig
   val no_binder: binding
   val mk_half_pairss: 'a list -> ('a * 'a) list list
+  val mk_ctr: typ list -> term -> term
   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
 end;
@@ -54,10 +55,15 @@
 (* TODO: provide a way to have a different default value, e.g. "tl Nil = Nil" *)
 fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
 
+fun mk_ctr Ts ctr =
+  let val Type (_, Ts0) = body_type (fastype_of ctr) in
+    Term.subst_atomic_types (Ts0 ~~ Ts) ctr
+  end;
+
 fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
 
-fun name_of_ctr t =
-  case head_of t of
+fun name_of_ctr c =
+  case head_of c of
     Const (s, _) => s
   | Free (s, _) => s
   | _ => error "Cannot extract name of constructor";
@@ -86,11 +92,6 @@
       |> mk_TFrees (length As0)
       ||> the_single o fst o mk_TFrees 1;
 
-    fun mk_ctr Ts ctr =
-      let val Type (_, Ts0) = body_type (fastype_of ctr) in
-        Term.subst_atomic_types (Ts0 ~~ Ts) ctr
-      end;
-
     val T = Type (T_name, As);
     val ctrs = map (mk_ctr As) ctrs0;
     val ctr_Tss = map (binder_types o fastype_of) ctrs;
@@ -220,8 +221,8 @@
     val discs0 = map (Morphism.term phi) raw_discs;
     val selss0 = map (map (Morphism.term phi)) raw_selss;
 
-    fun mk_disc_or_sel Ts t =
-      Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
+    fun mk_disc_or_sel Ts c =
+      Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of c))) ~~ Ts) c;
 
     val discs = map (mk_disc_or_sel As) discs0;
     val selss = map (map (mk_disc_or_sel As)) selss0;
@@ -245,9 +246,9 @@
 
     val goal_half_distinctss =
       let
-        fun mk_goal ((xs, t), (xs', t')) =
+        fun mk_goal ((xs, xc), (xs', xc')) =
           fold_rev Logic.all (xs @ xs')
-            (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (t, t'))));
+            (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (xc, xc'))));
       in
         map (map mk_goal) (mk_half_pairss (xss ~~ xctrs))
       end;