src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 49670 c7a034d01936
parent 49636 b7256a88a84b
child 49671 61729b149397
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Sun Sep 30 23:45:03 2012 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon Oct 01 10:34:58 2012 +0200
     1.3 @@ -67,8 +67,6 @@
     1.4  
     1.5  fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
     1.6  fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
     1.7 -fun mk_uncurried2_fun f xss =
     1.8 -  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
     1.9  
    1.10  fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
    1.11    Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
    1.12 @@ -247,6 +245,22 @@
    1.13  
    1.14      val timer = time (Timer.startRealTimer ());
    1.15  
    1.16 +    fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
    1.17 +      let
    1.18 +        val bnf = the (bnf_of lthy s);
    1.19 +        val live = live_of_bnf bnf;
    1.20 +        val mapx = mk_map live Ts Us (map_of_bnf bnf);
    1.21 +        val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
    1.22 +      in Term.list_comb (mapx, map build_arg TUs') end;
    1.23 +
    1.24 +    fun build_rel_step build_arg (Type (s, Ts)) =
    1.25 +      let
    1.26 +        val bnf = the (bnf_of lthy s);
    1.27 +        val live = live_of_bnf bnf;
    1.28 +        val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
    1.29 +        val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
    1.30 +      in Term.list_comb (rel, map build_arg Ts') end;
    1.31 +
    1.32      fun add_nesty_bnf_names Us =
    1.33        let
    1.34          fun add (Type (s, Ts)) ss =
    1.35 @@ -265,8 +279,11 @@
    1.36      val pre_map_defs = map map_def_of_bnf pre_bnfs;
    1.37      val pre_set_defss = map set_defs_of_bnf pre_bnfs;
    1.38      val pre_rel_defs = map rel_def_of_bnf pre_bnfs;
    1.39 +    val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
    1.40 +    val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
    1.41 +    val nesting_map_ids = map map_id_of_bnf nesting_bnfs;
    1.42 +    val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def}) nesting_map_ids;
    1.43      val nested_set_natural's = maps set_natural'_of_bnf nested_bnfs;
    1.44 -    val nesting_map_ids = map map_id_of_bnf nesting_bnfs;
    1.45      val nesting_set_natural's = maps set_natural'_of_bnf nesting_bnfs;
    1.46  
    1.47      val live = live_of_bnf any_fp_bnf;
    1.48 @@ -283,6 +300,7 @@
    1.49      val fpTs = map (domain_type o fastype_of) dtors;
    1.50  
    1.51      val exists_fp_subtype = exists_subtype (member (op =) fpTs);
    1.52 +    val exists_Cs_subtype = exists_subtype (member (op =) Cs);
    1.53  
    1.54      val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
    1.55      val ns = map length ctr_Tsss;
    1.56 @@ -310,25 +328,25 @@
    1.57              lthy
    1.58              |> mk_Freess "f" g_Tss
    1.59              ||>> mk_Freesss "x" y_Tsss;
    1.60 -          val yssss = map (map (map single)) ysss;
    1.61 +
    1.62 +          fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
    1.63 +              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
    1.64 +            | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
    1.65 +            | proj_recT _ T = T;
    1.66  
    1.67 -          fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
    1.68 -              if member (op =) Cs U then Us else [T]
    1.69 -            | dest_rec_prodT T = [T];
    1.70 +          fun unzip_recT T =
    1.71 +            if exists_fp_subtype T then [proj_recT fst T, proj_recT snd T] else [T];
    1.72  
    1.73 -          val z_Tssss =
    1.74 -            map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
    1.75 -              dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
    1.76 +          val z_Tsss =
    1.77 +            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
    1.78 +              ns mss fp_rec_fun_Ts;
    1.79 +          val z_Tssss = map (map (map unzip_recT)) z_Tsss;
    1.80            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
    1.81  
    1.82            val hss = map2 (map2 retype_free) h_Tss gss;
    1.83 -          val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
    1.84 -          val (zssss_tl, lthy) =
    1.85 -            lthy
    1.86 -            |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
    1.87 -          val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
    1.88 +          val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
    1.89          in
    1.90 -          ((((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
    1.91 +          ((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
    1.92              ([], [], [], (([], []), ([], [])), (([], []), ([], [])))), lthy)
    1.93          end
    1.94        else
    1.95 @@ -578,14 +596,37 @@
    1.96            let
    1.97              val fpT_to_C = fpT --> C;
    1.98  
    1.99 -            fun generate_rec_like (suf, fp_rec_like, (fss, f_Tss, xssss)) =
   1.100 +            fun build_ctor_rec_arg mk_proj (T, U) =
   1.101 +              if T = U then
   1.102 +                id_const T
   1.103 +              else
   1.104 +                (case (T, U) of
   1.105 +                  (Type (s, _), Type (s', _)) =>
   1.106 +                  if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
   1.107 +                | _ => mk_proj T);
   1.108 +
   1.109 +            fun mk_U proj (T as Type (@{type_name prod}, [T', U])) =
   1.110 +                if member (op =) fpTs T' then proj (T', U) else T
   1.111 +              | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
   1.112 +              | mk_U _ T = T;
   1.113 +
   1.114 +            fun unzip_rec (x as Free (_, T)) =
   1.115 +              if exists_fp_subtype T then
   1.116 +                [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
   1.117 +                 build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
   1.118 +              else
   1.119 +                [x];
   1.120 +
   1.121 +            fun mk_rec_like_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (maps unzip_rec xs);
   1.122 +
   1.123 +            fun generate_rec_like (suf, fp_rec_like, (fss, f_Tss, xsss)) =
   1.124                let
   1.125                  val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
   1.126                  val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
   1.127                  val spec =
   1.128                    mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
   1.129                      Term.list_comb (fp_rec_like,
   1.130 -                      map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
   1.131 +                      map2 (mk_sum_caseN_balanced oo map2 mk_rec_like_arg) fss xsss));
   1.132                in (binding, spec) end;
   1.133  
   1.134              val rec_like_infos =
   1.135 @@ -661,14 +702,6 @@
   1.136        fold_map I wrap_types_and_mores lthy
   1.137        |>> apsnd split_list4 o apfst split_list4 o split_list;
   1.138  
   1.139 -    fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
   1.140 -      let
   1.141 -        val bnf = the (bnf_of lthy s);
   1.142 -        val live = live_of_bnf bnf;
   1.143 -        val mapx = mk_map live Ts Us (map_of_bnf bnf);
   1.144 -        val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
   1.145 -      in Term.list_comb (mapx, map build_arg TUs') end;
   1.146 -
   1.147      (* TODO: Add map, sets, rel simps *)
   1.148      val mk_simp_thmss =
   1.149        map3 (fn (_, _, _, injects, distincts, cases, _, _, _) => fn rec_likes => fn fold_likes =>
   1.150 @@ -787,10 +820,8 @@
   1.151                typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
   1.152  
   1.153              fun intr_rec_likes frec_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
   1.154 -              if member (op =) fpTs T then
   1.155 +              if exists_fp_subtype T then
   1.156                  maybe_cons x [build_rec_like frec_likes (K I) (T, mk_U (K I) T) $ x]
   1.157 -              else if exists_fp_subtype T then
   1.158 -                [build_rec_like frec_likes maybe_tick (T, mk_U maybe_mk_prodT T) $ x]
   1.159                else
   1.160                  [x];
   1.161  
   1.162 @@ -802,11 +833,11 @@
   1.163              val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
   1.164  
   1.165              val fold_tacss =
   1.166 -              map2 (map o mk_rec_like_tac pre_map_defs nesting_map_ids fold_defs) fp_fold_thms
   1.167 +              map2 (map o mk_rec_like_tac pre_map_defs [] nesting_map_ids'' fold_defs) fp_fold_thms
   1.168                  ctr_defss;
   1.169              val rec_tacss =
   1.170 -              map2 (map o mk_rec_like_tac pre_map_defs nesting_map_ids rec_defs) fp_rec_thms
   1.171 -                ctr_defss;
   1.172 +              map2 (map o mk_rec_like_tac pre_map_defs nested_map_comp's
   1.173 +                (nested_map_ids'' @ nesting_map_ids'') rec_defs) fp_rec_thms ctr_defss;
   1.174  
   1.175              fun prove goal tac =
   1.176                Skip_Proof.prove lthy [] [] goal (tac o #context)
   1.177 @@ -873,14 +904,6 @@
   1.178                map4 (fn u => fn v => fn uvr => fn uv_eq =>
   1.179                  fold_rev Term.lambda [u, v] (HOLogic.mk_disj (uvr, uv_eq))) us vs uvrs uv_eqs;
   1.180  
   1.181 -            fun build_rel_step build_arg (Type (s, Ts)) =
   1.182 -              let
   1.183 -                val bnf = the (bnf_of lthy s);
   1.184 -                val live = live_of_bnf bnf;
   1.185 -                val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
   1.186 -                val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
   1.187 -              in Term.list_comb (rel, map build_arg Ts') end;
   1.188 -
   1.189              fun build_rel rs' T =
   1.190                (case find_index (curry (op =) T) fpTs of
   1.191                  ~1 =>
   1.192 @@ -963,7 +986,7 @@
   1.193  
   1.194              fun intr_corec_likes fcorec_likes maybe_mk_sumT maybe_tack cqf =
   1.195                let val T = fastype_of cqf in
   1.196 -                if exists_subtype (member (op =) Cs) T then
   1.197 +                if exists_Cs_subtype T then
   1.198                    build_corec_like fcorec_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
   1.199                  else
   1.200                    cqf