changed type of corecursor for the nested recursion case
authorblanchet
Tue Oct 02 01:00:18 2012 +0200 (2012-10-02)
changeset 49681aa66ea552357
parent 49680 00290dc6bfad
child 49682 f57af1c46f99
changed type of corecursor for the nested recursion case
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
     1.3 @@ -361,19 +361,19 @@
     1.4              | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
     1.5                p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
     1.6  
     1.7 -          fun mk_types maybe_dest_sumT fun_Ts =
     1.8 +          fun mk_types maybe_unzipT fun_Ts =
     1.9              let
    1.10                val f_sum_prod_Ts = map range_type fun_Ts;
    1.11                val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
    1.12 +              val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
    1.13                val f_Tssss =
    1.14 -                map3 (fn C => map2 (map (map (curry (op -->) C) o maybe_dest_sumT) oo dest_tupleT))
    1.15 -                  Cs mss' f_prod_Tss;
    1.16 +                map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
    1.17                val q_Tssss =
    1.18                  map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
    1.19                val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
    1.20 -            in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
    1.21 +            in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
    1.22  
    1.23 -          val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
    1.24 +          val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
    1.25  
    1.26            val ((((Free (z, _), cs), pss), gssss), lthy) =
    1.27              lthy
    1.28 @@ -383,11 +383,16 @@
    1.29              ||>> mk_Freessss "g" g_Tssss;
    1.30            val rssss = map (map (map (fn [] => []))) r_Tssss;
    1.31  
    1.32 -          fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
    1.33 -              if member (op =) Cs U then Us else [T]
    1.34 -            | dest_corec_sumT T = [T];
    1.35 +          fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
    1.36 +              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
    1.37 +            | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
    1.38 +            | proj_corecT _ T = T;
    1.39  
    1.40 -          val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
    1.41 +          fun unzip_corecT T =
    1.42 +            if exists_fp_subtype T then [proj_corecT fst T, proj_corecT snd T] else [T];
    1.43 +
    1.44 +          val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
    1.45 +            mk_types unzip_corecT fp_rec_fun_Ts;
    1.46  
    1.47            val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
    1.48            val ((sssss, hssss_tl), lthy) =
    1.49 @@ -396,23 +401,34 @@
    1.50              ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
    1.51            val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
    1.52  
    1.53 -          val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
    1.54 +          val cpss = map2 (map o rapp) cs pss;
    1.55  
    1.56 -          fun mk_preds_getters_join [] [cf] = cf
    1.57 -            | mk_preds_getters_join [cq] [cf, cf'] =
    1.58 -              mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
    1.59 +          fun build_sum_inj mk_inj (T, U) =
    1.60 +            if T = U then
    1.61 +              id_const T
    1.62 +            else
    1.63 +              (case (T, U) of
    1.64 +                (Type (s, _), Type (s', _)) =>
    1.65 +                if s = s' then build_map (build_sum_inj mk_inj) T U
    1.66 +                else uncurry mk_inj (dest_sumT U)
    1.67 +              | _ => uncurry mk_inj (dest_sumT U));
    1.68  
    1.69 -          fun mk_terms qssss fssss =
    1.70 +          fun build_dtor_corec_arg _ [] [cf] = cf
    1.71 +            | build_dtor_corec_arg T [cq] [cf, cf'] =
    1.72 +              mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
    1.73 +                (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
    1.74 +
    1.75 +          fun mk_terms f_Tsss qssss fssss =
    1.76              let
    1.77                val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
    1.78 -              val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
    1.79 -              val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
    1.80 -              val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
    1.81 +              val cqssss = map2 (map o map o map o rapp) cs qssss;
    1.82 +              val cfssss = map2 (map o map o map o rapp) cs fssss;
    1.83 +              val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
    1.84              in (pfss, cqfsss) end;
    1.85          in
    1.86            (((([], [], []), ([], [], [])),
    1.87 -            ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
    1.88 -             (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
    1.89 +            ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
    1.90 +             (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
    1.91          end;
    1.92  
    1.93      fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
    1.94 @@ -595,15 +611,16 @@
    1.95            let
    1.96              val fpT_to_C = fpT --> C;
    1.97  
    1.98 -            fun build_ctor_rec_arg mk_proj (T, U) =
    1.99 +            fun build_prod_proj mk_proj (T, U) =
   1.100                if T = U then
   1.101                  id_const T
   1.102                else
   1.103                  (case (T, U) of
   1.104                    (Type (s, _), Type (s', _)) =>
   1.105 -                  if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
   1.106 +                  if s = s' then build_map (build_prod_proj mk_proj) T U else mk_proj T
   1.107                  | _ => mk_proj T);
   1.108  
   1.109 +            (* TODO: Avoid these complications; cf. corec case *)
   1.110              fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
   1.111                  if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
   1.112                | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
   1.113 @@ -611,8 +628,8 @@
   1.114  
   1.115              fun unzip_rec (x as Free (_, T)) =
   1.116                if exists_fp_subtype T then
   1.117 -                [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
   1.118 -                 build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
   1.119 +                [build_prod_proj fst_const (T, mk_U fst T) $ x,
   1.120 +                 build_prod_proj snd_const (T, mk_U snd T) $ x]
   1.121                else
   1.122                  [x];
   1.123