construct high-level iterator RHS
authorblanchet
Thu Sep 06 02:56:21 2012 +0200 (2012-09-06)
changeset 491766d29d2db5f88
parent 49175 eab51f249c70
child 49177 db8ce685073f
construct high-level iterator RHS
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_lfp.ML
src/HOL/Codatatype/Tools/bnf_util.ML
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 06 01:37:24 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 06 02:56:21 2012 +0200
     1.3 @@ -52,9 +52,6 @@
     1.4  fun args_of ((_, args), _) = args;
     1.5  fun mixfix_of_ctr (_, mx) = mx;
     1.6  
     1.7 -val uncurry_fs =
     1.8 -  map2 (fn f => fn xs => HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs)));
     1.9 -
    1.10  fun prepare_data prepare_typ gfp specs fake_lthy lthy =
    1.11    let
    1.12      val constrained_As =
    1.13 @@ -75,7 +72,7 @@
    1.14          As);
    1.15  
    1.16      val bs = map type_binder_of specs;
    1.17 -    val Ts = map mk_T bs;
    1.18 +    val fp_Ts = map mk_T bs;
    1.19  
    1.20      val mixfixes = map mixfix_of_typ specs;
    1.21  
    1.22 @@ -98,35 +95,35 @@
    1.23        | A' :: _ => error ("Extra type variables on rhs: " ^
    1.24            quote (Syntax.string_of_typ lthy (TFree A'))));
    1.25  
    1.26 -    val (Bs, C) =
    1.27 +    val ((Cs, Xs), _) =
    1.28        lthy
    1.29        |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
    1.30        |> mk_TFrees N
    1.31 -      ||> the_single o fst o mk_TFrees 1;
    1.32 +      ||>> mk_TFrees N;
    1.33  
    1.34 -    fun is_same_rec (T as Type (s, Us)) (Type (s', Us')) =
    1.35 +    fun is_same_recT (T as Type (s, Us)) (Type (s', Us')) =
    1.36          s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
    1.37            quote (Syntax.string_of_typ fake_lthy T)))
    1.38 -      | is_same_rec _ _ = false
    1.39 +      | is_same_recT _ _ = false;
    1.40  
    1.41 -    fun freeze_rec (T as Type (s, Us)) =
    1.42 -        (case find_index (is_same_rec T) Ts of
    1.43 -          ~1 => Type (s, map freeze_rec Us)
    1.44 -        | i => nth Bs i)
    1.45 -      | freeze_rec T = T;
    1.46 +    fun freeze_recXs (T as Type (s, Us)) =
    1.47 +        (case find_index (is_same_recT T) fp_Ts of
    1.48 +          ~1 => Type (s, map freeze_recXs Us)
    1.49 +        | i => nth Xs i)
    1.50 +      | freeze_recXs T = T;
    1.51  
    1.52 -    val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
    1.53 -    val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
    1.54 +    val ctr_TsssXs = map (map (map freeze_recXs)) ctr_Tsss;
    1.55 +    val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
    1.56  
    1.57 -    val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
    1.58 +    val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
    1.59  
    1.60 -    val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') =
    1.61 +    val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') =
    1.62        fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
    1.63  
    1.64      val timer = time (Timer.startRealTimer ());
    1.65  
    1.66 -    fun mk_unf_or_fld get_foldedT Ts t =
    1.67 -      let val Type (_, Ts0) = get_foldedT (fastype_of t) in
    1.68 +    fun mk_unf_or_fld get_T Ts t =
    1.69 +      let val Type (_, Ts0) = get_T (fastype_of t) in
    1.70          Term.subst_atomic_types (Ts0 ~~ Ts) t
    1.71        end;
    1.72  
    1.73 @@ -136,10 +133,23 @@
    1.74      val unfs = map (mk_unf As) raw_unfs;
    1.75      val flds = map (mk_fld As) raw_flds;
    1.76  
    1.77 -    fun pour_sugar_on_type (((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject),
    1.78 -        ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
    1.79 +    fun mk_fp_iter_or_rec Ts Us t =
    1.80        let
    1.81 -        val n = length ctr_binders;
    1.82 +        val (binders, body) = strip_type (fastype_of t);
    1.83 +        val Type (_, Ts0) = if gfp then body else List.last binders;
    1.84 +        val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
    1.85 +      in
    1.86 +        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
    1.87 +      end;
    1.88 +
    1.89 +    val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
    1.90 +    val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
    1.91 +
    1.92 +    fun pour_sugar_on_type ((((((((((((((b, fp_T), C), fld), unf), fp_iter), fp_rec), fld_unf),
    1.93 +          unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
    1.94 +        no_defs_lthy =
    1.95 +      let
    1.96 +        val n = length ctr_Tss;
    1.97          val ks = 1 upto n;
    1.98          val ms = map length ctr_Tss;
    1.99  
   1.100 @@ -147,11 +157,11 @@
   1.101          val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   1.102          val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
   1.103  
   1.104 -        val ((((fs, u), v), xss), _) =
   1.105 +        val ((((u, v), fs), xss), _) =
   1.106            lthy
   1.107 -          |> mk_Frees "f" case_Ts
   1.108 -          ||>> yield_singleton (mk_Frees "u") unf_T
   1.109 -          ||>> yield_singleton (mk_Frees "v") T
   1.110 +          |> yield_singleton (mk_Frees "u") unf_T
   1.111 +          ||>> yield_singleton (mk_Frees "v") fp_T
   1.112 +          ||>> mk_Frees "f" case_Ts
   1.113            ||>> mk_Freess "x" ctr_Tss;
   1.114  
   1.115          val ctr_rhss =
   1.116 @@ -161,7 +171,7 @@
   1.117          val case_binder = Binding.suffix_name ("_" ^ caseN) b;
   1.118  
   1.119          val case_rhs =
   1.120 -          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
   1.121 +          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   1.122  
   1.123          val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy
   1.124            |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   1.125 @@ -189,8 +199,8 @@
   1.126                      (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
   1.127                in
   1.128                  Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   1.129 -                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld)
   1.130 -                    (certify lthy unf) fld_unf unf_fld)
   1.131 +                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, fp_T])
   1.132 +                    (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
   1.133                  |> Thm.close_derivation
   1.134                  |> Morphism.thm phi
   1.135                end;
   1.136 @@ -219,24 +229,30 @@
   1.137  
   1.138          val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   1.139  
   1.140 +        (* (co)iterators, (co)recursors, (co)induction *)
   1.141 +
   1.142 +        val is_recT = member (op =) fp_Ts;
   1.143 +
   1.144 +        val ns = map length ctr_Tsss;
   1.145 +        val mss = map (map length) ctr_Tsss;
   1.146 +        val Css = map2 replicate ns Cs;
   1.147 +
   1.148          fun sugar_lfp lthy =
   1.149            let
   1.150 -(*###
   1.151 -            val fld_iter = @{term True}; (*###*)
   1.152 +            val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
   1.153 +            val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
   1.154 +            val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
   1.155 +            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
   1.156 +            val iter_T = flat g_Tss ---> fp_T --> C;
   1.157  
   1.158 -            val iter_Tss = map (fn Ts => Ts) (*###*) ctr_Tss;
   1.159 -            val iter_Ts = map (fn Ts => Ts ---> C) iter_Tss;
   1.160 -
   1.161 -            val iter_fs = map2 (fn Free (s, _) => fn T => Free (s, T)) fs iter_Ts
   1.162 +            val ((gss, ysss), _) =
   1.163 +              lthy
   1.164 +              |> mk_Freess "f" g_Tss
   1.165 +              ||>> apfst (unflat y_Tsss) o mk_Freess "x" (flat y_Tsss);
   1.166  
   1.167              val iter_rhs =
   1.168 -              fold_rev Term.lambda fs (fld_iter $ mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
   1.169 -
   1.170 -
   1.171 -            val uncurried_fs =
   1.172 -              map2 (fn f => fn xs =>
   1.173 -                HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
   1.174 -*)
   1.175 +              fold_rev (fold_rev Term.lambda) gss
   1.176 +                (Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
   1.177            in
   1.178              lthy
   1.179            end;
   1.180 @@ -248,8 +264,9 @@
   1.181        end;
   1.182  
   1.183      val lthy'' =
   1.184 -      fold pour_sugar_on_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
   1.185 -        ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) lthy';
   1.186 +      fold pour_sugar_on_type (bs ~~ fp_Ts ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
   1.187 +        fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
   1.188 +        disc_binderss ~~ sel_bindersss) lthy';
   1.189  
   1.190      val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
   1.191        (if gfp then "co" else "") ^ "datatype"));
     2.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Thu Sep 06 01:37:24 2012 +0200
     2.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Thu Sep 06 02:56:21 2012 +0200
     2.3 @@ -88,6 +88,11 @@
     2.4    val mk_sum_case: term -> term -> term
     2.5    val mk_sum_caseN: term list -> term
     2.6  
     2.7 +  val dest_sumTN: int -> typ -> typ list
     2.8 +  val dest_tupleT: int -> typ -> typ list
     2.9 +
    2.10 +  val mk_uncurried_fun: term -> term list -> term
    2.11 +
    2.12    val mk_Field: term -> term
    2.13    val mk_union: term * term -> term
    2.14  
    2.15 @@ -219,6 +224,16 @@
    2.16  fun mk_sum_caseN [f] = f
    2.17    | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
    2.18  
    2.19 +fun dest_sumTN 1 T = [T]
    2.20 +  | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
    2.21 +
    2.22 +(* TODO: move something like this to "HOLogic"? *)
    2.23 +fun dest_tupleT 0 @{typ unit} = []
    2.24 +  | dest_tupleT 1 T = [T]
    2.25 +  | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
    2.26 +
    2.27 +fun mk_uncurried_fun f xs = HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs));
    2.28 +
    2.29  fun mk_Field r =
    2.30    let val T = fst (dest_relT (fastype_of r));
    2.31    in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
     3.1 --- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Thu Sep 06 01:37:24 2012 +0200
     3.2 +++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Thu Sep 06 02:56:21 2012 +0200
     3.3 @@ -11,7 +11,7 @@
     3.4  sig
     3.5    val bnf_gfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
     3.6      BNF_Def.BNF list -> local_theory ->
     3.7 -    (term list * term list * thm list * thm list * thm list) * local_theory
     3.8 +    (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
     3.9  end;
    3.10  
    3.11  structure BNF_GFP : BNF_GFP =
    3.12 @@ -1965,8 +1965,9 @@
    3.13  
    3.14      (*transforms defined frees into consts*)
    3.15      val phi = Proof_Context.export_morphism lthy_old lthy;
    3.16 -    val coiters = map (fst o dest_Const o Morphism.term phi) coiter_frees;
    3.17 -    fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiters (i - 1), Library.foldr (op -->)
    3.18 +    val coiters = map (Morphism.term phi) coiter_frees;
    3.19 +    val coiter_names = map (fst o dest_Const) coiters;
    3.20 +    fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiter_names (i - 1), Library.foldr (op -->)
    3.21        (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
    3.22      val coiter_defs = map ((fn thm => thm RS fun_cong) o Morphism.thm phi) coiter_def_frees;
    3.23  
    3.24 @@ -2158,8 +2159,9 @@
    3.25  
    3.26      (*transforms defined frees into consts*)
    3.27      val phi = Proof_Context.export_morphism lthy_old lthy;
    3.28 -    val corecs = map (fst o dest_Const o Morphism.term phi) corec_frees;
    3.29 -    fun mk_corec ss i = Term.list_comb (Const (nth corecs (i - 1), Library.foldr (op -->)
    3.30 +    val corecs = map (Morphism.term phi) corec_frees;
    3.31 +    val corec_names = map (fst o dest_Const) corecs;
    3.32 +    fun mk_corec ss i = Term.list_comb (Const (nth corec_names (i - 1), Library.foldr (op -->)
    3.33        (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
    3.34      val corec_defs = map (Morphism.thm phi) corec_def_frees;
    3.35  
    3.36 @@ -2990,7 +2992,7 @@
    3.37              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
    3.38            bs thmss)
    3.39    in
    3.40 -    ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
    3.41 +    ((unfs, flds, coiters, corecs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
    3.42        lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
    3.43    end;
    3.44  
     4.1 --- a/src/HOL/Codatatype/Tools/bnf_lfp.ML	Thu Sep 06 01:37:24 2012 +0200
     4.2 +++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML	Thu Sep 06 02:56:21 2012 +0200
     4.3 @@ -10,7 +10,7 @@
     4.4  sig
     4.5    val bnf_lfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
     4.6      BNF_Def.BNF list -> local_theory ->
     4.7 -    (term list * term list * thm list * thm list * thm list) * local_theory
     4.8 +    (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
     4.9  end;
    4.10  
    4.11  structure BNF_LFP : BNF_LFP =
    4.12 @@ -1078,8 +1078,9 @@
    4.13  
    4.14      (*transforms defined frees into consts*)
    4.15      val phi = Proof_Context.export_morphism lthy_old lthy;
    4.16 -    val iters = map (fst o dest_Const o Morphism.term phi) iter_frees;
    4.17 -    fun mk_iter Ts ss i = Term.list_comb (Const (nth iters (i - 1), Library.foldr (op -->)
    4.18 +    val iters = map (Morphism.term phi) iter_frees;
    4.19 +    val iter_names = map (fst o dest_Const) iters;
    4.20 +    fun mk_iter Ts ss i = Term.list_comb (Const (nth iter_names (i - 1), Library.foldr (op -->)
    4.21        (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
    4.22      val iter_defs = map (Morphism.thm phi) iter_def_frees;
    4.23  
    4.24 @@ -1239,8 +1240,9 @@
    4.25  
    4.26      (*transforms defined frees into consts*)
    4.27      val phi = Proof_Context.export_morphism lthy_old lthy;
    4.28 -    val recs = map (fst o dest_Const o Morphism.term phi) rec_frees;
    4.29 -    fun mk_rec ss i = Term.list_comb (Const (nth recs (i - 1), Library.foldr (op -->)
    4.30 +    val recs = map (Morphism.term phi) rec_frees;
    4.31 +    val rec_names = map (fst o dest_Const) recs;
    4.32 +    fun mk_rec ss i = Term.list_comb (Const (nth rec_names (i - 1), Library.foldr (op -->)
    4.33        (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
    4.34      val rec_defs = map (Morphism.thm phi) rec_def_frees;
    4.35  
    4.36 @@ -1813,7 +1815,7 @@
    4.37              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
    4.38            bs thmss)
    4.39    in
    4.40 -    ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
    4.41 +    ((unfs, flds, iters, recs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
    4.42        lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
    4.43    end;
    4.44  
     5.1 --- a/src/HOL/Codatatype/Tools/bnf_util.ML	Thu Sep 06 01:37:24 2012 +0200
     5.2 +++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Thu Sep 06 02:56:21 2012 +0200
     5.3 @@ -273,8 +273,8 @@
     5.4  fun mk_Frees x Ts ctxt = mk_fresh_names ctxt (length Ts) x
     5.5    |>> (fn names => map2 (curry Free) names Ts);
     5.6  fun mk_Freess x Tss ctxt =
     5.7 -  fold_map2 (fn name => fn Ts => fn ctxt =>
     5.8 -    mk_fresh_names ctxt (length Ts) name) (mk_names (length Tss) x) Tss ctxt
     5.9 +  fold_map2 (fn name => fn Ts => fn ctxt => mk_fresh_names ctxt (length Ts) name)
    5.10 +    (mk_names (length Tss) x) Tss ctxt
    5.11    |>> (fn namess => map2 (map2 (curry Free)) namess Tss);
    5.12  fun mk_Frees' x Ts ctxt = mk_fresh_names ctxt (length Ts) x
    5.13    |>> (fn names => `(map Free) (names ~~ Ts));