more work on FP sugar
authorblanchet
Tue Sep 04 13:05:01 2012 +0200 (2012-09-04)
changeset 491219e0acaa470ab
parent 49120 7f8e69fc6ac9
child 49122 83515378d4d7
more work on FP sugar
src/HOL/Codatatype/Tools/bnf_comp.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_gfp_util.ML
src/HOL/Codatatype/Tools/bnf_lfp.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_comp.ML	Tue Sep 04 13:02:32 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_comp.ML	Tue Sep 04 13:05:01 2012 +0200
     1.3 @@ -630,7 +630,7 @@
     1.4          else qualify' (Binding.prefix_name namei bind)
     1.5        end;
     1.6  
     1.7 -    val Ass = map (map dest_TFree) tfreess;
     1.8 +    val Ass = map (map Term.dest_TFree) tfreess;
     1.9      val Ds = fold (fold Term.add_tfreesT) (oDs :: Dss) [];
    1.10  
    1.11      val ((kill_poss, As), (inners', (unfold', lthy'))) =
    1.12 @@ -781,7 +781,7 @@
    1.13            val odead = dead_of_bnf outer;
    1.14            val olive = live_of_bnf outer;
    1.15            val oDs_pos = find_indices [TFree ("dead", [])]
    1.16 -            (snd (dest_Type
    1.17 +            (snd (Term.dest_Type
    1.18                (mk_T_of_bnf (replicate odead (TFree ("dead", []))) (replicate olive dummyT) outer)));
    1.19            val oDs = map (nth Ts) oDs_pos;
    1.20            val Ts' = map (nth Ts) (subtract (op =) oDs_pos (0 upto length Ts - 1));
     2.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 13:02:32 2012 +0200
     2.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 13:05:01 2012 +0200
     2.3 @@ -38,29 +38,28 @@
     2.4    if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
     2.5    else cannot_merge_types ();
     2.6  
     2.7 -fun type_args_constrained_of_spec (((cAs, _), _), _) = cAs;
     2.8 -fun type_name_of_spec (((_, b), _), _) = b;
     2.9 -fun mixfix_of_spec ((_, mx), _) = mx;
    2.10 -fun ctr_specs_of_spec (_, ctr_specs) = ctr_specs;
    2.11 +fun type_args_constrained_of (((cAs, _), _), _) = cAs;
    2.12 +val type_args_of = map fst o type_args_constrained_of;
    2.13 +fun type_name_of (((_, b), _), _) = b;
    2.14 +fun mixfix_of_typ ((_, mx), _) = mx;
    2.15 +fun ctr_specs_of (_, ctr_specs) = ctr_specs;
    2.16  
    2.17 -fun disc_of_ctr_spec (((disc, _), _), _) = disc;
    2.18 -fun ctr_of_ctr_spec (((_, ctr), _), _) = ctr;
    2.19 -fun args_of_ctr_spec ((_, args), _) = args;
    2.20 -fun mixfix_of_ctr_spec (_, mx) = mx;
    2.21 -
    2.22 -val mk_prod_sum = mk_sumTN o map HOLogic.mk_tupleT;
    2.23 +fun disc_of (((disc, _), _), _) = disc;
    2.24 +fun ctr_of (((_, ctr), _), _) = ctr;
    2.25 +fun args_of ((_, args), _) = args;
    2.26 +fun mixfix_of_ctr (_, mx) = mx;
    2.27  
    2.28  val lfp_info = bnf_lfp;
    2.29  val gfp_info = bnf_gfp;
    2.30  
    2.31 -fun prepare_data prepare_typ construct specs lthy =
    2.32 +fun prepare_data prepare_typ construct specs fake_lthy lthy =
    2.33    let
    2.34 -    val constrained_passiveAs =
    2.35 -      map (map (apfst (prepare_typ lthy)) o type_args_constrained_of_spec) specs
    2.36 +    val constrained_As =
    2.37 +      map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
    2.38        |> Library.foldr1 (merge_type_args_constrained lthy);
    2.39 -    val passiveAs = map fst constrained_passiveAs;
    2.40 +    val As = map fst constrained_As;
    2.41  
    2.42 -    val _ = (case duplicates (op =) passiveAs of [] => ()
    2.43 +    val _ = (case duplicates (op =) As of [] => ()
    2.44        | T :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy T)));
    2.45  
    2.46      (* TODO: check that no type variables occur in the rhss that's not in the lhss *)
    2.47 @@ -68,41 +67,116 @@
    2.48  
    2.49      val N = length specs;
    2.50  
    2.51 -    val bs = map type_name_of_spec specs;
    2.52 -    val mixfixes = map mixfix_of_spec specs;
    2.53 +    fun mk_T b =
    2.54 +      Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
    2.55 +        As);
    2.56 +
    2.57 +    val bs = map type_name_of specs;
    2.58 +    val Ts = map mk_T bs;
    2.59 +
    2.60 +    val mixfixes = map mixfix_of_typ specs;
    2.61  
    2.62      val _ = (case duplicates Binding.eq_name bs of [] => ()
    2.63        | b :: _ => error ("Duplicate type name declaration " ^ quote (Binding.name_of b)));
    2.64  
    2.65 -    val ctr_specss = map ctr_specs_of_spec specs;
    2.66 +    val ctr_specss = map ctr_specs_of specs;
    2.67  
    2.68 -    val disc_namess = map (map disc_of_ctr_spec) ctr_specss;
    2.69 -    val raw_ctr_namess = map (map ctr_of_ctr_spec) ctr_specss;
    2.70 -    val ctr_argsss = map (map args_of_ctr_spec) ctr_specss;
    2.71 -    val ctr_mixfixess = map (map mixfix_of_ctr_spec) ctr_specss;
    2.72 +    val disc_namess = map (map disc_of) ctr_specss;
    2.73 +    val ctr_namess = map (map ctr_of) ctr_specss;
    2.74 +    val ctr_argsss = map (map args_of) ctr_specss;
    2.75 +    val ctr_mixfixess = map (map mixfix_of_ctr) ctr_specss;
    2.76  
    2.77      val sel_namesss = map (map (map fst)) ctr_argsss;
    2.78 -    val ctr_Tsss = map (map (map (prepare_typ lthy o snd))) ctr_argsss;
    2.79 +    val ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
    2.80 +
    2.81 +    val (Bs, C) =
    2.82 +      lthy
    2.83 +      |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
    2.84 +      |> mk_TFrees N
    2.85 +      ||> the_single o fst o mk_TFrees 1;
    2.86  
    2.87 -    val (activeAs, _) = lthy |> mk_TFrees N;
    2.88 +    fun freeze_rec (T as Type (s, Ts')) =
    2.89 +        (case find_index (curry (op =) T) Ts of
    2.90 +          ~1 => Type (s, map freeze_rec Ts')
    2.91 +        | i => nth Bs i)
    2.92 +      | freeze_rec T = T;
    2.93 +
    2.94 +    val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
    2.95 +    val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
    2.96  
    2.97 -    val eqs = map2 (fn TFree A => fn Tss => (A, mk_prod_sum Tss)) activeAs ctr_Tsss;
    2.98 +    val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
    2.99 +
   2.100 +    val (raw_flds, lthy') = fp_bnf construct bs eqs lthy;
   2.101 +
   2.102 +    fun mk_fld Ts fld =
   2.103 +      let val Type (_, Ts0) = body_type (fastype_of fld) in
   2.104 +        Term.subst_atomic_types (Ts0 ~~ Ts) fld
   2.105 +      end;
   2.106  
   2.107 -    val lthy' = fp_bnf construct bs eqs lthy;
   2.108 +    val flds = map (mk_fld As) raw_flds;
   2.109 +
   2.110 +    fun wrap_type (((((T, fld), ctr_names), ctr_Tss), disc_names), sel_namess) no_defs_lthy =
   2.111 +      let
   2.112 +        val n = length ctr_names;
   2.113 +        val ks = 1 upto n;
   2.114 +        val ms = map length ctr_Tss;
   2.115 +
   2.116 +        val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   2.117  
   2.118 -    fun wrap_type ((b, disc_names), sel_namess) lthy =
   2.119 -      let
   2.120 -        val ctrs = [];
   2.121 -        val caseof = @{term True};
   2.122 -        val tacss = [];
   2.123 +        val (xss, _) = lthy |> mk_Freess "x" ctr_Tss;
   2.124 +
   2.125 +        val rhss =
   2.126 +          map2 (fn k => fn xs =>
   2.127 +            fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
   2.128 +
   2.129 +        val ((raw_ctrs, raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   2.130 +          |> apfst split_list o fold_map2 (fn b => fn rhs =>
   2.131 +               Local_Theory.define ((b, NoSyn), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   2.132 +             ctr_names rhss
   2.133 +          ||> `Local_Theory.restore;
   2.134 +
   2.135 +        val raw_caseof =
   2.136 +          Const (@{const_name undefined}, map (fn Ts => Ts ---> C) ctr_Tss ---> T --> C);
   2.137 +
   2.138 +        (*transforms defined frees into consts (and more)*)
   2.139 +        val phi = Proof_Context.export_morphism lthy lthy';
   2.140 +
   2.141 +        val ctr_defs = map (Morphism.thm phi) raw_ctr_defs;
   2.142 +
   2.143 +        val ctrs = map (Morphism.term phi) raw_ctrs;
   2.144 +
   2.145 +        val caseof = Morphism.term phi raw_caseof;
   2.146 +
   2.147 +        (* ### *)
   2.148 +        fun cheat_tac {context = ctxt, ...} = Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt);
   2.149 +
   2.150 +        val exhaust_tac = cheat_tac;
   2.151 +
   2.152 +        val inject_tacss = map (fn 0 => [] | _ => [cheat_tac]) ms;
   2.153 +
   2.154 +        val half_distinct_tacss = map (map (K cheat_tac)) (mk_half_pairss ks);
   2.155 +
   2.156 +        val case_tacs = map (K cheat_tac) ks;
   2.157 +
   2.158 +        val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   2.159        in
   2.160 -        wrap tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy
   2.161 +        wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
   2.162        end;
   2.163    in
   2.164 -    lthy' |> fold wrap_type (bs ~~ disc_namess ~~ sel_namesss)
   2.165 +    lthy' |> fold wrap_type (Ts ~~ flds ~~ ctr_namess ~~ ctr_Tsss ~~ disc_namess ~~ sel_namesss)
   2.166    end;
   2.167  
   2.168 -val data_cmd = prepare_data Syntax.read_typ;
   2.169 +fun data_cmd info specs lthy =
   2.170 +  let
   2.171 +    val fake_lthy =
   2.172 +      Proof_Context.theory_of lthy
   2.173 +      |> Theory.copy
   2.174 +      |> Sign.add_types_global (map (fn spec =>
   2.175 +        (type_name_of spec, length (type_args_constrained_of spec), mixfix_of_typ spec)) specs)
   2.176 +      |> Proof_Context.init_global
   2.177 +  in
   2.178 +    prepare_data Syntax.read_typ info specs fake_lthy lthy
   2.179 +  end;
   2.180  
   2.181  val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_name
   2.182  
     3.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 13:02:32 2012 +0200
     3.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 13:05:01 2012 +0200
     3.3 @@ -75,6 +75,13 @@
     3.4    val split_conj_thm: thm -> thm list
     3.5    val split_conj_prems: int -> thm -> thm
     3.6  
     3.7 +  val Inl_const: typ -> typ -> term
     3.8 +  val Inr_const: typ -> typ -> term
     3.9 +
    3.10 +  val mk_Inl: term -> typ -> term
    3.11 +  val mk_Inr: term -> typ -> term
    3.12 +  val mk_InN: typ list -> term -> int -> term
    3.13 +
    3.14    val mk_Field: term -> term
    3.15    val mk_union: term * term -> term
    3.16  
    3.17 @@ -82,12 +89,10 @@
    3.18  
    3.19    val fixpoint: ('a * 'a -> bool) -> ('a list -> 'a list) -> 'a list -> 'a list
    3.20  
    3.21 -  val fp_bnf: (binding list -> typ list list -> BNF_Def.BNF list ->
    3.22 -    Proof.context -> Proof.context) ->
    3.23 -    binding list -> ((string * sort) * typ) list -> Proof.context -> Proof.context
    3.24 -  val fp_bnf_cmd: (binding list -> typ list list -> BNF_Def.BNF list ->
    3.25 -    Proof.context -> Proof.context) ->
    3.26 -    binding list * (string list * string list) -> Proof.context -> Proof.context
    3.27 +  val fp_bnf: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
    3.28 +    binding list -> ((string * sort) * typ) list -> Proof.context -> 'a
    3.29 +  val fp_bnf_cmd: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
    3.30 +    binding list * (string list * string list) -> Proof.context -> 'a
    3.31  end;
    3.32  
    3.33  structure BNF_FP_Util : BNF_FP_UTIL =
    3.34 @@ -175,6 +180,17 @@
    3.35  val set_inclN = "set_incl"
    3.36  val set_set_inclN = "set_set_incl"
    3.37  
    3.38 +fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
    3.39 +fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
    3.40 +
    3.41 +fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
    3.42 +fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
    3.43 +
    3.44 +fun mk_InN [_] t 1 = t
    3.45 +  | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
    3.46 +  | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
    3.47 +  | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
    3.48 +
    3.49  fun mk_Field r =
    3.50    let val T = fst (dest_relT (fastype_of r));
    3.51    in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
     4.1 --- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 13:02:32 2012 +0200
     4.2 +++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 13:05:01 2012 +0200
     4.3 @@ -9,7 +9,8 @@
     4.4  
     4.5  signature BNF_GFP =
     4.6  sig
     4.7 -  val bnf_gfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> Proof.context
     4.8 +  val bnf_gfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context ->
     4.9 +    term list * Proof.context
    4.10  end;
    4.11  
    4.12  structure BNF_GFP : BNF_GFP =
    4.13 @@ -91,7 +92,7 @@
    4.14  
    4.15      (* typs *)
    4.16      fun mk_FTs Ts = map2 (fn Ds => mk_T_of_bnf Ds Ts) Dss bnfs;
    4.17 -    val (params, params') = `(map dest_TFree) (deads @ passiveAs);
    4.18 +    val (params, params') = `(map Term.dest_TFree) (deads @ passiveAs);
    4.19      val FTsAs = mk_FTs allAs;
    4.20      val FTsBs = mk_FTs allBs;
    4.21      val FTsCs = mk_FTs allCs;
    4.22 @@ -2995,13 +2996,13 @@
    4.23              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
    4.24            bs thmss)
    4.25    in
    4.26 -    lthy |> Local_Theory.notes (common_notes @ notes) |> snd
    4.27 +    (flds, lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
    4.28    end;
    4.29  
    4.30  val _ =
    4.31    Outer_Syntax.local_theory @{command_spec "codata_raw"} "greatest fixed points for BNF equations"
    4.32      (Parse.and_list1
    4.33        ((Parse.binding --| Parse.$$$ ":") -- (Parse.typ --| Parse.$$$ "=" -- Parse.typ)) >>
    4.34 -      (fp_bnf_cmd bnf_gfp o apsnd split_list o split_list));
    4.35 +      (snd oo fp_bnf_cmd bnf_gfp o apsnd split_list o split_list));
    4.36  
    4.37  end;
     5.1 --- a/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 13:02:32 2012 +0200
     5.2 +++ b/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 13:05:01 2012 +0200
     5.3 @@ -35,14 +35,6 @@
     5.4    val mk_undefined: typ -> term
     5.5    val mk_univ: term -> term
     5.6  
     5.7 -  val Inl_const: typ -> typ -> term
     5.8 -  val Inr_const: typ -> typ -> term
     5.9 -
    5.10 -  val mk_Inl: term -> typ -> term
    5.11 -  val mk_Inr: term -> typ -> term
    5.12 -
    5.13 -  val mk_InN: typ list -> term -> int -> term
    5.14 -
    5.15    val mk_sum_case: term -> term -> term
    5.16    val mk_sum_caseN: term list -> term
    5.17  
    5.18 @@ -191,17 +183,6 @@
    5.19        A $ f1 $ f2 $ b1 $ b2
    5.20    end;
    5.21  
    5.22 -fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
    5.23 -fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
    5.24 -
    5.25 -fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
    5.26 -fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
    5.27 -
    5.28 -fun mk_InN [_] t 1 = t
    5.29 -  | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
    5.30 -  | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
    5.31 -  | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
    5.32 -
    5.33  fun mk_sum_case f g =
    5.34    let
    5.35      val fT = fastype_of f;
     6.1 --- a/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 13:02:32 2012 +0200
     6.2 +++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 13:05:01 2012 +0200
     6.3 @@ -8,7 +8,8 @@
     6.4  
     6.5  signature BNF_LFP =
     6.6  sig
     6.7 -  val bnf_lfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> Proof.context
     6.8 +  val bnf_lfp: binding list -> typ list list -> BNF_Def.BNF list -> Proof.context ->
     6.9 +    term list * Proof.context
    6.10  end;
    6.11  
    6.12  structure BNF_LFP : BNF_LFP =
    6.13 @@ -57,7 +58,7 @@
    6.14  
    6.15      (* typs *)
    6.16      fun mk_FTs Ts = map2 (fn Ds => mk_T_of_bnf Ds Ts) Dss bnfs;
    6.17 -    val (params, params') = `(map dest_TFree) (deads @ passiveAs);
    6.18 +    val (params, params') = `(map Term.dest_TFree) (deads @ passiveAs);
    6.19      val FTsAs = mk_FTs allAs;
    6.20      val FTsBs = mk_FTs allBs;
    6.21      val FTsCs = mk_FTs allCs;
    6.22 @@ -1817,13 +1818,13 @@
    6.23              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
    6.24            bs thmss)
    6.25    in
    6.26 -    lthy |> Local_Theory.notes (common_notes @ notes) |> snd
    6.27 +    (flds, lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
    6.28    end;
    6.29  
    6.30  val _ =
    6.31    Outer_Syntax.local_theory @{command_spec "data_raw"} "least fixed points for BNF equations"
    6.32      (Parse.and_list1
    6.33        ((Parse.binding --| Parse.$$$ ":") -- (Parse.typ --| Parse.$$$ "=" -- Parse.typ)) >>
    6.34 -      (fp_bnf_cmd bnf_lfp o apsnd split_list o split_list));
    6.35 +      (snd oo fp_bnf_cmd bnf_lfp o apsnd split_list o split_list));
    6.36  
    6.37  end;
     7.1 --- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 13:02:32 2012 +0200
     7.2 +++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 13:05:01 2012 +0200
     7.3 @@ -8,7 +8,8 @@
     7.4  signature BNF_WRAP =
     7.5  sig
     7.6    val no_name: binding
     7.7 -  val wrap: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     7.8 +  val mk_half_pairss: 'a list -> ('a * 'a) list list
     7.9 +  val wrap_data: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    7.10      (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
    7.11  end;
    7.12  
    7.13 @@ -62,7 +63,7 @@
    7.14    | Free (s, _) => s
    7.15    | _ => error "Cannot extract name of constructor";
    7.16  
    7.17 -fun prepare_wrap prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
    7.18 +fun prepare_wrap_data prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
    7.19    no_defs_lthy =
    7.20    let
    7.21      (* TODO: sanity checks on arguments *)
    7.22 @@ -76,7 +77,9 @@
    7.23      val n = length ctrs0;
    7.24      val ks = 1 upto n;
    7.25  
    7.26 -    val (T_name, As0) = dest_Type (body_type (fastype_of (hd ctrs0)));
    7.27 +    val _ = if n > 0 then () else error "No constructors specified";
    7.28 +
    7.29 +    val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
    7.30      val b = Binding.qualified_name T_name;
    7.31  
    7.32      val (As, B) =
    7.33 @@ -85,7 +88,7 @@
    7.34        ||> the_single o fst o mk_TFrees 1;
    7.35  
    7.36      fun mk_ctr Ts ctr =
    7.37 -      let val Ts0 = snd (dest_Type (body_type (fastype_of ctr))) in
    7.38 +      let val Type (_, Ts0) = body_type (fastype_of ctr) in
    7.39          Term.subst_atomic_types (Ts0 ~~ Ts) ctr
    7.40        end;
    7.41  
    7.42 @@ -127,9 +130,10 @@
    7.43            sel) (1 upto m) o pad_list no_name m) ctrs0 ms;
    7.44  
    7.45      fun mk_caseof Ts T =
    7.46 -      let val (binders, body) = strip_type (fastype_of caseof0) in
    7.47 -        Term.subst_atomic_types ((body, T) :: (snd (dest_Type (List.last binders)) ~~ Ts)) caseof0
    7.48 -      end;
    7.49 +      let
    7.50 +        val (binders, body) = strip_type (fastype_of caseof0)
    7.51 +        val Type (_, Ts0) = List.last binders
    7.52 +      in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) caseof0 end;
    7.53  
    7.54      val caseofB = mk_caseof As B;
    7.55      val caseofB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
    7.56 @@ -207,7 +211,7 @@
    7.57      val selss0 = map (map (Morphism.term phi)) raw_selss;
    7.58  
    7.59      fun mk_disc_or_sel Ts t =
    7.60 -      Term.subst_atomic_types (snd (dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
    7.61 +      Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
    7.62  
    7.63      val discs = map (mk_disc_or_sel As) discs0;
    7.64      val selss = map (map (mk_disc_or_sel As)) selss0;
    7.65 @@ -216,25 +220,33 @@
    7.66  
    7.67      val goal_exhaust =
    7.68        let fun mk_prem xctr xs = fold_rev Logic.all xs (mk_imp_p [mk_Trueprop_eq (v, xctr)]) in
    7.69 -        mk_imp_p (map2 mk_prem xctrs xss)
    7.70 +        fold_rev Logic.all [p, v] (mk_imp_p (map2 mk_prem xctrs xss))
    7.71        end;
    7.72  
    7.73      val goal_injectss =
    7.74        let
    7.75          fun mk_goal _ _ [] [] = []
    7.76            | mk_goal xctr yctr xs ys =
    7.77 -            [mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
    7.78 -              Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys))];
    7.79 +            [fold_rev Logic.all (xs @ ys) (mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
    7.80 +              Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys)))];
    7.81        in
    7.82          map4 mk_goal xctrs yctrs xss yss
    7.83        end;
    7.84  
    7.85      val goal_half_distinctss =
    7.86 -      map (map (HOLogic.mk_Trueprop o HOLogic.mk_not o HOLogic.mk_eq)) (mk_half_pairss xctrs);
    7.87 +      let
    7.88 +        fun mk_goal ((xs, t), (xs', t')) =
    7.89 +          fold_rev Logic.all (xs @ xs')
    7.90 +            (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (t, t'))));
    7.91 +      in
    7.92 +        map (map mk_goal) (mk_half_pairss (xss ~~ xctrs))
    7.93 +      end;
    7.94  
    7.95 -    val goal_cases = map2 (fn xctr => fn xf => mk_Trueprop_eq (caseofB_fs $ xctr, xf)) xctrs xfs;
    7.96 +    val goal_cases =
    7.97 +      map3 (fn xs => fn xctr => fn xf =>
    7.98 +        fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (caseofB_fs $ xctr, xf))) xss xctrs xfs;
    7.99  
   7.100 -    val goals = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
   7.101 +    val goalss = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
   7.102  
   7.103      fun after_qed thmss lthy =
   7.104        let
   7.105 @@ -356,7 +368,7 @@
   7.106            else
   7.107              let
   7.108                fun mk_prem disc = mk_imp_p [HOLogic.mk_Trueprop (betapply (disc, v))];
   7.109 -              val goal = fold Logic.all [p, v] (mk_imp_p (map mk_prem discs));
   7.110 +              val goal = fold_rev Logic.all [p, v] (mk_imp_p (map mk_prem discs));
   7.111              in
   7.112                [Skip_Proof.prove lthy [] [] goal (fn _ =>
   7.113                   mk_disc_exhaust_tac n exhaust_thm discI_thms)]
   7.114 @@ -455,9 +467,9 @@
   7.115             (disc_exhaustN, disc_exhaust_thms),
   7.116             (distinctN, distinct_thms),
   7.117             (exhaustN, [exhaust_thm]),
   7.118 -           (injectN, (flat inject_thmss)),
   7.119 +           (injectN, flat inject_thmss),
   7.120             (nchotomyN, [nchotomy_thm]),
   7.121 -           (selsN, (flat sel_thmss)),
   7.122 +           (selsN, flat sel_thmss),
   7.123             (splitN, [split_thm]),
   7.124             (split_asmN, [split_asm_thm]),
   7.125             (weak_case_cong_thmsN, [weak_case_cong_thm])]
   7.126 @@ -468,20 +480,20 @@
   7.127          lthy |> Local_Theory.notes notes |> snd
   7.128        end;
   7.129    in
   7.130 -    (goals, after_qed, lthy')
   7.131 +    (goalss, after_qed, lthy')
   7.132    end;
   7.133  
   7.134 -fun wrap tacss = (fn (goalss, after_qed, lthy) =>
   7.135 +fun wrap_data tacss = (fn (goalss, after_qed, lthy) =>
   7.136    map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   7.137    |> (fn thms => after_qed thms lthy)) oo
   7.138 -  prepare_wrap (singleton o Type_Infer_Context.infer_types)
   7.139 +  prepare_wrap_data (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
   7.140  
   7.141  val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
   7.142  val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
   7.143  
   7.144  val wrap_data_cmd = (fn (goalss, after_qed, lthy) =>
   7.145    Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
   7.146 -  prepare_wrap Syntax.read_term;
   7.147 +  prepare_wrap_data Syntax.read_term;
   7.148  
   7.149  val _ =
   7.150    Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"