src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
author blanchet
Fri Sep 20 11:44:30 2013 +0200 (2013-09-20)
changeset 53746 bd038e48526d
parent 53678 e55bb583342e
child 53808 b3e2022530e3
permissions -rw-r--r--
have "datatype_new_compat" register induction and recursion theorems in nested case
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2013
     4 
     5 Suggared flattening of nested to mutual (co)recursion.
     6 *)
     7 
     8 signature BNF_FP_N2M_SUGAR =
     9 sig
    10   val mutualize_fp_sugars: bool -> bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
    11     (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
    12     local_theory ->
    13     (BNF_FP_Def_Sugar.fp_sugar list
    14      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    15     * local_theory
    16   val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int ->
    17     (term * term list list) list list -> term list list list list
    18   val nested_to_mutual_fps: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
    19     (term -> int list) -> ((term * term list list) list) list -> local_theory ->
    20     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    21      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    22     * local_theory
    23 end;
    24 
    25 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
    26 struct
    27 
    28 open BNF_Util
    29 open BNF_Def
    30 open BNF_Ctr_Sugar
    31 open BNF_FP_Util
    32 open BNF_FP_Def_Sugar
    33 open BNF_FP_N2M
    34 
    35 val n2mN = "n2m_"
    36 
    37 (* TODO: test with sort constraints on As *)
    38 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
    39    as deads? *)
    40 fun mutualize_fp_sugars lose_co_rec mutualize fp bs fpTs get_indices callssss fp_sugars0
    41     no_defs_lthy0 =
    42   (* TODO: Also check whether there's any lost recursion? *)
    43   if mutualize orelse has_duplicates (op =) fpTs then
    44     let
    45       val thy = Proof_Context.theory_of no_defs_lthy0;
    46 
    47       val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
    48 
    49       fun heterogeneous_call t = error ("Heterogeneous recursive call: " ^ qsotm t);
    50       fun incompatible_calls t1 t2 =
    51         error ("Incompatible recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
    52 
    53       val b_names = map Binding.name_of bs;
    54       val fp_b_names = map base_name_of_typ fpTs;
    55 
    56       val nn = length fpTs;
    57 
    58       fun target_ctr_sugar_of_fp_sugar fpT {T, index, ctr_sugars, ...} =
    59         let
    60           val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
    61           val phi = Morphism.term_morphism (Term.subst_TVars rho);
    62         in
    63           morph_ctr_sugar phi (nth ctr_sugars index)
    64         end;
    65 
    66       val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
    67       val mapss = map (of_fp_sugar #mapss) fp_sugars0;
    68       val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
    69 
    70       val ctrss = map #ctrs ctr_sugars0;
    71       val ctr_Tss = map (map fastype_of) ctrss;
    72 
    73       val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
    74       val As = map TFree As';
    75 
    76       val ((Cs, Xs), no_defs_lthy) =
    77         no_defs_lthy0
    78         |> fold Variable.declare_typ As
    79         |> mk_TFrees nn
    80         ||>> variant_tfrees fp_b_names;
    81 
    82       (* If "lose_co_rec" is "true", the function "null" on "'a list" gives rise to
    83            'list = unit + 'a list
    84          instead of
    85            'list = unit + 'list
    86          resulting in a simpler (co)induction rule and (co)recursor. *)
    87       fun freeze_fp_default (T as Type (s, Ts)) =
    88           (case find_index (curry (op =) T) fpTs of
    89             ~1 => Type (s, map freeze_fp_default Ts)
    90           | kk => nth Xs kk)
    91         | freeze_fp_default T = T;
    92 
    93       fun get_indices_checked call =
    94         (case get_indices call of
    95           _ :: _ :: _ => heterogeneous_call call
    96         | kks => kks);
    97 
    98       fun freeze_fp calls (T as Type (s, Ts)) =
    99           (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of
   100             [] =>
   101             (case union (op = o pairself fst)
   102                 (maps (fn call => map (rpair call) (get_indices_checked call)) calls) [] of
   103               [] => T |> not lose_co_rec ? freeze_fp_default
   104             | [(kk, _)] => nth Xs kk
   105             | (_, call1) :: (_, call2) :: _ => incompatible_calls call1 call2)
   106           | callss =>
   107             Type (s, map2 freeze_fp (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
   108               (transpose callss)) Ts))
   109         | freeze_fp _ T = T;
   110 
   111       val ctr_Tsss = map (map binder_types) ctr_Tss;
   112       val ctrXs_Tsss = map2 (map2 (map2 freeze_fp)) callssss ctr_Tsss;
   113       val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   114       val Ts = map (body_type o hd) ctr_Tss;
   115 
   116       val ns = map length ctr_Tsss;
   117       val kss = map (fn n => 1 upto n) ns;
   118       val mss = map (map length) ctr_Tsss;
   119 
   120       val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
   121 
   122       val base_fp_names = Name.variant_list [] fp_b_names;
   123       val fp_bs = map2 (fn b_name => fn base_fp_name =>
   124           Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
   125         b_names base_fp_names;
   126 
   127       val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
   128              dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
   129         fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
   130 
   131       val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
   132       val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
   133 
   134       val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
   135         mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
   136 
   137       fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
   138 
   139       val ((co_iterss, co_iter_defss), lthy) =
   140         fold_map2 (fn b =>
   141           (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
   142            else define_coiters [unfoldN, corecN] (the coiters_args_types))
   143             (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
   144         |>> split_list;
   145 
   146       val rho = tvar_subst thy Ts fpTs;
   147       val ctr_sugar_phi =
   148         Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
   149           (Morphism.term_morphism (Term.subst_TVars rho));
   150       val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
   151 
   152       val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
   153 
   154       val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
   155             sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
   156         if fp = Least_FP then
   157           derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
   158             xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
   159             co_iterss co_iter_defss lthy
   160           |> `(fn ((_, induct, _), (fold_thmss, _), (rec_thmss, _)) =>
   161             ([induct], fold_thmss, rec_thmss, [], [], [], []))
   162           ||> (fn info => (SOME info, NONE))
   163         else
   164           derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
   165             dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs kss mss ns ctr_defss
   166             ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy) lthy
   167           |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _), _,
   168                   (disc_unfold_thmss, disc_corec_thmss, _), _,
   169                   (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
   170             (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
   171              disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
   172           ||> (fn info => (NONE, SOME info));
   173 
   174       val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   175 
   176       fun mk_target_fp_sugar (kk, T) =
   177         {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
   178          nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
   179          ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
   180          co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
   181          disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
   182          sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
   183         |> morph_fp_sugar phi;
   184     in
   185       ((map_index mk_target_fp_sugar fpTs, fp_sugar_thms), lthy)
   186     end
   187   else
   188     (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
   189     ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
   190 
   191 fun indexify_callsss fp_sugar callsss =
   192   let
   193     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   194     fun do_ctr ctr =
   195       (case AList.lookup Term.aconv_untyped callsss ctr of
   196         NONE => replicate (num_binder_types (fastype_of ctr)) []
   197       | SOME callss => map (map Envir.beta_eta_contract) callss);
   198   in
   199     map do_ctr ctrs
   200   end;
   201 
   202 fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list [];
   203 
   204 fun nested_to_mutual_fps lose_co_rec fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   205   let
   206     val qsoty = quote o Syntax.string_of_typ lthy;
   207     val qsotys = space_implode " or " o map qsoty;
   208 
   209     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   210     fun not_co_datatype (T as Type (s, _)) =
   211         if fp = Least_FP andalso
   212            is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
   213           error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
   214         else
   215           not_co_datatype0 T
   216       | not_co_datatype T = not_co_datatype0 T;
   217     fun not_mutually_nested_rec Ts1 Ts2 =
   218       error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
   219         qsotys Ts2);
   220 
   221     val perm_actual_Ts as Type (_, ty_args0) :: _ =
   222       sort (int_ord o pairself Term.size_of_typ) actual_Ts;
   223 
   224     fun check_enrich_with_mutuals _ [] = []
   225       | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) =
   226         (case fp_sugar_of lthy T_name of
   227           SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
   228           if fp = fp' then
   229             let
   230               val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts';
   231               val _ =
   232                 seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
   233                 not_mutually_nested_rec mutual_Ts seen;
   234               val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
   235             in
   236               mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
   237             end
   238           else
   239             not_co_datatype T
   240         | NONE => not_co_datatype T)
   241       | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
   242 
   243     val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
   244 
   245     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   246     val Ts = actual_Ts @ missing_Ts;
   247 
   248     val nn = length Ts;
   249     val kks = 0 upto nn - 1;
   250 
   251     val common_name = mk_common_name (map Binding.name_of actual_bs);
   252     val bs = pad_list (Binding.name common_name) nn actual_bs;
   253 
   254     fun permute xs = permute_like (op =) Ts perm_Ts xs;
   255     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   256 
   257     val perm_bs = permute bs;
   258     val perm_kks = permute kks;
   259     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   260 
   261     val mutualize = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts;
   262     val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0;
   263 
   264     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   265 
   266     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   267       mutualize_fp_sugars lose_co_rec mutualize fp perm_bs perm_Ts get_perm_indices perm_callssss
   268         perm_fp_sugars0 lthy;
   269 
   270     val fp_sugars = unpermute perm_fp_sugars;
   271   in
   272     ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)
   273   end;
   274 
   275 end;