src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49498 acc583e14167
parent 49484 0194a18f80cf
child 49501 acc9635a644a
equal deleted inserted replaced
49497:860b7c6bd913 49498:acc583e14167
    50 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    50 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    51 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    51 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    52 fun mk_uncurried2_fun f xss =
    52 fun mk_uncurried2_fun f xss =
    53   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
    53   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
    54 
    54 
    55 fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
    55 fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
    56 
    56 
    57 fun tack z_name (c, v) f =
    57 fun tack z_name (c, u) f =
    58   let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
    58   let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in
    59     Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
    59     Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z)
    60   end;
    60   end;
    61 
    61 
    62 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    62 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    63 
    63 
    64 fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
    64 fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
    90     val _ = if not lfp andalso no_dests then error "Cannot define destructor-less codatatypes"
    90     val _ = if not lfp andalso no_dests then error "Cannot define destructor-less codatatypes"
    91       else ();
    91       else ();
    92 
    92 
    93     val nn = length specs;
    93     val nn = length specs;
    94     val fp_bs = map type_binding_of specs;
    94     val fp_bs = map type_binding_of specs;
    95     val fp_common_name = mk_common_name fp_bs;
    95     val fp_b_names = map Binding.name_of fp_bs;
       
    96     val fp_common_name = mk_common_name fp_b_names;
    96 
    97 
    97     fun prepare_type_arg (ty, c) =
    98     fun prepare_type_arg (ty, c) =
    98       let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
    99       let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
    99         TFree (s, prepare_constraint no_defs_lthy0 c)
   100         TFree (s, prepare_constraint no_defs_lthy0 c)
   100       end;
   101       end;
   130 
   131 
   131     val ctr_specss = map ctr_specs_of specs;
   132     val ctr_specss = map ctr_specs_of specs;
   132 
   133 
   133     val disc_bindingss = map (map disc_of) ctr_specss;
   134     val disc_bindingss = map (map disc_of) ctr_specss;
   134     val ctr_bindingss =
   135     val ctr_bindingss =
   135       map2 (fn fp_b => map (Binding.qualify false (Binding.name_of fp_b) o ctr_of))
   136       map2 (fn fp_b_name => map (Binding.qualify false fp_b_name o ctr_of)) fp_b_names ctr_specss;
   136         fp_bs ctr_specss;
       
   137     val ctr_argsss = map (map args_of) ctr_specss;
   137     val ctr_argsss = map (map args_of) ctr_specss;
   138     val ctr_mixfixess = map (map ctr_mixfix_of) ctr_specss;
   138     val ctr_mixfixess = map (map ctr_mixfix_of) ctr_specss;
   139 
   139 
   140     val sel_bindingsss = map (map (map fst)) ctr_argsss;
   140     val sel_bindingsss = map (map (map fst)) ctr_argsss;
   141     val fake_ctr_Tsss0 = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
   141     val fake_ctr_Tsss0 = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
   331 
   331 
   332     fun define_ctrs_case_for_type ((((((((((((((((((fp_b, fpT), C), fld), unf), fp_iter), fp_rec),
   332     fun define_ctrs_case_for_type ((((((((((((((((((fp_b, fpT), C), fld), unf), fp_iter), fp_rec),
   333           fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_bindings), ctr_mixfixes), ctr_Tss),
   333           fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_bindings), ctr_mixfixes), ctr_Tss),
   334         disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
   334         disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
   335       let
   335       let
       
   336         val fp_b_name = Binding.name_of fp_b;
       
   337 
   336         val unfT = domain_type (fastype_of fld);
   338         val unfT = domain_type (fastype_of fld);
   337         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   339         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   338         val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
   340         val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
   339         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
   341         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
   340 
   342 
   341         val ((((u, fs), xss), v'), _) =
   343         val ((((w, fs), xss), u'), _) =
   342           no_defs_lthy
   344           no_defs_lthy
   343           |> yield_singleton (mk_Frees "u") unfT
   345           |> yield_singleton (mk_Frees "w") unfT
   344           ||>> mk_Frees "f" case_Ts
   346           ||>> mk_Frees "f" case_Ts
   345           ||>> mk_Freess "x" ctr_Tss
   347           ||>> mk_Freess "x" ctr_Tss
   346           ||>> yield_singleton (Variable.variant_fixes) (Binding.name_of fp_b);
   348           ||>> yield_singleton Variable.variant_fixes fp_b_name;
   347 
   349 
   348         val v = Free (v', fpT);
   350         val u = Free (u', fpT);
   349 
   351 
   350         val ctr_rhss =
   352         val ctr_rhss =
   351           map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
   353           map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
   352             mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
   354             mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
   353 
   355 
   354         val case_binding = Binding.suffix_name ("_" ^ caseN) fp_b;
   356         val case_binding = Binding.suffix_name ("_" ^ caseN) fp_b;
   355 
   357 
   356         val case_rhs =
   358         val case_rhs =
   357           fold_rev Term.lambda (fs @ [v])
   359           fold_rev Term.lambda (fs @ [u])
   358             (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   360             (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ u));
   359 
   361 
   360         val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   362         val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   361           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   363           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   362               Local_Theory.define ((b, mx), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   364               Local_Theory.define ((b, mx), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   363             (case_binding :: ctr_bindings) (NoSyn :: ctr_mixfixes) (case_rhs :: ctr_rhss)
   365             (case_binding :: ctr_bindings) (NoSyn :: ctr_mixfixes) (case_rhs :: ctr_rhss)
   376         fun exhaust_tac {context = ctxt, ...} =
   378         fun exhaust_tac {context = ctxt, ...} =
   377           let
   379           let
   378             val fld_iff_unf_thm =
   380             val fld_iff_unf_thm =
   379               let
   381               let
   380                 val goal =
   382                 val goal =
   381                   fold_rev Logic.all [u, v]
   383                   fold_rev Logic.all [w, u]
   382                     (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
   384                     (mk_Trueprop_eq (HOLogic.mk_eq (u, fld $ w), HOLogic.mk_eq (unf $ u, w)));
   383               in
   385               in
   384                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   386                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   385                   mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unfT, fpT])
   387                   mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unfT, fpT])
   386                     (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
   388                     (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
   387                 |> Thm.close_derivation
   389                 |> Thm.close_derivation
   523         injects @ distincts @ cases @ rec_likes @ iter_likes);
   525         injects @ distincts @ cases @ rec_likes @ iter_likes);
   524 
   526 
   525     fun derive_induct_iter_rec_thms_for_types ((wrap_ress, ctrss, iters, recs, xsss, ctr_defss,
   527     fun derive_induct_iter_rec_thms_for_types ((wrap_ress, ctrss, iters, recs, xsss, ctr_defss,
   526         iter_defs, rec_defs), lthy) =
   528         iter_defs, rec_defs), lthy) =
   527       let
   529       let
   528         val (((phis, phis'), vs'), names_lthy) =
   530         val (((phis, phis'), us'), names_lthy) =
   529           lthy
   531           lthy
   530           |> mk_Frees' "P" (map mk_pred1T fpTs)
   532           |> mk_Frees' "P" (map mk_pred1T fpTs)
   531           ||>> Variable.variant_fixes (map Binding.name_of fp_bs);
   533           ||>> Variable.variant_fixes fp_b_names;
   532 
   534 
   533         val vs = map2 (curry Free) vs' fpTs;
   535         val us = map2 (curry Free) us' fpTs;
   534 
   536 
   535         fun mk_sets_nested bnf =
   537         fun mk_sets_nested bnf =
   536           let
   538           let
   537             val Type (T_name, Us) = T_of_bnf bnf;
   539             val Type (T_name, Us) = T_of_bnf bnf;
   538             val lives = lives_of_bnf bnf;
   540             val lives = lives_of_bnf bnf;
   593 
   595 
   594             val raw_premss = map3 (map2 o mk_raw_prem) phis ctrss ctr_Tsss;
   596             val raw_premss = map3 (map2 o mk_raw_prem) phis ctrss ctr_Tsss;
   595 
   597 
   596             val goal =
   598             val goal =
   597               Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
   599               Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
   598                 HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) phis vs)));
   600                 HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) phis us)));
   599 
   601 
   600             val kksss = map (map (map (fst o snd) o #2)) raw_premss;
   602             val kksss = map (map (map (fst o snd) o #2)) raw_premss;
   601 
   603 
   602             val fld_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
   604             val fld_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
   603 
   605 
   627               if T = U then
   629               if T = U then
   628                 id_const T
   630                 id_const T
   629               else
   631               else
   630                 (case find_index (curry (op =) T) fpTs of
   632                 (case find_index (curry (op =) T) fpTs of
   631                   ~1 => build_map (build_call fiter_likes maybe_tick) T U
   633                   ~1 => build_map (build_call fiter_likes maybe_tick) T U
   632                 | j => maybe_tick (nth vs j) (nth fiter_likes j));
   634                 | j => maybe_tick (nth us j) (nth fiter_likes j));
   633 
   635 
   634             fun mk_U maybe_mk_prodT =
   636             fun mk_U maybe_mk_prodT =
   635               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
   637               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
   636 
   638 
   637             fun intr_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
   639             fun intr_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
   694         val selsss = map #2 wrap_ress;
   696         val selsss = map #2 wrap_ress;
   695         val disc_thmsss = map #6 wrap_ress;
   697         val disc_thmsss = map #6 wrap_ress;
   696         val discIss = map #7 wrap_ress;
   698         val discIss = map #7 wrap_ress;
   697         val sel_thmsss = map #8 wrap_ress;
   699         val sel_thmsss = map #8 wrap_ress;
   698 
   700 
   699         val (vs', _) =
   701         val (us', _) =
   700           lthy
   702           lthy
   701           |> Variable.variant_fixes (map Binding.name_of fp_bs);
   703           |> Variable.variant_fixes fp_b_names;
   702 
   704 
   703         val vs = map2 (curry Free) vs' fpTs;
   705         val us = map2 (curry Free) us' fpTs;
   704 
   706 
   705         val (coinduct_thms, coinduct_thm) =
   707         val (coinduct_thms, coinduct_thm) =
   706           let
   708           let
   707             val coinduct_thm = fp_induct;
   709             val coinduct_thm = fp_induct;
   708           in
   710           in
   726               if T = U then
   728               if T = U then
   727                 id_const T
   729                 id_const T
   728               else
   730               else
   729                 (case find_index (curry (op =) U) fpTs of
   731                 (case find_index (curry (op =) U) fpTs of
   730                   ~1 => build_map (build_call fiter_likes maybe_tack) T U
   732                   ~1 => build_map (build_call fiter_likes maybe_tack) T U
   731                 | j => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j));
   733                 | j => maybe_tack (nth cs j, nth us j) (nth fiter_likes j));
   732 
   734 
   733             fun mk_U maybe_mk_sumT =
   735             fun mk_U maybe_mk_sumT =
   734               typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
   736               typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
   735 
   737 
   736             fun intr_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
   738             fun intr_calls fiter_likes maybe_mk_sumT maybe_tack cqf =