src/HOL/Tools/BNF/bnf_fp_n2m.ML
author blanchet
Tue Feb 18 17:52:27 2014 +0100 (2014-02-18)
changeset 55566 ab0a547b5aee
parent 55539 0819931d652d
child 55575 a5e33e18fb5c
permissions -rw-r--r--
made SML/NJ happier
     1 (*  Title:      HOL/Tools/BNF/bnf_fp_n2m.ML
     2     Author:     Dmitriy Traytel, TU Muenchen
     3     Copyright   2013
     4 
     5 Flattening of nested to mutual (co)recursion.
     6 *)
     7 
     8 signature BNF_FP_N2M =
     9 sig
    10   val construct_mutualized_fp: BNF_FP_Util.fp_kind  -> typ list -> BNF_FP_Def_Sugar.fp_sugar list ->
    11     binding list -> (string * sort) list -> typ list * typ list list -> BNF_Def.bnf list ->
    12     local_theory -> BNF_FP_Util.fp_result * local_theory
    13 end;
    14 
    15 structure BNF_FP_N2M : BNF_FP_N2M =
    16 struct
    17 
    18 open BNF_Def
    19 open BNF_Util
    20 open BNF_FP_Util
    21 open BNF_FP_Def_Sugar
    22 open BNF_Tactics
    23 open BNF_FP_N2M_Tactics
    24 
    25 fun force_typ ctxt T =
    26   map_types Type_Infer.paramify_vars
    27   #> Type.constraint T
    28   #> Syntax.check_term ctxt
    29   #> singleton (Variable.polymorphic ctxt);
    30 
    31 fun mk_prod_map f g =
    32   let
    33     val ((fAT, fBT), fT) = `dest_funT (fastype_of f);
    34     val ((gAT, gBT), gT) = `dest_funT (fastype_of g);
    35   in
    36     Const (@{const_name map_pair},
    37       fT --> gT --> HOLogic.mk_prodT (fAT, gAT) --> HOLogic.mk_prodT (fBT, gBT)) $ f $ g
    38   end;
    39 
    40 fun mk_sum_map f g =
    41   let
    42     val ((fAT, fBT), fT) = `dest_funT (fastype_of f);
    43     val ((gAT, gBT), gT) = `dest_funT (fastype_of g);
    44   in
    45     Const (@{const_name sum_map}, fT --> gT --> mk_sumT (fAT, gAT) --> mk_sumT (fBT, gBT)) $ f $ g
    46   end;
    47 
    48 fun construct_mutualized_fp fp fpTs (fp_sugars : fp_sugar list) bs resBs (resDs, Dss) bnfs lthy =
    49   let
    50     fun steal_fp_res get =
    51       map (fn {fp_res, fp_res_index, ...} => nth (get fp_res) fp_res_index) fp_sugars;
    52 
    53     val n = length bnfs;
    54     val deads = fold (union (op =)) Dss resDs;
    55     val As = subtract (op =) deads (map TFree resBs);
    56     val names_lthy = fold Variable.declare_typ (As @ deads) lthy;
    57     val m = length As;
    58     val live = m + n;
    59     val ((Xs, Bs), names_lthy) = names_lthy
    60       |> mk_TFrees n
    61       ||>> mk_TFrees m;
    62     val allAs = As @ Xs;
    63     val phiTs = map2 mk_pred2T As Bs;
    64     val theta = As ~~ Bs;
    65     val fpTs' = map (Term.typ_subst_atomic theta) fpTs;
    66     val pre_phiTs = map2 mk_pred2T fpTs fpTs';
    67 
    68     fun mk_co_algT T U = fp_case fp (T --> U) (U --> T);
    69     fun co_swap pair = fp_case fp I swap pair;
    70     val dest_co_algT = co_swap o dest_funT;
    71     val co_alg_argT = fp_case fp range_type domain_type;
    72     val co_alg_funT = fp_case fp domain_type range_type;
    73     val mk_co_product = curry (fp_case fp mk_convol mk_case_sum);
    74     val mk_map_co_product = fp_case fp mk_prod_map mk_sum_map;
    75     val co_proj1_const = fp_case fp (fst_const o fst) (uncurry Inl_const o dest_sumT o snd);
    76     val mk_co_productT = curry (fp_case fp HOLogic.mk_prodT mk_sumT);
    77     val dest_co_productT = fp_case fp HOLogic.dest_prodT dest_sumT;
    78 
    79     val ((ctors, dtors), (xtor's, xtors)) =
    80       let
    81         val ctors = map2 (force_typ names_lthy o (fn T => dummyT --> T)) fpTs (steal_fp_res #ctors);
    82         val dtors = map2 (force_typ names_lthy o (fn T => T --> dummyT)) fpTs (steal_fp_res #dtors);
    83       in
    84         ((ctors, dtors), `(map (Term.subst_atomic_types theta)) (fp_case fp ctors dtors))
    85       end;
    86 
    87     val xTs = map (domain_type o fastype_of) xtors;
    88     val yTs = map (domain_type o fastype_of) xtor's;
    89 
    90     val (((((phis, phis'), pre_phis), xs), ys), names_lthy) = names_lthy
    91       |> mk_Frees' "R" phiTs
    92       ||>> mk_Frees "S" pre_phiTs
    93       ||>> mk_Frees "x" xTs
    94       ||>> mk_Frees "y" yTs;
    95 
    96     val fp_bnfs = steal_fp_res #bnfs;
    97     val pre_bnfs = map #pre_bnf fp_sugars;
    98     val nesty_bnfss = map (fn sugar => #nested_bnfs sugar @ #nesting_bnfs sugar) fp_sugars;
    99     val fp_nesty_bnfss = fp_bnfs :: nesty_bnfss;
   100     val fp_nesty_bnfs = distinct (op = o pairself T_of_bnf) (flat fp_nesty_bnfss);
   101 
   102     val rels =
   103       let
   104         fun find_rel T As Bs = fp_nesty_bnfss
   105           |> map (filter_out (curry (op = o pairself name_of_bnf) BNF_Comp.DEADID_bnf))
   106           |> get_first (find_first (fn bnf => Type.could_unify (T_of_bnf bnf, T)))
   107           |> Option.map (fn bnf =>
   108             let val live = live_of_bnf bnf;
   109             in (mk_rel live As Bs (rel_of_bnf bnf), live) end)
   110           |> the_default (HOLogic.eq_const T, 0);
   111 
   112         fun mk_rel (T as Type (_, Ts)) (Type (_, Us)) =
   113               let
   114                 val (rel, live) = find_rel T Ts Us;
   115                 val (Ts', Us') = fastype_of rel |> strip_typeN live |> fst |> map_split dest_pred2T;
   116                 val rels = map2 mk_rel Ts' Us';
   117               in
   118                 Term.list_comb (rel, rels)
   119               end
   120           | mk_rel (T as TFree _) _ = (nth phis (find_index (curry op = T) As)
   121               handle General.Subscript => HOLogic.eq_const T)
   122           | mk_rel _ _ = raise Fail "fpTs contains schematic type variables";
   123       in
   124         map2 (fold_rev Term.absfree phis' oo mk_rel) fpTs fpTs'
   125       end;
   126 
   127     val pre_rels = map2 (fn Ds => mk_rel_of_bnf Ds (As @ fpTs) (Bs @ fpTs')) Dss bnfs;
   128 
   129     val rel_unfolds = maps (no_refl o single o rel_def_of_bnf) pre_bnfs;
   130     val rel_xtor_co_inducts = steal_fp_res (split_conj_thm o #rel_xtor_co_induct_thm)
   131       |> map (unfold_thms lthy (id_apply :: rel_unfolds));
   132 
   133     val rel_defs = map rel_def_of_bnf bnfs;
   134     val rel_monos = map rel_mono_of_bnf bnfs;
   135 
   136     val rel_xtor_co_induct_thm =
   137       mk_rel_xtor_co_induct_thm fp pre_rels pre_phis rels phis xs ys xtors xtor's
   138         (mk_rel_xtor_co_induct_tactic fp rel_xtor_co_inducts rel_defs rel_monos) lthy;
   139 
   140     val rel_eqs = no_refl (map rel_eq_of_bnf fp_nesty_bnfs);
   141     val map_id0s = no_refl (map map_id0_of_bnf bnfs);
   142 
   143     val xtor_co_induct_thm =
   144       (case fp of
   145         Least_FP =>
   146           let
   147             val (Ps, names_lthy) = names_lthy
   148               |> mk_Frees "P" (map (fn T => T --> HOLogic.boolT) fpTs);
   149             fun mk_Grp_id P =
   150               let val T = domain_type (fastype_of P);
   151               in mk_Grp (HOLogic.Collect_const T $ P) (HOLogic.id_const T) end;
   152             val cts = map (SOME o certify lthy) (map HOLogic.eq_const As @ map mk_Grp_id Ps);
   153           in
   154             cterm_instantiate_pos cts rel_xtor_co_induct_thm
   155             |> singleton (Proof_Context.export names_lthy lthy)
   156             |> unfold_thms lthy (@{thms eq_le_Grp_id_iff all_simps(1,2)[symmetric]} @ rel_eqs)
   157             |> funpow n (fn thm => thm RS spec)
   158             |> unfold_thms lthy (@{thm eq_alt} :: map rel_Grp_of_bnf bnfs @ map_id0s)
   159             |> unfold_thms lthy @{thms Grp_id_mono_subst eqTrueI[OF subset_UNIV] simp_thms(22)}
   160             |> unfold_thms lthy @{thms subset_iff mem_Collect_eq
   161                atomize_conjL[symmetric] atomize_all[symmetric] atomize_imp[symmetric]}
   162             |> unfold_thms lthy (maps set_defs_of_bnf bnfs)
   163           end
   164       | Greatest_FP =>
   165           let
   166             val cts = NONE :: map (SOME o certify lthy) (map HOLogic.eq_const As);
   167           in
   168             cterm_instantiate_pos cts rel_xtor_co_induct_thm
   169             |> unfold_thms lthy (@{thms le_fun_def le_bool_def all_simps(1,2)[symmetric]} @ rel_eqs)
   170             |> funpow (2 * n) (fn thm => thm RS spec)
   171             |> Conv.fconv_rule (Object_Logic.atomize lthy)
   172             |> funpow n (fn thm => thm RS mp)
   173           end);
   174 
   175     val fold_preTs = map2 (fn Ds => mk_T_of_bnf Ds allAs) Dss bnfs;
   176     val fold_pre_deads_only_Ts = map2 (fn Ds => mk_T_of_bnf Ds (replicate live dummyT)) Dss bnfs;
   177     val rec_theta = Xs ~~ map2 mk_co_productT fpTs Xs;
   178     val rec_preTs = map (Term.typ_subst_atomic rec_theta) fold_preTs;
   179 
   180     val fold_strTs = map2 mk_co_algT fold_preTs Xs;
   181     val rec_strTs = map2 mk_co_algT rec_preTs Xs;
   182     val resTs = map2 mk_co_algT fpTs Xs;
   183 
   184     val (((fold_strs, fold_strs'), (rec_strs, rec_strs')), names_lthy) = names_lthy
   185       |> mk_Frees' "s" fold_strTs
   186       ||>> mk_Frees' "s" rec_strTs;
   187 
   188     val co_iters = steal_fp_res #xtor_co_iterss;
   189     val ns = map (length o #Ts o #fp_res) fp_sugars;
   190 
   191     fun substT rho (Type (@{type_name "fun"}, [T, U])) = substT rho T --> substT rho U
   192       | substT rho (Type (s, Ts)) = Type (s, map (typ_subst_nonatomic rho) Ts)
   193       | substT _ T = T;
   194 
   195     fun force_iter is_rec i TU TU_rec raw_iters =
   196       let
   197         val approx_fold = un_fold_of raw_iters
   198           |> force_typ names_lthy
   199             (replicate (nth ns i) dummyT ---> (if is_rec then TU_rec else TU));
   200         val subst = Term.typ_subst_atomic (Xs ~~ fpTs);
   201         val TUs = map_split dest_co_algT (binder_fun_types (fastype_of approx_fold))
   202           |>> map subst
   203           |> uncurry (map2 mk_co_algT);
   204         val js = find_indices Type.could_unify TUs
   205           (map2 (fn T => fn U => mk_co_algT (subst T) U) fold_preTs Xs);
   206         val Tpats = map (fn j => mk_co_algT (nth fold_pre_deads_only_Ts j) (nth Xs j)) js;
   207         val iter = raw_iters |> (if is_rec then co_rec_of else un_fold_of);
   208       in
   209         force_typ names_lthy (Tpats ---> TU) iter
   210       end;
   211 
   212     fun mk_iter b_opt is_rec iters lthy TU =
   213       let
   214         val x = co_alg_argT TU;
   215         val i = find_index (fn T => x = T) Xs;
   216         val TUiter =
   217           (case find_first (fn f => body_fun_type (fastype_of f) = TU) iters of
   218             NONE => nth co_iters i
   219               |> force_iter is_rec i
   220                 (TU |> (is_none b_opt andalso not is_rec) ? substT (fpTs ~~ Xs))
   221                 (TU |> (is_none b_opt) ? substT (map2 mk_co_productT fpTs Xs ~~ Xs))
   222           | SOME f => f);
   223         val TUs = binder_fun_types (fastype_of TUiter);
   224         val iter_preTs = if is_rec then rec_preTs else fold_preTs;
   225         val iter_strs = if is_rec then rec_strs else fold_strs;
   226         fun mk_s TU' =
   227           let
   228             val i = find_index (fn T => co_alg_argT TU' = T) Xs;
   229             val sF = co_alg_funT TU';
   230             val F = nth iter_preTs i;
   231             val s = nth iter_strs i;
   232           in
   233             (if sF = F then s
   234             else
   235               let
   236                 val smapT = replicate live dummyT ---> mk_co_algT sF F;
   237                 fun hidden_to_unit t =
   238                   Term.subst_TVars (map (rpair HOLogic.unitT) (Term.add_tvar_names t [])) t;
   239                 val smap = map_of_bnf (nth bnfs i)
   240                   |> force_typ names_lthy smapT
   241                   |> hidden_to_unit;
   242                 val smap_argTs = strip_typeN live (fastype_of smap) |> fst;
   243                 fun mk_smap_arg TU =
   244                   (if domain_type TU = range_type TU then
   245                     HOLogic.id_const (domain_type TU)
   246                   else if is_rec then
   247                     let
   248                       val (TY, (U, X)) = TU |> dest_co_algT ||> dest_co_productT;
   249                       val T = mk_co_algT TY U;
   250                     in
   251                       (case try (force_typ lthy T o build_map lthy co_proj1_const o dest_funT) T of
   252                         SOME f => mk_co_product f
   253                           (fst (fst (mk_iter NONE is_rec iters lthy (mk_co_algT TY X))))
   254                       | NONE => mk_map_co_product
   255                           (build_map lthy co_proj1_const
   256                             (dest_funT (mk_co_algT (dest_co_productT TY |> fst) U)))
   257                           (HOLogic.id_const X))
   258                     end
   259                   else
   260                     fst (fst (mk_iter NONE is_rec iters lthy TU)))
   261                 val smap_args = map mk_smap_arg smap_argTs;
   262               in
   263                 HOLogic.mk_comp (co_swap (s, Term.list_comb (smap, smap_args)))
   264               end)
   265           end;
   266         val t = Term.list_comb (TUiter, map mk_s TUs);
   267       in
   268         (case b_opt of
   269           NONE => ((t, Drule.dummy_thm), lthy)
   270         | SOME b => Local_Theory.define ((b, NoSyn), ((Binding.conceal (Thm.def_binding b), []),
   271             fold_rev Term.absfree (if is_rec then rec_strs' else fold_strs') t)) lthy |>> apsnd snd)
   272       end;
   273 
   274     fun mk_iters is_rec name lthy =
   275       fold2 (fn TU => fn b => fn ((iters, defs), lthy) =>
   276         mk_iter (SOME b) is_rec iters lthy TU |>> (fn (f, d) => (f :: iters, d :: defs)))
   277       resTs (map (Binding.suffix_name ("_" ^ name)) bs) (([], []), lthy)
   278       |>> apfst rev o apsnd rev;
   279     val foldN = fp_case fp ctor_foldN dtor_unfoldN;
   280     val recN = fp_case fp ctor_recN dtor_corecN;
   281     val (((raw_un_folds, raw_un_fold_defs), (raw_co_recs, raw_co_rec_defs)), (lthy, raw_lthy)) =
   282       lthy
   283       |> mk_iters false foldN
   284       ||>> mk_iters true recN
   285       ||> `Local_Theory.restore;
   286 
   287     val phi = Proof_Context.export_morphism raw_lthy lthy;
   288 
   289     val un_folds = map (Morphism.term phi) raw_un_folds;
   290     val co_recs = map (Morphism.term phi) raw_co_recs;
   291 
   292     val (xtor_un_fold_thms, xtor_co_rec_thms) =
   293       let
   294         val folds = map (fn f => Term.list_comb (f, fold_strs)) raw_un_folds;
   295         val recs = map (fn r => Term.list_comb (r, rec_strs)) raw_co_recs;
   296         val fold_mapTs = co_swap (As @ fpTs, As @ Xs);
   297         val rec_mapTs = co_swap (As @ fpTs, As @ map2 mk_co_productT fpTs Xs);
   298         val pre_fold_maps =
   299           map2 (fn Ds => fn bnf =>
   300             Term.list_comb (uncurry (mk_map_of_bnf Ds) fold_mapTs bnf,
   301               map HOLogic.id_const As @ folds))
   302           Dss bnfs;
   303         val pre_rec_maps =
   304           map2 (fn Ds => fn bnf =>
   305             Term.list_comb (uncurry (mk_map_of_bnf Ds) rec_mapTs bnf,
   306               map HOLogic.id_const As @ map2 (mk_co_product o HOLogic.id_const) fpTs recs))
   307           Dss bnfs;
   308 
   309         fun mk_goals f xtor s smap =
   310           ((f, xtor), (s, smap))
   311           |> pairself (HOLogic.mk_comp o co_swap)
   312           |> HOLogic.mk_eq;
   313 
   314         val fold_goals = map4 mk_goals folds xtors fold_strs pre_fold_maps
   315         val rec_goals = map4 mk_goals recs xtors rec_strs pre_rec_maps;
   316 
   317         fun mk_thms ss goals tac =
   318           Library.foldr1 HOLogic.mk_conj goals
   319           |> HOLogic.mk_Trueprop
   320           |> fold_rev Logic.all ss
   321           |> (fn goal => Goal.prove_sorry raw_lthy [] [] goal tac)
   322           |> Thm.close_derivation
   323           |> Morphism.thm phi
   324           |> split_conj_thm
   325           |> map (fn thm => thm RS @{thm comp_eq_dest});
   326 
   327         val pre_map_defs = no_refl (map map_def_of_bnf bnfs);
   328         val fp_pre_map_defs = no_refl (map map_def_of_bnf pre_bnfs);
   329 
   330         val map_unfolds = maps (fn bnf => no_refl [map_def_of_bnf bnf]) pre_bnfs;
   331         val unfold_map = map (unfold_thms lthy (id_apply :: map_unfolds));
   332 
   333         val fp_xtor_co_iterss = steal_fp_res #xtor_co_iter_thmss;
   334         val fp_xtor_un_folds = map (mk_pointfree lthy o un_fold_of) fp_xtor_co_iterss |> unfold_map;
   335         val fp_xtor_co_recs = map (mk_pointfree lthy o co_rec_of) fp_xtor_co_iterss |> unfold_map;
   336 
   337         val fp_co_iter_o_mapss = steal_fp_res #xtor_co_iter_o_map_thmss;
   338         val fp_fold_o_maps = map un_fold_of fp_co_iter_o_mapss |> unfold_map;
   339         val fp_rec_o_maps = map co_rec_of fp_co_iter_o_mapss |> unfold_map;
   340         val fold_thms = fp_case fp @{thm comp_assoc} @{thm comp_assoc[symmetric]} :: @{thms id_apply
   341           o_apply comp_id id_comp map_pair.comp map_pair.id sum_map.comp sum_map.id};
   342         val rec_thms = fold_thms @ fp_case fp
   343           @{thms fst_convol map_pair_o_convol convol_o}
   344           @{thms case_sum_o_inj(1) case_sum_o_sum_map o_case_sum};
   345         val map_thms = no_refl (maps (fn bnf =>
   346           [map_comp0_of_bnf bnf RS sym, map_id0_of_bnf bnf]) fp_nesty_bnfs);
   347 
   348         fun mk_tac defs o_map_thms xtor_thms thms {context = ctxt, prems = _} =
   349           unfold_thms_tac ctxt
   350             (flat [thms, defs, pre_map_defs, fp_pre_map_defs, xtor_thms, o_map_thms, map_thms]) THEN
   351           CONJ_WRAP (K (HEADGOAL (rtac refl))) bnfs;
   352 
   353         val fold_tac = mk_tac raw_un_fold_defs fp_fold_o_maps fp_xtor_un_folds fold_thms;
   354         val rec_tac = mk_tac raw_co_rec_defs fp_rec_o_maps fp_xtor_co_recs rec_thms;
   355       in
   356         (mk_thms fold_strs fold_goals fold_tac, mk_thms rec_strs rec_goals rec_tac)
   357       end;
   358 
   359     (* These results are half broken. This is deliberate. We care only about those fields that are
   360        used by "primrec", "primcorecursive", and "datatype_compat". *)
   361     val fp_res =
   362       ({Ts = fpTs,
   363         bnfs = steal_fp_res #bnfs,
   364         dtors = dtors,
   365         ctors = ctors,
   366         xtor_co_iterss = transpose [un_folds, co_recs],
   367         xtor_co_induct = xtor_co_induct_thm,
   368         dtor_ctors = steal_fp_res #dtor_ctors (*too general types*),
   369         ctor_dtors = steal_fp_res #ctor_dtors (*too general types*),
   370         ctor_injects = steal_fp_res #ctor_injects (*too general types*),
   371         dtor_injects = steal_fp_res #dtor_injects (*too general types*),
   372         xtor_map_thms = steal_fp_res #xtor_map_thms (*too general types and terms*),
   373         xtor_set_thmss = steal_fp_res #xtor_set_thmss (*too general types and terms*),
   374         xtor_rel_thms = steal_fp_res #xtor_rel_thms (*too general types and terms*),
   375         xtor_co_iter_thmss = transpose [xtor_un_fold_thms, xtor_co_rec_thms],
   376         xtor_co_iter_o_map_thmss = steal_fp_res #xtor_co_iter_o_map_thmss
   377           (*theorem about old constant*),
   378         rel_xtor_co_induct_thm = rel_xtor_co_induct_thm}
   379        |> morph_fp_result (Morphism.term_morphism "BNF" (singleton (Variable.polymorphic lthy))));
   380   in
   381     (fp_res, lthy)
   382   end;
   383 
   384 end;