src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49438 5bc80d96241e
parent 49437 c139da00fb4a
child 49450 24029cbec12a
equal deleted inserted replaced
49437:c139da00fb4a 49438:5bc80d96241e
    32 open BNF_FP_Util
    32 open BNF_FP_Util
    33 open BNF_FP_Sugar_Tactics
    33 open BNF_FP_Sugar_Tactics
    34 
    34 
    35 val simp_attrs = @{attributes [simp]};
    35 val simp_attrs = @{attributes [simp]};
    36 
    36 
    37 fun split_list10 xs =
    37 fun split_list8 xs =
    38   (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs,
    38   (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
    39    map #9 xs, map #10 xs);
       
    40 
    39 
    41 fun resort_tfree S (TFree (s, _)) = TFree (s, S);
    40 fun resort_tfree S (TFree (s, _)) = TFree (s, S);
    42 
    41 
    43 fun typ_subst inst (T as Type (s, Ts)) =
    42 fun typ_subst inst (T as Type (s, Ts)) =
    44     (case AList.lookup (op =) inst T of
    43     (case AList.lookup (op =) inst T of
   404           map3 (fn k => fn m => fn ctr_def => fn {context = ctxt, ...} =>
   403           map3 (fn k => fn m => fn ctr_def => fn {context = ctxt, ...} =>
   405             mk_case_tac ctxt n k m case_def ctr_def unf_fld) ks ms ctr_defs;
   404             mk_case_tac ctxt n k m case_def ctr_def unf_fld) ks ms ctr_defs;
   406 
   405 
   407         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   406         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   408 
   407 
   409         fun define_iter_rec ((selss0, discIs, sel_thmss), no_defs_lthy) =
   408         fun define_iter_rec (wrap_res, no_defs_lthy) =
   410           let
   409           let
   411             val fpT_to_C = fpT --> C;
   410             val fpT_to_C = fpT --> C;
   412 
   411 
   413             fun generate_iter_like (suf, fp_iter_like, (fss, f_Tss, xssss)) =
   412             fun generate_iter_like (suf, fp_iter_like, (fss, f_Tss, xssss)) =
   414               let
   413               let
   436 
   435 
   437             val [iter_def, rec_def] = map (Morphism.thm phi) defs;
   436             val [iter_def, rec_def] = map (Morphism.thm phi) defs;
   438 
   437 
   439             val [iter, recx] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   438             val [iter, recx] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   440           in
   439           in
   441             ((ctrs, selss0, iter, recx, xss, ctr_defs, discIs, sel_thmss, iter_def, rec_def), lthy)
   440             ((wrap_res, ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
   442           end;
   441           end;
   443 
   442 
   444         fun define_coiter_corec ((selss0, discIs, sel_thmss), no_defs_lthy) =
   443         fun define_coiter_corec (wrap_res, no_defs_lthy) =
   445           let
   444           let
   446             val B_to_fpT = C --> fpT;
   445             val B_to_fpT = C --> fpT;
   447 
   446 
   448             fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
   447             fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
   449               Term.lambda c (mk_IfN sum_prod_T cps
   448               Term.lambda c (mk_IfN sum_prod_T cps
   476 
   475 
   477             val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
   476             val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
   478 
   477 
   479             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   478             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
   480           in
   479           in
   481             ((ctrs, selss0, coiter, corec, xss, ctr_defs, discIs, sel_thmss, coiter_def, corec_def),
   480             ((wrap_res, ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
   482              lthy)
       
   483           end;
   481           end;
   484 
   482 
   485         fun wrap lthy =
   483         fun wrap lthy =
   486           let val sel_defaultss = map (map (apsnd (prepare_term lthy))) raw_sel_defaultss in
   484           let val sel_defaultss = map (map (apsnd (prepare_term lthy))) raw_sel_defaultss in
   487             wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_bindings, (sel_bindingss,
   485             wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_bindings, (sel_bindingss,
   512         val mapx = mk_map live Ts Us (map_of_bnf bnf);
   510         val mapx = mk_map live Ts Us (map_of_bnf bnf);
   513         val TUs = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
   511         val TUs = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
   514         val args = map build_arg TUs;
   512         val args = map build_arg TUs;
   515       in Term.list_comb (mapx, args) end;
   513       in Term.list_comb (mapx, args) end;
   516 
   514 
   517     fun derive_induct_iter_rec_thms_for_types ((ctrss, _, iters, recs, xsss, ctr_defss, _, _,
   515     fun derive_induct_iter_rec_thms_for_types ((wrap_ress, ctrss, iters, recs, xsss, ctr_defss,
   518         iter_defs, rec_defs), lthy) =
   516         iter_defs, rec_defs), lthy) =
   519       let
   517       let
       
   518         val inject_thmss = map #2 wrap_ress;
       
   519         val distinct_thmss = map #3 wrap_ress;
       
   520         val case_thmss = map #4 wrap_ress;
       
   521 
   520         val (((phis, phis'), vs'), names_lthy) =
   522         val (((phis, phis'), vs'), names_lthy) =
   521           lthy
   523           lthy
   522           |> mk_Frees' "P" (map mk_predT fpTs)
   524           |> mk_Frees' "P" (map mk_predT fpTs)
   523           ||>> Variable.variant_fixes (map Binding.name_of fp_bs);
   525           ||>> Variable.variant_fixes (map Binding.name_of fp_bs);
   524 
   526 
   601               |> singleton (Proof_Context.export names_lthy lthy)
   603               |> singleton (Proof_Context.export names_lthy lthy)
   602           in
   604           in
   603             `(conj_dests nn) induct_thm
   605             `(conj_dests nn) induct_thm
   604           end;
   606           end;
   605 
   607 
       
   608         (* TODO: Generate nicer names in case of clashes *)
   606         val induct_cases = Datatype_Prop.indexify_names (maps (map base_name_of_ctr) ctrss);
   609         val induct_cases = Datatype_Prop.indexify_names (maps (map base_name_of_ctr) ctrss);
   607 
   610 
   608         val (iter_thmss, rec_thmss) =
   611         val (iter_thmss, rec_thmss) =
   609           let
   612           let
   610             val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
   613             val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
   651                goal_iterss iter_tacss,
   654                goal_iterss iter_tacss,
   652              map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   655              map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   653                goal_recss rec_tacss)
   656                goal_recss rec_tacss)
   654           end;
   657           end;
   655 
   658 
       
   659         val simp_thmss =
       
   660           map4 (fn injects => fn distincts => fn cases => fn recs =>
       
   661             injects @ distincts @ cases @ recs) inject_thmss distinct_thmss case_thmss rec_thmss;
       
   662 
   656         val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
   663         val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
   657         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
   664         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
   658 
   665 
   659         val common_notes =
   666         val common_notes =
   660           (if nn > 1 then [(inductN, [induct_thm], [induct_case_names_attr])] else [])
   667           (if nn > 1 then [(inductN, [induct_thm], [induct_case_names_attr])] else [])
   663 
   670 
   664         val notes =
   671         val notes =
   665           [(inductN, map single induct_thms,
   672           [(inductN, map single induct_thms,
   666             fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
   673             fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
   667            (itersN, iter_thmss, K simp_attrs),
   674            (itersN, iter_thmss, K simp_attrs),
   668            (recsN, rec_thmss, K (Code.add_default_eqn_attrib :: simp_attrs))]
   675            (recsN, rec_thmss, K (Code.add_default_eqn_attrib :: simp_attrs)),
       
   676            (simpsN, simp_thmss, K [])]
   669           |> maps (fn (thmN, thmss, attrs) =>
   677           |> maps (fn (thmN, thmss, attrs) =>
   670             map3 (fn b => fn Type (T_name, _) => fn thms =>
   678             map3 (fn b => fn Type (T_name, _) => fn thms =>
   671               ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs T_name),
   679               ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs T_name),
   672                 [(thms, [])])) fp_bs fpTs thmss);
   680                 [(thms, [])])) fp_bs fpTs thmss);
   673       in
   681       in
   674         lthy |> Local_Theory.notes (common_notes @ notes) |> snd
   682         lthy |> Local_Theory.notes (common_notes @ notes) |> snd
   675       end;
   683       end;
   676 
   684 
   677     fun derive_coinduct_coiter_corec_thms_for_types ((ctrss, selsss, coiters, corecs, _, ctr_defss,
   685     fun derive_coinduct_coiter_corec_thms_for_types ((wrap_ress, ctrss, coiters, corecs, _,
   678         discIss, sel_thmsss, coiter_defs, corec_defs), lthy) =
   686         ctr_defss, coiter_defs, corec_defs), lthy) =
   679       let
   687       let
       
   688         val selsss0 = map #1 wrap_ress;
       
   689         val discIss = map #5 wrap_ress;
       
   690         val sel_thmsss = map #6 wrap_ress;
       
   691 
   680         val (vs', _) =
   692         val (vs', _) =
   681           lthy
   693           lthy
   682           |> Variable.variant_fixes (map Binding.name_of fp_bs);
   694           |> Variable.variant_fixes (map Binding.name_of fp_bs);
   683 
   695 
   684         val vs = map2 (curry Free) vs' fpTs;
   696         val vs = map2 (curry Free) vs' fpTs;
   763           in
   775           in
   764             coiter_like_thm RS arg_cong' RS sel_thm'
   776             coiter_like_thm RS arg_cong' RS sel_thm'
   765           end;
   777           end;
   766 
   778 
   767         val sel_coiter_thmsss =
   779         val sel_coiter_thmsss =
   768           map3 (map3 (map2 o mk_sel_coiter_like_thm)) coiter_thmss selsss sel_thmsss;
   780           map3 (map3 (map2 o mk_sel_coiter_like_thm)) coiter_thmss selsss0 sel_thmsss;
   769         val sel_corec_thmsss =
   781         val sel_corec_thmsss =
   770           map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss sel_thmsss;
   782           map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss0 sel_thmsss;
   771 
   783 
   772         val common_notes =
   784         val common_notes =
   773           (if nn > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
   785           (if nn > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
   774           |> map (fn (thmN, thms, attrs) =>
   786           |> map (fn (thmN, thms, attrs) =>
   775               ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
   787               ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
   789       in
   801       in
   790         lthy |> Local_Theory.notes (common_notes @ notes) |> snd
   802         lthy |> Local_Theory.notes (common_notes @ notes) |> snd
   791       end;
   803       end;
   792 
   804 
   793     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
   805     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
   794       fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list10
   806       fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list8
   795 
   807 
   796     val lthy' = lthy
   808     val lthy' = lthy
   797       |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
   809       |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
   798         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
   810         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
   799         ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)
   811         ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)