generalized recursors, effectively reverting inductive half of c7a034d01936
authorblanchet
Wed May 29 02:35:49 2013 +0200 (2013-05-29)
changeset 522144cc5a80bba80
parent 52213 f4c5c6320cce
child 52215 7facaee8586f
generalized recursors, effectively reverting inductive half of c7a034d01936
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed May 29 02:35:49 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed May 29 02:35:49 2013 +0200
     1.3 @@ -32,8 +32,8 @@
     1.4    val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
     1.5    val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
     1.6      int list list -> term list -> term list -> Proof.context ->
     1.7 -    (term list * term list * ((term list list * typ list list * term list list list)
     1.8 -       * (term list list * typ list list * term list list list)) option
     1.9 +    (term list * term list * ((term list list * typ list list * term list list list list)
    1.10 +       * (term list list * typ list list * term list list list list)) option
    1.11       * (term list * term list list
    1.12          * ((term list list * term list list list list * term list list list list)
    1.13             * (typ list * typ list list list * typ list list))
    1.14 @@ -44,9 +44,10 @@
    1.15  
    1.16    val mk_iter_fun_arg_types_pairsss: typ list -> int list -> int list list -> term ->
    1.17      (typ list * typ list) list list list
    1.18 -  val define_fold_rec: (term list list * typ list list * term list list list)
    1.19 -      * (term list list * typ list list * term list list list) -> (string -> binding) -> typ list ->
    1.20 -    typ list -> term -> term -> Proof.context -> (term * term * thm * thm) * Proof.context
    1.21 +  val define_fold_rec: (term list list * typ list list * term list list list list)
    1.22 +      * (term list list * typ list list * term list list list list) -> (string -> binding) ->
    1.23 +    typ list -> typ list -> term -> term -> Proof.context ->
    1.24 +    (term * term * thm * thm) * Proof.context
    1.25    val define_unfold_corec: term list * term list list
    1.26        * ((term list list * term list list list list * term list list list list)
    1.27           * (typ list * typ list list list * typ list list))
    1.28 @@ -182,17 +183,14 @@
    1.29  
    1.30  val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
    1.31  
    1.32 -fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    1.33 -fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    1.34 -
    1.35  fun flat_rec unzipf xs =
    1.36    let val ps = map unzipf xs in
    1.37      (* The first line below gives the preferred order. The second line is for compatibility with the
    1.38         old datatype package: *)
    1.39 -(*
    1.40      maps (op @) ps
    1.41 +(* ###
    1.42 +    maps fst ps @ maps snd ps
    1.43  *)
    1.44 -    maps fst ps @ maps snd ps
    1.45    end;
    1.46  
    1.47  fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
    1.48 @@ -201,6 +199,11 @@
    1.49    | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
    1.50      p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
    1.51  
    1.52 +fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    1.53 +fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    1.54 +fun mk_uncurried2_fun f xss =
    1.55 +  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
    1.56 +
    1.57  fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
    1.58    Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
    1.59  
    1.60 @@ -245,8 +248,12 @@
    1.61  
    1.62  val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
    1.63  
    1.64 -fun meta_unzip_rec getT proj1 proj2 fpTs y =
    1.65 -  if exists_subtype_in fpTs (getT y) then ([proj1 y], [proj2 y]) else ([y], []);
    1.66 +fun meta_unzip_rec getT left right nested fpTs y =
    1.67 +  let val T = getT y in
    1.68 +    if member (op =) fpTs T then ([left y], [right y])
    1.69 +    else if exists_subtype_in fpTs T then ([nested y], [])
    1.70 +    else ([y], [])
    1.71 +  end;
    1.72  
    1.73  fun project_co_recT special_Tname fpTs proj =
    1.74    let
    1.75 @@ -259,10 +266,7 @@
    1.76  val project_recT = project_co_recT @{type_name prod};
    1.77  val project_corecT = project_co_recT @{type_name sum};
    1.78  
    1.79 -fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) fpTs;
    1.80 -
    1.81 -fun mk_fold_fun_typess y_Tsss Cs = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
    1.82 -val mk_rec_fun_typess = mk_fold_fun_typess oo map o map o flat_rec o unzip_recT;
    1.83 +fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) I fpTs;
    1.84  
    1.85  fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
    1.86  
    1.87 @@ -273,21 +277,40 @@
    1.88  
    1.89  fun mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
    1.90    let
    1.91 +    val Css = map2 replicate ns Cs;
    1.92      val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
    1.93 -    val g_Tss = mk_fold_fun_typess y_Tsss Cs;
    1.94 +    val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
    1.95  
    1.96      val ((gss, ysss), lthy) =
    1.97        lthy
    1.98        |> mk_Freess "f" g_Tss
    1.99        ||>> mk_Freesss "x" y_Tsss;
   1.100 +    val yssss = map (map (map single)) ysss;
   1.101 +
   1.102 +    (* ### *)
   1.103 +    fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
   1.104 +        if member (op =) Cs U then Us else [T]
   1.105 +      | dest_rec_prodT T = [T];
   1.106 +
   1.107 +    val z_Tssss =
   1.108 +      map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
   1.109 +        dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
   1.110  
   1.111      val z_Tsss = map3 mk_fun_arg_typess ns mss ctor_rec_fun_Ts;
   1.112 -    val h_Tss = mk_rec_fun_typess fpTs z_Tsss Cs;
   1.113 +    val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
   1.114  
   1.115      val hss = map2 (map2 retype_free) h_Tss gss;
   1.116      val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
   1.117 +    val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
   1.118 +    val (zssss_tl, lthy) =
   1.119 +      lthy
   1.120 +      |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
   1.121 +    val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   1.122 +
   1.123 +val _ = tracing (" *** OLD:  " ^ PolyML.makestring (ysss, zsss)) (*###*)
   1.124 +val _ = tracing ("  *** NEW: " ^ PolyML.makestring (yssss, zssss)) (*###*)
   1.125    in
   1.126 -    (((gss, g_Tss, ysss), (hss, h_Tss, zsss)), lthy)
   1.127 +    (((gss, g_Tss, yssss), (hss, h_Tss, zssss)), lthy)
   1.128    end;
   1.129  
   1.130  fun mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
   1.131 @@ -438,18 +461,12 @@
   1.132          | _ => build_simple TU);
   1.133    in build end;
   1.134  
   1.135 -fun mk_iter_body lthy fpTs ctor_iter fss xsss =
   1.136 +fun mk_iter_body lthy fpTs ctor_iter fss xssss =
   1.137    let
   1.138      fun build_proj sel sel_const (x as Free (_, T)) =
   1.139        build_map lthy (sel_const o fst) (T, project_recT fpTs sel T) $ x;
   1.140 -
   1.141 -    (* TODO: Avoid these complications; cf. corec case *)
   1.142 -    val unzip_rec = meta_unzip_rec (snd o dest_Free) (build_proj fst fst_const)
   1.143 -      (build_proj snd snd_const) fpTs;
   1.144 -
   1.145 -    fun mk_iter_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (flat_rec unzip_rec xs);
   1.146    in
   1.147 -    Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_iter_arg) fss xsss)
   1.148 +    Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss)
   1.149    end;
   1.150  
   1.151  fun mk_preds_getterss_join c cps sum_prod_T cqfss =
   1.152 @@ -480,13 +497,13 @@
   1.153  
   1.154      val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
   1.155  
   1.156 -    fun generate_iter (suf, ctor_iter, (fss, f_Tss, xsss)) =
   1.157 +    fun generate_iter (suf, ctor_iter, (fss, f_Tss, xssss)) =
   1.158        let
   1.159          val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
   1.160          val binding = mk_binding suf;
   1.161          val spec =
   1.162            mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
   1.163 -            mk_iter_body lthy0 fpTs ctor_iter fss xsss);
   1.164 +            mk_iter_body lthy0 fpTs ctor_iter fss xssss);
   1.165        in (binding, spec) end;
   1.166  
   1.167      val binding_specs =
   1.168 @@ -558,7 +575,6 @@
   1.169      val pre_map_defs = map map_def_of_bnf pre_bnfs;
   1.170      val pre_set_defss = map set_defs_of_bnf pre_bnfs;
   1.171      val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
   1.172 -    val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
   1.173      val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
   1.174      val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
   1.175  
   1.176 @@ -671,24 +687,47 @@
   1.177  
   1.178          val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
   1.179  
   1.180 -        fun unzip_iters fiters =
   1.181 +        (* ### *)
   1.182 +        fun typ_subst inst (T as Type (s, Ts)) =
   1.183 +            (case AList.lookup (op =) inst T of
   1.184 +              NONE => Type (s, map (typ_subst inst) Ts)
   1.185 +            | SOME T' => T')
   1.186 +          | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
   1.187 +
   1.188 +        fun mk_U' maybe_mk_prodT =
   1.189 +          typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
   1.190 +
   1.191 +        (* ### *)
   1.192 +        fun build_rec_like fiters maybe_tick (T, U) =
   1.193 +          if T = U then
   1.194 +            id_const T
   1.195 +          else
   1.196 +            (case find_index (curry (op =) T) fpTs of
   1.197 +              ~1 => build_map lthy (build_rec_like fiters maybe_tick) (T, U)
   1.198 +            | kk => maybe_tick (nth us kk) (nth fiters kk));
   1.199 +
   1.200 +        fun unzip_iters fiters maybe_tick maybe_mk_prodT =
   1.201            meta_unzip_rec (snd o dest_Free) I
   1.202              (fn x as Free (_, T) => build_map lthy (indexify_fst fpTs (K o nth fiters))
   1.203 -              (T, mk_U T) $ x) fpTs;
   1.204 +              (T, mk_U T) $ x)
   1.205 +            (fn x as Free (_, T) => build_rec_like fiters maybe_tick (T, mk_U' maybe_mk_prodT T) $ x)
   1.206 +            fpTs;
   1.207 +
   1.208 +        fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
   1.209  
   1.210          val gxsss = map (map (flat_rec ((fn (ts, ts') => ([hd (ts' @ ts)], [])) o
   1.211 -          unzip_iters gfolds))) xsss;
   1.212 -        val hxsss = map (map (flat_rec (unzip_iters hrecs))) xsss;
   1.213 +          unzip_iters gfolds (K I) (K I)))) xsss;
   1.214 +        val hxsss = map (map (flat_rec (unzip_iters hrecs tick (curry HOLogic.mk_prodT)))) xsss;
   1.215  
   1.216          val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
   1.217          val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
   1.218  
   1.219          val fold_tacss =
   1.220 -          map2 (map o mk_iter_tac pre_map_defs [] nesting_map_ids'' fold_defs) ctor_fold_thms
   1.221 -            ctr_defss;
   1.222 +          map2 (map o mk_iter_tac pre_map_defs nesting_map_ids'' fold_defs)
   1.223 +            ctor_fold_thms ctr_defss;
   1.224          val rec_tacss =
   1.225 -          map2 (map o mk_iter_tac pre_map_defs nested_map_comp's
   1.226 -            (nested_map_ids'' @ nesting_map_ids'') rec_defs) ctor_rec_thms ctr_defss;
   1.227 +          map2 (map o mk_iter_tac pre_map_defs (nested_map_ids'' @ nesting_map_ids'') rec_defs)
   1.228 +            ctor_rec_thms ctr_defss;
   1.229  
   1.230          fun prove goal tac =
   1.231            Goal.prove_sorry lthy [] [] goal (tac o #context)
     2.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Wed May 29 02:35:49 2013 +0200
     2.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Wed May 29 02:35:49 2013 +0200
     2.3 @@ -24,8 +24,7 @@
     2.4    val mk_induct_tac: Proof.context -> int -> int list -> int list list -> int list list list ->
     2.5      thm list -> thm -> thm list -> thm list list -> tactic
     2.6    val mk_inject_tac: Proof.context -> thm -> thm -> tactic
     2.7 -  val mk_iter_tac: thm list -> thm list -> thm list -> thm list -> thm -> thm -> Proof.context
     2.8 -    -> tactic
     2.9 +  val mk_iter_tac: thm list -> thm list -> thm list -> thm -> thm -> Proof.context -> tactic
    2.10  end;
    2.11  
    2.12  structure BNF_FP_Def_Sugar_Tactics : BNF_FP_DEF_SUGAR_TACTICS =
    2.13 @@ -103,9 +102,9 @@
    2.14    @{thms comp_def convol_def fst_conv id_def prod_case_Pair_iden snd_conv
    2.15        split_conv unit_case_Unity} @ sum_prod_thms_map;
    2.16  
    2.17 -fun mk_iter_tac pre_map_defs map_comp's map_ids'' iter_defs ctor_iter ctr_def ctxt =
    2.18 -  unfold_thms_tac ctxt (ctr_def :: ctor_iter :: iter_defs @ pre_map_defs @ map_comp's @
    2.19 -    map_ids'' @ iter_unfold_thms) THEN rtac refl 1;
    2.20 +fun mk_iter_tac pre_map_defs map_ids'' iter_defs ctor_iter ctr_def ctxt =
    2.21 +  unfold_thms_tac ctxt (ctr_def :: ctor_iter :: iter_defs @ pre_map_defs @ map_ids'' @
    2.22 +    iter_unfold_thms) THEN rtac refl 1;
    2.23  
    2.24  val coiter_unfold_thms =
    2.25    @{thms id_def ident_o_ident sum_case_if sum_case_o_inj} @ sum_prod_thms_map;