src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49211 239a4fa29ddf
parent 49210 656fb50d33f0
child 49212 ca59649170b0
equal deleted inserted replaced
49210:656fb50d33f0 49211:239a4fa29ddf
   166       end;
   166       end;
   167 
   167 
   168     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
   168     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
   169     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
   169     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
   170 
   170 
   171     val fp_iter_f_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
   171     val fp_iter_g_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
   172     val fp_rec_f_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
   172     val fp_rec_h_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
   173 
   173 
   174     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
   174     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
   175         if member (op =) Cs U then Us else [T]
   175         if member (op =) Cs U then Us else [T]
   176       | dest_rec_pair T = [T];
   176       | dest_rec_pair T = [T];
   177 
   177 
   178     val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
   178     val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
   179          (cs, (qss, q_Tss, gsss, g_Tsss), ())) =
   179          (cs, pss, p_Tss, coiter_extra, corec_extra)) =
   180       if lfp then
   180       if lfp then
   181         let
   181         let
   182           val y_Tsss =
   182           val y_Tsss =
   183             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
   183             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
   184               ns mss fp_iter_f_Ts;
   184               ns mss fp_iter_g_Ts;
   185           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
   185           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
   186 
   186 
   187           val ((gss, ysss), _) =
   187           val ((gss, ysss), _) =
   188             lthy
   188             lthy
   189             |> mk_Freess "f" g_Tss
   189             |> mk_Freess "f" g_Tss
   190             ||>> mk_Freesss "x" y_Tsss;
   190             ||>> mk_Freesss "x" y_Tsss;
   191 
   191 
   192           val z_Tssss =
   192           val z_Tssss =
   193             map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
   193             map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
   194               o domain_type) ns mss fp_rec_f_Ts;
   194               o domain_type) ns mss fp_rec_h_Ts;
   195           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
   195           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
   196 
   196 
   197           val hss = map2 (map2 retype_free) gss h_Tss;
   197           val hss = map2 (map2 retype_free) gss h_Tss;
   198           val (zssss, _) =
   198           val (zssss, _) =
   199             lthy
   199             lthy
   200             |> mk_Freessss "x" z_Tssss;
   200             |> mk_Freessss "x" z_Tssss;
   201         in
   201         in
   202           (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
   202           (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
   203            ([], ([], [], [], []), ()))
   203            ([], [], [], ([], [], [], []), ([], [], [], [])))
   204         end
   204         end
   205       else
   205       else
   206         let
   206         let
   207           val q_Tss =
   207           fun mk_to_dest_prodT C = map2 (map (curry (op -->) C) oo dest_tupleT);
       
   208 
       
   209           val p_Tss =
   208             map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
   210             map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
   209           val g_Tsss =
   211 
   210             map4 (fn C => fn n => fn ms => map2 (map (curry (op -->) C) oo dest_tupleT) ms o
   212           val g_sum_prod_Ts = map range_type fp_iter_g_Ts;
   211               dest_sumTN n o range_type) Cs ns mss fp_iter_f_Ts;
   213           val g_prod_Tss = map2 dest_sumTN ns g_sum_prod_Ts;
   212 
   214           val g_Tsss = map3 mk_to_dest_prodT Cs mss g_prod_Tss;
   213           val (((c, qss), gsss), _) =
   215 
       
   216           val h_sum_prod_Ts = map range_type fp_rec_h_Ts;
       
   217           val h_prod_Tss = map2 dest_sumTN ns h_sum_prod_Ts;
       
   218           val h_Tsss = map3 mk_to_dest_prodT Cs mss h_prod_Tss;
       
   219 
       
   220           val (((c, pss), gsss), _) =
   214             lthy
   221             lthy
   215             |> yield_singleton (mk_Frees "c") dummyT
   222             |> yield_singleton (mk_Frees "c") dummyT
   216             ||>> mk_Freess "p" q_Tss
   223             ||>> mk_Freess "p" p_Tss
   217             ||>> mk_Freesss "g" g_Tsss;
   224             ||>> mk_Freesss "g" g_Tsss;
       
   225 
       
   226           val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
   218 
   227 
   219           val cs = map (retype_free c) Cs;
   228           val cs = map (retype_free c) Cs;
   220         in
   229         in
   221           ((([], [], [], []), ([], [], [], [])),
   230           ((([], [], [], []), ([], [], [], [])),
   222            (cs, (qss, q_Tss, gsss, g_Tsss), ()))
   231            (cs, pss, p_Tss, (gsss, g_sum_prod_Ts, g_prod_Tss, g_Tsss),
       
   232             (hsss, h_sum_prod_Ts, h_prod_Tss, h_Tsss)))
   223         end;
   233         end;
   224 
   234 
   225     fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
   235     fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
   226           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders),
   236           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders),
   227           sel_binderss) no_defs_lthy =
   237           sel_binderss) no_defs_lthy =
   311             val rec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
   321             val rec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
   312 
   322 
   313             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
   323             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
   314             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
   324             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
   315 
   325 
   316             val iter_free = Free (Binding.name_of iter_binder, iter_T);
       
   317             val rec_free = Free (Binding.name_of rec_binder, rec_T);
       
   318 
       
   319             val iter_spec =
   326             val iter_spec =
   320               mk_Trueprop_eq (flat_list_comb (iter_free, gss),
   327               mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of iter_binder, iter_T), gss),
   321                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
   328                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
   322             val rec_spec =
   329             val rec_spec =
   323               mk_Trueprop_eq (flat_list_comb (rec_free, hss),
   330               mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of rec_binder, rec_T), hss),
   324                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
   331                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
   325 
   332 
   326             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
   333             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
   327               |> apfst split_list o fold_map2 (fn b => fn spec =>
   334               |> apfst split_list o fold_map2 (fn b => fn spec =>
   328                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   335                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   344             ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
   351             ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
   345           end;
   352           end;
   346 
   353 
   347         fun some_gfp_sugar no_defs_lthy =
   354         fun some_gfp_sugar no_defs_lthy =
   348           let
   355           let
   349             (* qss, q_Tss, gsss, g_Tsss *)
   356             fun zip_preds_and_getters ps fss = ps @ flat fss;
   350             fun zip_preds_and_getters p_Ts f_Tss = p_Ts @ flat f_Tss;
       
   351 
       
   352             val qg_Tss = map2 zip_preds_and_getters q_Tss g_Tsss;
       
   353 
   357 
   354             val B_to_fpT = C --> fpT;
   358             val B_to_fpT = C --> fpT;
   355             val coiter_T = fold_rev (curry (op --->)) qg_Tss B_to_fpT;
   359 
   356 (*
   360             val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
   357             val corec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
   361 
   358 *)
   362             fun generate_coiter_like (suf, fp_iter_like,
   359 
   363                 (fsss, f_sum_prod_Ts, f_prod_Tss, f_Tsss)) =
   360             val qgss = map2 zip_preds_and_getters qss gsss;
   364               let
   361             val cqss = map2 (fn c => map (fn q => q $ c)) cs qss;
   365                 val pf_Tss = map2 zip_preds_and_getters p_Tss f_Tsss;
   362             val cgsss = map2 (fn c => map (map (fn g => g $ c))) cs gsss;
   366                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
   363 
   367 
   364             val coiter_binder = Binding.suffix_name ("_" ^ coiterN) b;
   368                 val pfss = map2 zip_preds_and_getters pss fsss;
   365             val corec_binder = Binding.suffix_name ("_" ^ corecN) b;
   369                 val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss;
   366 
   370 
   367             val coiter_free = Free (Binding.name_of coiter_binder, coiter_T);
   371                 val binder = Binding.suffix_name ("_" ^ suf) b;
   368 (*
   372 
   369             val corec_free = Free (Binding.name_of corec_binder, corec_T);
   373                 fun mk_join c n cps sum_prod_T prod_Ts cfss =
   370 *)
   374                   Term.lambda c (mk_IfN sum_prod_T cps
   371 
   375                     (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
   372             val coiter_sum_prod_Ts = map range_type fp_iter_f_Ts;
   376 
   373             val coiter_prod_Tss = map2 dest_sumTN ns coiter_sum_prod_Ts;
   377                 val spec =
   374 
   378                   mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of binder, res_T), pfss),
   375             fun mk_join c n cqs sum_prod_T prod_Ts cgss =
   379                     Term.list_comb (fp_iter_like,
   376               Term.lambda c (mk_IfN sum_prod_T cqs
   380                       map6 mk_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
   377                 (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cgss) (1 upto n)));
   381               in (binder, spec) end;
   378 
   382 
   379             val coiter_spec =
   383             val coiter_likes = [(coiterN, fp_iter, coiter_extra), (corecN, fp_rec, corec_extra)];
   380               mk_Trueprop_eq (flat_list_comb (coiter_free, qgss),
   384             val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
   381                 Term.list_comb (fp_iter,
   385 
   382                   map6 mk_join cs ns cqss coiter_sum_prod_Ts coiter_prod_Tss cgsss));
   386             val ((csts, defs), (lthy', lthy)) = no_defs_lthy
   383 (*
       
   384             val corec_spec =
       
   385               mk_Trueprop_eq (flat_list_comb (corec_free, hss),
       
   386                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
       
   387 *)
       
   388 
       
   389             val (([raw_coiter (*, raw_corec*)], [raw_coiter_def (*, raw_corec_def*)]), (lthy', lthy)) = no_defs_lthy
       
   390               |> apfst split_list o fold_map2 (fn b => fn spec =>
   387               |> apfst split_list o fold_map2 (fn b => fn spec =>
   391                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   388                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
   392                 #>> apsnd snd) [coiter_binder (*, corec_binder*)] [coiter_spec (*, corec_spec*)]
   389                 #>> apsnd snd) binders specs
   393               ||> `Local_Theory.restore;
   390               ||> `Local_Theory.restore;
   394 
   391 
   395             (*transforms defined frees into consts (and more)*)
   392             (*transforms defined frees into consts (and more)*)
   396             val phi = Proof_Context.export_morphism lthy lthy';
   393             val phi = Proof_Context.export_morphism lthy lthy';
   397 
   394 
   398             val coiter_def = Morphism.thm phi raw_coiter_def;
   395             val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
   399 (*
   396 
   400             val corec_def = Morphism.thm phi raw_corec_def;
   397             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   401 *)
       
   402 
       
   403             val coiter0 = Morphism.term phi raw_coiter;
       
   404 (*
       
   405             val corec0 = Morphism.term phi raw_corec;
       
   406 *)
       
   407 
       
   408             val coiter = mk_iter_like As Cs coiter0;
       
   409 (*
       
   410             val corec = mk_iter_like As Cs corec0;
       
   411 *)
       
   412 
       
   413             (*###*)
       
   414             val corec = @{term True};
       
   415             val corec_def = TrueI;
       
   416           in
   398           in
   417             ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
   399             ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
   418           end;
   400           end;
   419       in
   401       in
   420         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
   402         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
   428         val giters = map (fn iter => flat_list_comb (iter, gss)) iters;
   410         val giters = map (fn iter => flat_list_comb (iter, gss)) iters;
   429         val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs;
   411         val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs;
   430 
   412 
   431         val (iter_thmss, rec_thmss) =
   413         val (iter_thmss, rec_thmss) =
   432           let
   414           let
   433             fun mk_goal_iter_or_rec fss fc xctr f xs xs' =
   415             fun mk_goal_iter_like fss fc xctr f xs xs' =
   434               fold_rev (fold_rev Logic.all) (xs :: fss)
   416               fold_rev (fold_rev Logic.all) (xs :: fss)
   435                 (mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs')));
   417                 (mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs')));
   436 
   418 
   437             fun fix_iter_free (x as Free (_, T)) =
   419             fun fix_iter_free (x as Free (_, T)) =
   438               (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x);
   420               (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x);
   440               (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
   422               (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
   441 
   423 
   442             val iter_xsss = map (map (map fix_iter_free)) xsss;
   424             val iter_xsss = map (map (map fix_iter_free)) xsss;
   443             val rec_xsss = map (map (maps fix_rec_free)) xsss;
   425             val rec_xsss = map (map (maps fix_rec_free)) xsss;
   444 
   426 
   445             val goal_iterss =
   427             val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss iter_xsss;
   446               map5 (map4 o mk_goal_iter_or_rec gss) giters xctrss gss xsss iter_xsss;
   428             val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss rec_xsss;
   447             val goal_recss =
       
   448               map5 (map4 o mk_goal_iter_or_rec hss) hrecs xctrss hss xsss rec_xsss;
       
   449 
   429 
   450             val iter_tacss =
   430             val iter_tacss =
   451               map2 (map o mk_iter_or_rec_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss;
   431               map2 (map o mk_iter_like_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss;
   452             val rec_tacss =
   432             val rec_tacss =
   453               map2 (map o mk_iter_or_rec_tac pre_map_defs rec_defs) fp_rec_thms ctr_defss;
   433               map2 (map o mk_iter_like_tac pre_map_defs rec_defs) fp_rec_thms ctr_defss;
   454           in
   434           in
   455             (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   435             (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   456                goal_iterss iter_tacss,
   436                goal_iterss iter_tacss,
   457              map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   437              map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   458                goal_recss rec_tacss)
   438                goal_recss rec_tacss)