src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49501 acc9635a644a
parent 49498 acc583e14167
child 49502 92a7c1842c78
equal deleted inserted replaced
49500:3cb59fdd69a8 49501:acc9635a644a
   167     val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
   167     val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
   168 
   168 
   169     val fp_eqs =
   169     val fp_eqs =
   170       map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
   170       map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
   171 
   171 
   172     val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, fp_induct, unf_flds, fld_unfs, fld_injects,
   172     val (pre_bnfs, ((dtors0, ctors0, fp_iters0, fp_recs0, fp_induct, dtor_ctors, ctor_dtors,
   173         fp_iter_thms, fp_rec_thms), lthy)) =
   173            ctor_injects, fp_iter_thms, fp_rec_thms), lthy)) =
   174       fp_bnf construct fp_bs mixfixes (map dest_TFree unsorted_As) fp_eqs no_defs_lthy0;
   174       fp_bnf construct fp_bs mixfixes (map dest_TFree unsorted_As) fp_eqs no_defs_lthy0;
   175 
   175 
   176     fun add_nesty_bnf_names Us =
   176     fun add_nesty_bnf_names Us =
   177       let
   177       let
   178         fun add (Type (s, Ts)) ss =
   178         fun add (Type (s, Ts)) ss =
   188     val nesting_bnfs = nesty_bnfs As;
   188     val nesting_bnfs = nesty_bnfs As;
   189     val nested_bnfs = nesty_bnfs Bs;
   189     val nested_bnfs = nesty_bnfs Bs;
   190 
   190 
   191     val timer = time (Timer.startRealTimer ());
   191     val timer = time (Timer.startRealTimer ());
   192 
   192 
   193     fun mk_unf_or_fld get_T Ts t =
   193     fun mk_ctor_or_dtor get_T Ts t =
   194       let val Type (_, Ts0) = get_T (fastype_of t) in
   194       let val Type (_, Ts0) = get_T (fastype_of t) in
   195         Term.subst_atomic_types (Ts0 ~~ Ts) t
   195         Term.subst_atomic_types (Ts0 ~~ Ts) t
   196       end;
   196       end;
   197 
   197 
   198     val mk_unf = mk_unf_or_fld domain_type;
   198     val mk_ctor = mk_ctor_or_dtor range_type;
   199     val mk_fld = mk_unf_or_fld range_type;
   199     val mk_dtor = mk_ctor_or_dtor domain_type;
   200 
   200 
   201     val unfs = map (mk_unf As) unfs0;
   201     val ctors = map (mk_ctor As) ctors0;
   202     val flds = map (mk_fld As) flds0;
   202     val dtors = map (mk_dtor As) dtors0;
   203 
   203 
   204     val fpTs = map (domain_type o fastype_of) unfs;
   204     val fpTs = map (domain_type o fastype_of) dtors;
   205 
   205 
   206     val exists_fp_subtype = exists_subtype (member (op =) fpTs);
   206     val exists_fp_subtype = exists_subtype (member (op =) fpTs);
   207 
   207 
   208     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
   208     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
   209     val ns = map length ctr_Tsss;
   209     val ns = map length ctr_Tsss;
   327           (((([], [], []), ([], [], [])),
   327           (((([], [], []), ([], [], [])),
   328             ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
   328             ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
   329              (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
   329              (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
   330         end;
   330         end;
   331 
   331 
   332     fun define_ctrs_case_for_type ((((((((((((((((((fp_b, fpT), C), fld), unf), fp_iter), fp_rec),
   332     fun define_ctrs_case_for_type ((((((((((((((((((fp_b, fpT), C), ctor), dtor), fp_iter), fp_rec),
   333           fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_bindings), ctr_mixfixes), ctr_Tss),
   333           ctor_dtor), dtor_ctor), ctor_inject), n), ks), ms), ctr_bindings), ctr_mixfixes), ctr_Tss),
   334         disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
   334         disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
   335       let
   335       let
   336         val fp_b_name = Binding.name_of fp_b;
   336         val fp_b_name = Binding.name_of fp_b;
   337 
   337 
   338         val unfT = domain_type (fastype_of fld);
   338         val dtorT = domain_type (fastype_of ctor);
   339         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   339         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   340         val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
   340         val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
   341         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
   341         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
   342 
   342 
   343         val ((((w, fs), xss), u'), _) =
   343         val ((((w, fs), xss), u'), _) =
   344           no_defs_lthy
   344           no_defs_lthy
   345           |> yield_singleton (mk_Frees "w") unfT
   345           |> yield_singleton (mk_Frees "w") dtorT
   346           ||>> mk_Frees "f" case_Ts
   346           ||>> mk_Frees "f" case_Ts
   347           ||>> mk_Freess "x" ctr_Tss
   347           ||>> mk_Freess "x" ctr_Tss
   348           ||>> yield_singleton Variable.variant_fixes fp_b_name;
   348           ||>> yield_singleton Variable.variant_fixes fp_b_name;
   349 
   349 
   350         val u = Free (u', fpT);
   350         val u = Free (u', fpT);
   351 
   351 
   352         val ctr_rhss =
   352         val ctr_rhss =
   353           map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
   353           map2 (fn k => fn xs => fold_rev Term.lambda xs (ctor $
   354             mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
   354             mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
   355 
   355 
   356         val case_binding = Binding.suffix_name ("_" ^ caseN) fp_b;
   356         val case_binding = Binding.suffix_name ("_" ^ caseN) fp_b;
   357 
   357 
   358         val case_rhs =
   358         val case_rhs =
   359           fold_rev Term.lambda (fs @ [u])
   359           fold_rev Term.lambda (fs @ [u])
   360             (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ u));
   360             (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (dtor $ u));
   361 
   361 
   362         val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   362         val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   363           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   363           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   364               Local_Theory.define ((b, mx), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   364               Local_Theory.define ((b, mx), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   365             (case_binding :: ctr_bindings) (NoSyn :: ctr_mixfixes) (case_rhs :: ctr_rhss)
   365             (case_binding :: ctr_bindings) (NoSyn :: ctr_mixfixes) (case_rhs :: ctr_rhss)
   375 
   375 
   376         val ctrs = map (mk_ctr As) ctrs0;
   376         val ctrs = map (mk_ctr As) ctrs0;
   377 
   377 
   378         fun exhaust_tac {context = ctxt, ...} =
   378         fun exhaust_tac {context = ctxt, ...} =
   379           let
   379           let
   380             val fld_iff_unf_thm =
   380             val ctor_iff_dtor_thm =
   381               let
   381               let
   382                 val goal =
   382                 val goal =
   383                   fold_rev Logic.all [w, u]
   383                   fold_rev Logic.all [w, u]
   384                     (mk_Trueprop_eq (HOLogic.mk_eq (u, fld $ w), HOLogic.mk_eq (unf $ u, w)));
   384                     (mk_Trueprop_eq (HOLogic.mk_eq (u, ctor $ w), HOLogic.mk_eq (dtor $ u, w)));
   385               in
   385               in
   386                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   386                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   387                   mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unfT, fpT])
   387                   mk_ctor_iff_dtor_tac ctxt (map (SOME o certifyT lthy) [dtorT, fpT])
   388                     (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
   388                     (certify lthy ctor) (certify lthy dtor) ctor_dtor dtor_ctor)
   389                 |> Thm.close_derivation
   389                 |> Thm.close_derivation
   390                 |> Morphism.thm phi
   390                 |> Morphism.thm phi
   391               end;
   391               end;
   392 
   392 
   393             val sumEN_thm' =
   393             val sumEN_thm' =
   394               unfold_defs lthy @{thms all_unit_eq}
   394               unfold_defs lthy @{thms all_unit_eq}
   395                 (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
   395                 (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
   396                    (mk_sumEN_balanced n))
   396                    (mk_sumEN_balanced n))
   397               |> Morphism.thm phi;
   397               |> Morphism.thm phi;
   398           in
   398           in
   399             mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
   399             mk_exhaust_tac ctxt n ctr_defs ctor_iff_dtor_thm sumEN_thm'
   400           end;
   400           end;
   401 
   401 
   402         val inject_tacss =
   402         val inject_tacss =
   403           map2 (fn 0 => K [] | _ => fn ctr_def => [fn {context = ctxt, ...} =>
   403           map2 (fn 0 => K [] | _ => fn ctr_def => [fn {context = ctxt, ...} =>
   404               mk_inject_tac ctxt ctr_def fld_inject]) ms ctr_defs;
   404               mk_inject_tac ctxt ctr_def ctor_inject]) ms ctr_defs;
   405 
   405 
   406         val half_distinct_tacss =
   406         val half_distinct_tacss =
   407           map (map (fn (def, def') => fn {context = ctxt, ...} =>
   407           map (map (fn (def, def') => fn {context = ctxt, ...} =>
   408             mk_half_distinct_tac ctxt fld_inject [def, def'])) (mk_half_pairss ctr_defs);
   408             mk_half_distinct_tac ctxt ctor_inject [def, def'])) (mk_half_pairss ctr_defs);
   409 
   409 
   410         val case_tacs =
   410         val case_tacs =
   411           map3 (fn k => fn m => fn ctr_def => fn {context = ctxt, ...} =>
   411           map3 (fn k => fn m => fn ctr_def => fn {context = ctxt, ...} =>
   412             mk_case_tac ctxt n k m case_def ctr_def unf_fld) ks ms ctr_defs;
   412             mk_case_tac ctxt n k m case_def ctr_def dtor_ctor) ks ms ctr_defs;
   413 
   413 
   414         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   414         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   415 
   415 
   416         fun define_iter_rec (wrap_res, no_defs_lthy) =
   416         fun define_iter_rec (wrap_res, no_defs_lthy) =
   417           let
   417           let
   599               Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
   599               Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
   600                 HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) phis us)));
   600                 HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) phis us)));
   601 
   601 
   602             val kksss = map (map (map (fst o snd) o #2)) raw_premss;
   602             val kksss = map (map (map (fst o snd) o #2)) raw_premss;
   603 
   603 
   604             val fld_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
   604             val ctor_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
   605 
   605 
   606             val induct_thm =
   606             val induct_thm =
   607               Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   607               Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   608                 mk_induct_tac ctxt ns mss kksss (flat ctr_defss) fld_induct'
   608                 mk_induct_tac ctxt ns mss kksss (flat ctr_defss) ctor_induct'
   609                   nested_set_natural's pre_set_defss)
   609                   nested_set_natural's pre_set_defss)
   610               |> singleton (Proof_Context.export names_lthy lthy)
   610               |> singleton (Proof_Context.export names_lthy lthy)
   611           in
   611           in
   612             `(conj_dests nn) induct_thm
   612             `(conj_dests nn) induct_thm
   613           end;
   613           end;
   873 
   873 
   874     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
   874     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
   875       fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list8
   875       fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list8
   876 
   876 
   877     val lthy' = lthy
   877     val lthy' = lthy
   878       |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
   878       |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ ctors ~~ dtors ~~ fp_iters ~~
   879         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
   879         fp_recs ~~ ctor_dtors ~~ dtor_ctors ~~ ctor_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
   880         ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)
   880         ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)
   881       |>> split_list |> wrap_types_and_define_iter_likes
   881       |>> split_list |> wrap_types_and_define_iter_likes
   882       |> (if lfp then derive_induct_iter_rec_thms_for_types
   882       |> (if lfp then derive_induct_iter_rec_thms_for_types
   883           else derive_coinduct_coiter_corec_thms_for_types);
   883           else derive_coinduct_coiter_corec_thms_for_types);
   884 
   884