more work on FP sugar
authorblanchet
Tue, 04 Sep 2012 13:05:01 +0200
changeset 49121 9e0acaa470ab
parent 49120 7f8e69fc6ac9
child 49122 83515378d4d7
more work on FP sugar
src/HOL/Codatatype/Tools/bnf_comp.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_gfp_util.ML
src/HOL/Codatatype/Tools/bnf_lfp.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_comp.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_comp.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -630,7 +630,7 @@
         else qualify' (Binding.prefix_name namei bind)
       end;
 
-    val Ass = map (map dest_TFree) tfreess;
+    val Ass = map (map Term.dest_TFree) tfreess;
     val Ds = fold (fold Term.add_tfreesT) (oDs :: Dss) [];
 
     val ((kill_poss, As), (inners', (unfold', lthy'))) =
@@ -781,7 +781,7 @@
           val odead = dead_of_bnf outer;
           val olive = live_of_bnf outer;
           val oDs_pos = find_indices [TFree ("dead", [])]
-            (snd (dest_Type
+            (snd (Term.dest_Type
               (mk_T_of_bnf (replicate odead (TFree ("dead", []))) (replicate olive dummyT) outer)));
           val oDs = map (nth Ts) oDs_pos;
           val Ts' = map (nth Ts) (subtract (op =) oDs_pos (0 upto length Ts - 1));
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -38,29 +38,28 @@
   if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
   else cannot_merge_types ();
 
-fun type_args_constrained_of_spec (((cAs, _), _), _) = cAs;
-fun type_name_of_spec (((_, b), _), _) = b;
-fun mixfix_of_spec ((_, mx), _) = mx;
-fun ctr_specs_of_spec (_, ctr_specs) = ctr_specs;
+fun type_args_constrained_of (((cAs, _), _), _) = cAs;
+val type_args_of = map fst o type_args_constrained_of;
+fun type_name_of (((_, b), _), _) = b;
+fun mixfix_of_typ ((_, mx), _) = mx;
+fun ctr_specs_of (_, ctr_specs) = ctr_specs;
 
-fun disc_of_ctr_spec (((disc, _), _), _) = disc;
-fun ctr_of_ctr_spec (((_, ctr), _), _) = ctr;
-fun args_of_ctr_spec ((_, args), _) = args;
-fun mixfix_of_ctr_spec (_, mx) = mx;
-
-val mk_prod_sum = mk_sumTN o map HOLogic.mk_tupleT;
+fun disc_of (((disc, _), _), _) = disc;
+fun ctr_of (((_, ctr), _), _) = ctr;
+fun args_of ((_, args), _) = args;
+fun mixfix_of_ctr (_, mx) = mx;
 
 val lfp_info = bnf_lfp;
 val gfp_info = bnf_gfp;
 
-fun prepare_data prepare_typ construct specs lthy =
+fun prepare_data prepare_typ construct specs fake_lthy lthy =
   let
-    val constrained_passiveAs =
-      map (map (apfst (prepare_typ lthy)) o type_args_constrained_of_spec) specs
+    val constrained_As =
+      map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
       |> Library.foldr1 (merge_type_args_constrained lthy);
-    val passiveAs = map fst constrained_passiveAs;
+    val As = map fst constrained_As;
 
-    val _ = (case duplicates (op =) passiveAs of [] => ()
+    val _ = (case duplicates (op =) As of [] => ()
       | T :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy T)));
 
     (* TODO: check that no type variables occur in the rhss that's not in the lhss *)
@@ -68,41 +67,116 @@
 
     val N = length specs;
 
-    val bs = map type_name_of_spec specs;
-    val mixfixes = map mixfix_of_spec specs;
+    fun mk_T b =
+      Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
+        As);
+
+    val bs = map type_name_of specs;
+    val Ts = map mk_T bs;
+
+    val mixfixes = map mixfix_of_typ specs;
 
     val _ = (case duplicates Binding.eq_name bs of [] => ()
       | b :: _ => error ("Duplicate type name declaration " ^ quote (Binding.name_of b)));
 
-    val ctr_specss = map ctr_specs_of_spec specs;
+    val ctr_specss = map ctr_specs_of specs;
 
-    val disc_namess = map (map disc_of_ctr_spec) ctr_specss;
-    val raw_ctr_namess = map (map ctr_of_ctr_spec) ctr_specss;
-    val ctr_argsss = map (map args_of_ctr_spec) ctr_specss;
-    val ctr_mixfixess = map (map mixfix_of_ctr_spec) ctr_specss;
+    val disc_namess = map (map disc_of) ctr_specss;
+    val ctr_namess = map (map ctr_of) ctr_specss;
+    val ctr_argsss = map (map args_of) ctr_specss;
+    val ctr_mixfixess = map (map mixfix_of_ctr) ctr_specss;
 
     val sel_namesss = map (map (map fst)) ctr_argsss;
-    val ctr_Tsss = map (map (map (prepare_typ lthy o snd))) ctr_argsss;
+    val ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
+
+    val (Bs, C) =
+      lthy
+      |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
+      |> mk_TFrees N
+      ||> the_single o fst o mk_TFrees 1;
 
-    val (activeAs, _) = lthy |> mk_TFrees N;
+    fun freeze_rec (T as Type (s, Ts')) =
+        (case find_index (curry (op =) T) Ts of
+          ~1 => Type (s, map freeze_rec Ts')
+        | i => nth Bs i)
+      | freeze_rec T = T;
+
+    val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
+    val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
 
-    val eqs = map2 (fn TFree A => fn Tss => (A, mk_prod_sum Tss)) activeAs ctr_Tsss;
+    val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
+
+    val (raw_flds, lthy') = fp_bnf construct bs eqs lthy;
+
+    fun mk_fld Ts fld =
+      let val Type (_, Ts0) = body_type (fastype_of fld) in
+        Term.subst_atomic_types (Ts0 ~~ Ts) fld
+      end;
 
-    val lthy' = fp_bnf construct bs eqs lthy;
+    val flds = map (mk_fld As) raw_flds;
+
+    fun wrap_type (((((T, fld), ctr_names), ctr_Tss), disc_names), sel_namess) no_defs_lthy =
+      let
+        val n = length ctr_names;
+        val ks = 1 upto n;
+        val ms = map length ctr_Tss;
+
+        val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
 
-    fun wrap_type ((b, disc_names), sel_namess) lthy =
-      let
-        val ctrs = [];
-        val caseof = @{term True};
-        val tacss = [];
+        val (xss, _) = lthy |> mk_Freess "x" ctr_Tss;
+
+        val rhss =
+          map2 (fn k => fn xs =>
+            fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
+
+        val ((raw_ctrs, raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
+          |> apfst split_list o fold_map2 (fn b => fn rhs =>
+               Local_Theory.define ((b, NoSyn), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
+             ctr_names rhss
+          ||> `Local_Theory.restore;
+
+        val raw_caseof =
+          Const (@{const_name undefined}, map (fn Ts => Ts ---> C) ctr_Tss ---> T --> C);
+
+        (*transforms defined frees into consts (and more)*)
+        val phi = Proof_Context.export_morphism lthy lthy';
+
+        val ctr_defs = map (Morphism.thm phi) raw_ctr_defs;
+
+        val ctrs = map (Morphism.term phi) raw_ctrs;
+
+        val caseof = Morphism.term phi raw_caseof;
+
+        (* ### *)
+        fun cheat_tac {context = ctxt, ...} = Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt);
+
+        val exhaust_tac = cheat_tac;
+
+        val inject_tacss = map (fn 0 => [] | _ => [cheat_tac]) ms;
+
+        val half_distinct_tacss = map (map (K cheat_tac)) (mk_half_pairss ks);
+
+        val case_tacs = map (K cheat_tac) ks;
+
+        val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
       in
-        wrap tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy
+        wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
       end;
   in
-    lthy' |> fold wrap_type (bs ~~ disc_namess ~~ sel_namesss)
+    lthy' |> fold wrap_type (Ts ~~ flds ~~ ctr_namess ~~ ctr_Tsss ~~ disc_namess ~~ sel_namesss)
   end;
 
-val data_cmd = prepare_data Syntax.read_typ;
+fun data_cmd info specs lthy =
+  let
+    val fake_lthy =
+      Proof_Context.theory_of lthy
+      |> Theory.copy
+      |> Sign.add_types_global (map (fn spec =>
+        (type_name_of spec, length (type_args_constrained_of spec), mixfix_of_typ spec)) specs)
+      |> Proof_Context.init_global
+  in
+    prepare_data Syntax.read_typ info specs fake_lthy lthy
+  end;
 
 val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_name
 
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -75,6 +75,13 @@
   val split_conj_thm: thm -> thm list
   val split_conj_prems: int -> thm -> thm
 
+  val Inl_const: typ -> typ -> term
+  val Inr_const: typ -> typ -> term
+
+  val mk_Inl: term -> typ -> term
+  val mk_Inr: term -> typ -> term
+  val mk_InN: typ list -> term -> int -> term
+
   val mk_Field: term -> term
   val mk_union: term * term -> term
 
@@ -82,12 +89,10 @@
 
   val fixpoint: ('a * 'a -> bool) -> ('a list -> 'a list) -> 'a list -> 'a list
 
-  val fp_bnf: (binding list -> typ list list -> BNF_Def.BNF list ->
-    Proof.context -> Proof.context) ->
-    binding list -> ((string * sort) * typ) list -> Proof.context -> Proof.context
-  val fp_bnf_cmd: (binding list -> typ list list -> BNF_Def.BNF list ->
-    Proof.context -> Proof.context) ->
-    binding list * (string list * string list) -> Proof.context -> Proof.context
+  val fp_bnf: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
+    binding list -> ((string * sort) * typ) list -> Proof.context -> 'a
+  val fp_bnf_cmd: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
+    binding list * (string list * string list) -> Proof.context -> 'a
 end;
 
 structure BNF_FP_Util : BNF_FP_UTIL =
@@ -175,6 +180,17 @@
 val set_inclN = "set_incl"
 val set_set_inclN = "set_set_incl"
 
+fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
+fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
+
+fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
+fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
+
+fun mk_InN [_] t 1 = t
+  | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
+  | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
+  | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
+
 fun mk_Field r =
   let val T = fst (dest_relT (fastype_of r));
   in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -9,7 +9,8 @@
 
 signature BNF_GFP =
 sig
-  val bnf_gfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> Proof.context
+  val bnf_gfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context ->
+    term list * Proof.context
 end;
 
 structure BNF_GFP : BNF_GFP =
@@ -91,7 +92,7 @@
 
     (* typs *)
     fun mk_FTs Ts = map2 (fn Ds => mk_T_of_bnf Ds Ts) Dss bnfs;
-    val (params, params') = `(map dest_TFree) (deads @ passiveAs);
+    val (params, params') = `(map Term.dest_TFree) (deads @ passiveAs);
     val FTsAs = mk_FTs allAs;
     val FTsBs = mk_FTs allBs;
     val FTsCs = mk_FTs allCs;
@@ -2995,13 +2996,13 @@
             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
           bs thmss)
   in
-    lthy |> Local_Theory.notes (common_notes @ notes) |> snd
+    (flds, lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
   end;
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "codata_raw"} "greatest fixed points for BNF equations"
     (Parse.and_list1
       ((Parse.binding --| Parse.$$$ ":") -- (Parse.typ --| Parse.$$$ "=" -- Parse.typ)) >>
-      (fp_bnf_cmd bnf_gfp o apsnd split_list o split_list));
+      (snd oo fp_bnf_cmd bnf_gfp o apsnd split_list o split_list));
 
 end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -35,14 +35,6 @@
   val mk_undefined: typ -> term
   val mk_univ: term -> term
 
-  val Inl_const: typ -> typ -> term
-  val Inr_const: typ -> typ -> term
-
-  val mk_Inl: term -> typ -> term
-  val mk_Inr: term -> typ -> term
-
-  val mk_InN: typ list -> term -> int -> term
-
   val mk_sum_case: term -> term -> term
   val mk_sum_caseN: term list -> term
 
@@ -191,17 +183,6 @@
       A $ f1 $ f2 $ b1 $ b2
   end;
 
-fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
-fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
-
-fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
-fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
-
-fun mk_InN [_] t 1 = t
-  | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
-  | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
-  | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
-
 fun mk_sum_case f g =
   let
     val fT = fastype_of f;
--- a/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -8,7 +8,8 @@
 
 signature BNF_LFP =
 sig
-  val bnf_lfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> Proof.context
+  val bnf_lfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context ->
+    term list * Proof.context
 end;
 
 structure BNF_LFP : BNF_LFP =
@@ -57,7 +58,7 @@
 
     (* typs *)
     fun mk_FTs Ts = map2 (fn Ds => mk_T_of_bnf Ds Ts) Dss bnfs;
-    val (params, params') = `(map dest_TFree) (deads @ passiveAs);
+    val (params, params') = `(map Term.dest_TFree) (deads @ passiveAs);
     val FTsAs = mk_FTs allAs;
     val FTsBs = mk_FTs allBs;
     val FTsCs = mk_FTs allCs;
@@ -1817,13 +1818,13 @@
             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
           bs thmss)
   in
-    lthy |> Local_Theory.notes (common_notes @ notes) |> snd
+    (flds, lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
   end;
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "data_raw"} "least fixed points for BNF equations"
     (Parse.and_list1
       ((Parse.binding --| Parse.$$$ ":") -- (Parse.typ --| Parse.$$$ "=" -- Parse.typ)) >>
-      (fp_bnf_cmd bnf_lfp o apsnd split_list o split_list));
+      (snd oo fp_bnf_cmd bnf_lfp o apsnd split_list o split_list));
 
 end;
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 13:02:32 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 13:05:01 2012 +0200
@@ -8,7 +8,8 @@
 signature BNF_WRAP =
 sig
   val no_name: binding
-  val wrap: ({prems: thm list, context: Proof.context} -> tactic) list list ->
+  val mk_half_pairss: 'a list -> ('a * 'a) list list
+  val wrap_data: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
 end;
 
@@ -62,7 +63,7 @@
   | Free (s, _) => s
   | _ => error "Cannot extract name of constructor";
 
-fun prepare_wrap prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
+fun prepare_wrap_data prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
   no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
@@ -76,7 +77,9 @@
     val n = length ctrs0;
     val ks = 1 upto n;
 
-    val (T_name, As0) = dest_Type (body_type (fastype_of (hd ctrs0)));
+    val _ = if n > 0 then () else error "No constructors specified";
+
+    val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
     val b = Binding.qualified_name T_name;
 
     val (As, B) =
@@ -85,7 +88,7 @@
       ||> the_single o fst o mk_TFrees 1;
 
     fun mk_ctr Ts ctr =
-      let val Ts0 = snd (dest_Type (body_type (fastype_of ctr))) in
+      let val Type (_, Ts0) = body_type (fastype_of ctr) in
         Term.subst_atomic_types (Ts0 ~~ Ts) ctr
       end;
 
@@ -127,9 +130,10 @@
           sel) (1 upto m) o pad_list no_name m) ctrs0 ms;
 
     fun mk_caseof Ts T =
-      let val (binders, body) = strip_type (fastype_of caseof0) in
-        Term.subst_atomic_types ((body, T) :: (snd (dest_Type (List.last binders)) ~~ Ts)) caseof0
-      end;
+      let
+        val (binders, body) = strip_type (fastype_of caseof0)
+        val Type (_, Ts0) = List.last binders
+      in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) caseof0 end;
 
     val caseofB = mk_caseof As B;
     val caseofB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
@@ -207,7 +211,7 @@
     val selss0 = map (map (Morphism.term phi)) raw_selss;
 
     fun mk_disc_or_sel Ts t =
-      Term.subst_atomic_types (snd (dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
+      Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
 
     val discs = map (mk_disc_or_sel As) discs0;
     val selss = map (map (mk_disc_or_sel As)) selss0;
@@ -216,25 +220,33 @@
 
     val goal_exhaust =
       let fun mk_prem xctr xs = fold_rev Logic.all xs (mk_imp_p [mk_Trueprop_eq (v, xctr)]) in
-        mk_imp_p (map2 mk_prem xctrs xss)
+        fold_rev Logic.all [p, v] (mk_imp_p (map2 mk_prem xctrs xss))
       end;
 
     val goal_injectss =
       let
         fun mk_goal _ _ [] [] = []
           | mk_goal xctr yctr xs ys =
-            [mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
-              Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys))];
+            [fold_rev Logic.all (xs @ ys) (mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
+              Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys)))];
       in
         map4 mk_goal xctrs yctrs xss yss
       end;
 
     val goal_half_distinctss =
-      map (map (HOLogic.mk_Trueprop o HOLogic.mk_not o HOLogic.mk_eq)) (mk_half_pairss xctrs);
+      let
+        fun mk_goal ((xs, t), (xs', t')) =
+          fold_rev Logic.all (xs @ xs')
+            (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (t, t'))));
+      in
+        map (map mk_goal) (mk_half_pairss (xss ~~ xctrs))
+      end;
 
-    val goal_cases = map2 (fn xctr => fn xf => mk_Trueprop_eq (caseofB_fs $ xctr, xf)) xctrs xfs;
+    val goal_cases =
+      map3 (fn xs => fn xctr => fn xf =>
+        fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (caseofB_fs $ xctr, xf))) xss xctrs xfs;
 
-    val goals = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
+    val goalss = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
 
     fun after_qed thmss lthy =
       let
@@ -356,7 +368,7 @@
           else
             let
               fun mk_prem disc = mk_imp_p [HOLogic.mk_Trueprop (betapply (disc, v))];
-              val goal = fold Logic.all [p, v] (mk_imp_p (map mk_prem discs));
+              val goal = fold_rev Logic.all [p, v] (mk_imp_p (map mk_prem discs));
             in
               [Skip_Proof.prove lthy [] [] goal (fn _ =>
                  mk_disc_exhaust_tac n exhaust_thm discI_thms)]
@@ -455,9 +467,9 @@
            (disc_exhaustN, disc_exhaust_thms),
            (distinctN, distinct_thms),
            (exhaustN, [exhaust_thm]),
-           (injectN, (flat inject_thmss)),
+           (injectN, flat inject_thmss),
            (nchotomyN, [nchotomy_thm]),
-           (selsN, (flat sel_thmss)),
+           (selsN, flat sel_thmss),
            (splitN, [split_thm]),
            (split_asmN, [split_asm_thm]),
            (weak_case_cong_thmsN, [weak_case_cong_thm])]
@@ -468,20 +480,20 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
   in
-    (goals, after_qed, lthy')
+    (goalss, after_qed, lthy')
   end;
 
-fun wrap tacss = (fn (goalss, after_qed, lthy) =>
+fun wrap_data tacss = (fn (goalss, after_qed, lthy) =>
   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   |> (fn thms => after_qed thms lthy)) oo
-  prepare_wrap (singleton o Type_Infer_Context.infer_types)
+  prepare_wrap_data (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
 
 val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
 val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
 
 val wrap_data_cmd = (fn (goalss, after_qed, lthy) =>
   Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
-  prepare_wrap Syntax.read_term;
+  prepare_wrap_data Syntax.read_term;
 
 val _ =
   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"