src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49214 2a3cb4c71b87
parent 49213 975ccb0130cb
child 49215 1c5d6e2eb0c6
equal deleted inserted replaced
49213:975ccb0130cb 49214:2a3cb4c71b87
    13 structure BNF_FP_Sugar : BNF_FP_SUGAR =
    13 structure BNF_FP_Sugar : BNF_FP_SUGAR =
    14 struct
    14 struct
    15 
    15 
    16 open BNF_Util
    16 open BNF_Util
    17 open BNF_Wrap
    17 open BNF_Wrap
       
    18 open BNF_Def
    18 open BNF_FP_Util
    19 open BNF_FP_Util
    19 open BNF_LFP
    20 open BNF_LFP
    20 open BNF_GFP
    21 open BNF_GFP
    21 open BNF_FP_Sugar_Tactics
    22 open BNF_FP_Sugar_Tactics
    22 
    23 
    24 val coitersN = "coiters";
    25 val coitersN = "coiters";
    25 val corecsN = "corecs";
    26 val corecsN = "corecs";
    26 val itersN = "iters";
    27 val itersN = "iters";
    27 val recsN = "recs";
    28 val recsN = "recs";
    28 
    29 
    29 fun split_list7 xs = (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs);
    30 fun split_list8 xs =
       
    31   (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
       
    32 
       
    33 fun typ_subst inst (T as Type (s, Ts)) =
       
    34     (case AList.lookup (op =) inst T of
       
    35       NONE => Type (s, map (typ_subst inst) Ts)
       
    36     | SOME T' => T')
       
    37   | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
    30 
    38 
    31 fun retype_free (Free (s, _)) T = Free (s, T);
    39 fun retype_free (Free (s, _)) T = Free (s, T);
    32 
    40 
    33 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
    41 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
    34 
    42 
    35 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    43 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    36 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    44 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    37 fun mk_uncurried2_fun f xss =
    45 fun mk_uncurried2_fun f xss =
    38   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
    46   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
       
    47 
       
    48 fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
    39 
    49 
    40 fun popescu_zip [] [fs] = fs
    50 fun popescu_zip [] [fs] = fs
    41   | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
    51   | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
    42 
    52 
    43 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    53 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
   158     val ns = map length ctr_Tsss;
   168     val ns = map length ctr_Tsss;
   159     val kss = map (fn n => 1 upto n) ns;
   169     val kss = map (fn n => 1 upto n) ns;
   160     val mss = map (map length) ctr_Tsss;
   170     val mss = map (map length) ctr_Tsss;
   161     val Css = map2 replicate ns Cs;
   171     val Css = map2 replicate ns Cs;
   162 
   172 
   163     fun mk_iter_like Ts Us c =
   173     fun mk_iter_like Ts Us t =
   164       let
   174       let
   165         val (binders, body) = strip_type (fastype_of c);
   175         val (binders, body) = strip_type (fastype_of t);
   166         val (f_Us, prebody) = split_last binders;
   176         val (f_Us, prebody) = split_last binders;
   167         val Type (_, Ts0) = if lfp then prebody else body;
   177         val Type (_, Ts0) = if lfp then prebody else body;
   168         val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
   178         val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
   169       in
   179       in
   170         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
   180         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   171       end;
   181       end;
   172 
   182 
   173     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
   183     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
   174     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
   184     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
   175 
   185 
   357             val rec0 = Morphism.term phi raw_rec;
   367             val rec0 = Morphism.term phi raw_rec;
   358 
   368 
   359             val iter = mk_iter_like As Cs iter0;
   369             val iter = mk_iter_like As Cs iter0;
   360             val recx = mk_iter_like As Cs rec0;
   370             val recx = mk_iter_like As Cs rec0;
   361           in
   371           in
   362             ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
   372             ((ctrs, iter, recx, v, xss, ctr_defs, iter_def, rec_def), lthy)
   363           end;
   373           end;
   364 
   374 
   365         fun some_gfp_sugar no_defs_lthy =
   375         fun some_gfp_sugar no_defs_lthy =
   366           let
   376           let
   367             val B_to_fpT = C --> fpT;
   377             val B_to_fpT = C --> fpT;
   400 
   410 
   401             val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
   411             val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
   402 
   412 
   403             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   413             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   404           in
   414           in
   405             ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
   415             ((ctrs, coiter, corec, v, xss, ctr_defs, coiter_def, corec_def), lthy)
   406           end;
   416           end;
   407       in
   417       in
   408         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
   418         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
   409         |> (if lfp then some_lfp_sugar else some_gfp_sugar)
   419         |> (if lfp then some_lfp_sugar else some_gfp_sugar)
   410       end;
   420       end;
   411 
   421 
   412     fun pour_more_sugar_on_lfps ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
   422     fun mk_map Ts Us t =
       
   423       let val (Type (_, Ts0), Type (_, Us0)) = strip_type (fastype_of t) |>> List.last in
       
   424         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
       
   425       end;
       
   426 
       
   427     fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
   413         lthy) =
   428         lthy) =
   414       let
   429       let
   415         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
   430         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
   416         val giters = map (lists_bmoc gss) iters;
   431         val giters = map (lists_bmoc gss) iters;
   417         val hrecs = map (lists_bmoc hss) recs;
   432         val hrecs = map (lists_bmoc hss) recs;
   420           let
   435           let
   421             fun mk_goal_iter_like fss fiter_like xctr f xs fxs =
   436             fun mk_goal_iter_like fss fiter_like xctr f xs fxs =
   422               fold_rev (fold_rev Logic.all) (xs :: fss)
   437               fold_rev (fold_rev Logic.all) (xs :: fss)
   423                 (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
   438                 (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
   424 
   439 
   425             fun repair_iter_call (x as Free (_, T)) =
   440             fun build_iter_like fiter_likes maybe_tick =
   426               (case find_index (curry (op =) T) fpTs of ~1 => x | j => nth giters j $ x);
   441               let
       
   442                 fun build (T, U) =
       
   443                   if T = U then
       
   444                     Const (@{const_name id}, T --> T)
       
   445                   else
       
   446                     (case (find_index (curry (op =) T) fpTs, (T, U)) of
       
   447                       (~1, (Type (s, Ts), Type (_, Us))) =>
       
   448                       let
       
   449                         val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
       
   450                         val mapx = mk_map Ts Us map0;
       
   451                         val TUs = map dest_funT (fst (split_last (binder_types (fastype_of mapx))));
       
   452                         val args = map build TUs;
       
   453                       in Term.list_comb (mapx, args) end
       
   454                     | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
       
   455               in build end;
       
   456 
       
   457             fun mk_U maybe_prodT =
       
   458               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
       
   459 
       
   460             fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
       
   461               if member (op =) fpTs T then
       
   462                 maybe_cons x [build_iter_like fiter_likes (K I) (T, mk_U (K I) T) $ x]
       
   463               else if exists_subtype (member (op =) fpTs) T then
       
   464                 [build_iter_like fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
       
   465               else
       
   466                 [x];
       
   467 
   427             fun repair_rec_call (x as Free (_, T)) =
   468             fun repair_rec_call (x as Free (_, T)) =
   428               (case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
   469               (case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
   429 
   470 
   430             val gxsss = map (map (map repair_iter_call)) xsss;
   471             val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
   431             val hxsss = map (map (maps repair_rec_call)) xsss;
   472             val hxsss =
       
   473               map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
   432 
   474 
   433             val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
   475             val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
   434             val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
   476             val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
   435 
   477 
   436             val iter_tacss =
   478             val iter_tacss =
   453               bs thmss);
   495               bs thmss);
   454       in
   496       in
   455         lthy |> Local_Theory.notes notes |> snd
   497         lthy |> Local_Theory.notes notes |> snd
   456       end;
   498       end;
   457 
   499 
   458     fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, xsss, ctr_defss, coiter_defs, corec_defs),
   500     fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
   459         lthy) =
   501         corec_defs), lthy) =
   460       let
   502       let
   461         val gcoiters = map (lists_bmoc pgss) coiters;
   503         val gcoiters = map (lists_bmoc pgss) coiters;
   462         val hcorecs = map (lists_bmoc phss) corecs;
   504         val hcorecs = map (lists_bmoc phss) corecs;
   463 
   505 
   464         val (coiter_thmss, corec_thmss) =
   506         val (coiter_thmss, corec_thmss) =
   503 
   545 
   504     val lthy' = lthy
   546     val lthy' = lthy
   505       |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
   547       |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
   506         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
   548         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
   507         ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
   549         ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
   508       |>> split_list7
   550       |>> split_list8
   509       |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
   551       |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
   510 
   552 
   511     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
   553     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
   512       (if lfp then "" else "co") ^ "datatype"));
   554       (if lfp then "" else "co") ^ "datatype"));
   513   in
   555   in