src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 49681 aa66ea552357
parent 49672 902b24e0ffb4
child 49682 f57af1c46f99
equal deleted inserted replaced
49680:00290dc6bfad 49681:aa66ea552357
   359 
   359 
   360           fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
   360           fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
   361             | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
   361             | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
   362               p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
   362               p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
   363 
   363 
   364           fun mk_types maybe_dest_sumT fun_Ts =
   364           fun mk_types maybe_unzipT fun_Ts =
   365             let
   365             let
   366               val f_sum_prod_Ts = map range_type fun_Ts;
   366               val f_sum_prod_Ts = map range_type fun_Ts;
   367               val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
   367               val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
       
   368               val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
   368               val f_Tssss =
   369               val f_Tssss =
   369                 map3 (fn C => map2 (map (map (curry (op -->) C) o maybe_dest_sumT) oo dest_tupleT))
   370                 map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
   370                   Cs mss' f_prod_Tss;
       
   371               val q_Tssss =
   371               val q_Tssss =
   372                 map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
   372                 map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
   373               val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
   373               val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
   374             in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
   374             in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
   375 
   375 
   376           val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
   376           val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
   377 
   377 
   378           val ((((Free (z, _), cs), pss), gssss), lthy) =
   378           val ((((Free (z, _), cs), pss), gssss), lthy) =
   379             lthy
   379             lthy
   380             |> yield_singleton (mk_Frees "z") dummyT
   380             |> yield_singleton (mk_Frees "z") dummyT
   381             ||>> mk_Frees "a" Cs
   381             ||>> mk_Frees "a" Cs
   382             ||>> mk_Freess "p" p_Tss
   382             ||>> mk_Freess "p" p_Tss
   383             ||>> mk_Freessss "g" g_Tssss;
   383             ||>> mk_Freessss "g" g_Tssss;
   384           val rssss = map (map (map (fn [] => []))) r_Tssss;
   384           val rssss = map (map (map (fn [] => []))) r_Tssss;
   385 
   385 
   386           fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
   386           fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
   387               if member (op =) Cs U then Us else [T]
   387               if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
   388             | dest_corec_sumT T = [T];
   388             | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
   389 
   389             | proj_corecT _ T = T;
   390           val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
   390 
       
   391           fun unzip_corecT T =
       
   392             if exists_fp_subtype T then [proj_corecT fst T, proj_corecT snd T] else [T];
       
   393 
       
   394           val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
       
   395             mk_types unzip_corecT fp_rec_fun_Ts;
   391 
   396 
   392           val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
   397           val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
   393           val ((sssss, hssss_tl), lthy) =
   398           val ((sssss, hssss_tl), lthy) =
   394             lthy
   399             lthy
   395             |> mk_Freessss "q" s_Tssss
   400             |> mk_Freessss "q" s_Tssss
   396             ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
   401             ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
   397           val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
   402           val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
   398 
   403 
   399           val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
   404           val cpss = map2 (map o rapp) cs pss;
   400 
   405 
   401           fun mk_preds_getters_join [] [cf] = cf
   406           fun build_sum_inj mk_inj (T, U) =
   402             | mk_preds_getters_join [cq] [cf, cf'] =
   407             if T = U then
   403               mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
   408               id_const T
   404 
   409             else
   405           fun mk_terms qssss fssss =
   410               (case (T, U) of
       
   411                 (Type (s, _), Type (s', _)) =>
       
   412                 if s = s' then build_map (build_sum_inj mk_inj) T U
       
   413                 else uncurry mk_inj (dest_sumT U)
       
   414               | _ => uncurry mk_inj (dest_sumT U));
       
   415 
       
   416           fun build_dtor_corec_arg _ [] [cf] = cf
       
   417             | build_dtor_corec_arg T [cq] [cf, cf'] =
       
   418               mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
       
   419                 (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
       
   420 
       
   421           fun mk_terms f_Tsss qssss fssss =
   406             let
   422             let
   407               val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
   423               val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
   408               val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
   424               val cqssss = map2 (map o map o map o rapp) cs qssss;
   409               val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
   425               val cfssss = map2 (map o map o map o rapp) cs fssss;
   410               val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
   426               val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
   411             in (pfss, cqfsss) end;
   427             in (pfss, cqfsss) end;
   412         in
   428         in
   413           (((([], [], []), ([], [], [])),
   429           (((([], [], []), ([], [], [])),
   414             ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
   430             ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
   415              (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
   431              (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
   416         end;
   432         end;
   417 
   433 
   418     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
   434     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
   419             fp_fold), fp_rec), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
   435             fp_fold), fp_rec), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
   420           pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
   436           pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
   593 
   609 
   594         fun define_fold_rec no_defs_lthy =
   610         fun define_fold_rec no_defs_lthy =
   595           let
   611           let
   596             val fpT_to_C = fpT --> C;
   612             val fpT_to_C = fpT --> C;
   597 
   613 
   598             fun build_ctor_rec_arg mk_proj (T, U) =
   614             fun build_prod_proj mk_proj (T, U) =
   599               if T = U then
   615               if T = U then
   600                 id_const T
   616                 id_const T
   601               else
   617               else
   602                 (case (T, U) of
   618                 (case (T, U) of
   603                   (Type (s, _), Type (s', _)) =>
   619                   (Type (s, _), Type (s', _)) =>
   604                   if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
   620                   if s = s' then build_map (build_prod_proj mk_proj) T U else mk_proj T
   605                 | _ => mk_proj T);
   621                 | _ => mk_proj T);
   606 
   622 
       
   623             (* TODO: Avoid these complications; cf. corec case *)
   607             fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
   624             fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
   608                 if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
   625                 if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
   609               | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
   626               | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
   610               | mk_U _ T = T;
   627               | mk_U _ T = T;
   611 
   628 
   612             fun unzip_rec (x as Free (_, T)) =
   629             fun unzip_rec (x as Free (_, T)) =
   613               if exists_fp_subtype T then
   630               if exists_fp_subtype T then
   614                 [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
   631                 [build_prod_proj fst_const (T, mk_U fst T) $ x,
   615                  build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
   632                  build_prod_proj snd_const (T, mk_U snd T) $ x]
   616               else
   633               else
   617                 [x];
   634                 [x];
   618 
   635 
   619             fun mk_rec_like_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (maps unzip_rec xs);
   636             fun mk_rec_like_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (maps unzip_rec xs);
   620 
   637