src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
changeset 53303 ae49b835ca01
child 53475 185ad6cf6576
equal deleted inserted replaced
53302:98fdf6c34142 53303:ae49b835ca01
       
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
       
     2     Author:     Jasmin Blanchette, TU Muenchen
       
     3     Copyright   2013
       
     4 
       
     5 Suggared flattening of nested to mutual (co)recursion.
       
     6 *)
       
     7 
       
     8 signature BNF_FP_N2M_SUGAR =
       
     9 sig
       
    10   val mutualize_fp_sugars: bool -> bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
       
    11     (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
       
    12     local_theory -> (bool * BNF_FP_Def_Sugar.fp_sugar list) * local_theory
       
    13   val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int ->
       
    14     (term * term list list) list list -> term list list list list
       
    15   val nested_to_mutual_fps: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
       
    16     (term -> int list) -> ((term * term list list) list) list -> local_theory ->
       
    17     (bool * typ list * int list * BNF_FP_Def_Sugar.fp_sugar list) * local_theory
       
    18 end;
       
    19 
       
    20 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
       
    21 struct
       
    22 
       
    23 open BNF_Util
       
    24 open BNF_Def
       
    25 open BNF_Ctr_Sugar
       
    26 open BNF_FP_Util
       
    27 open BNF_FP_Def_Sugar
       
    28 open BNF_FP_N2M
       
    29 
       
    30 val n2mN = "n2m_"
       
    31 
       
    32 (* TODO: test with sort constraints on As *)
       
    33 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
       
    34    as deads? *)
       
    35 fun mutualize_fp_sugars lose_co_rec mutualize fp bs fpTs get_indices callssss fp_sugars0
       
    36     no_defs_lthy0 =
       
    37   (* TODO: Also check whether there's any lost recursion? *)
       
    38   if mutualize orelse has_duplicates (op =) fpTs then
       
    39     let
       
    40       val thy = Proof_Context.theory_of no_defs_lthy0;
       
    41 
       
    42       val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
       
    43 
       
    44       fun heterogeneous_call t = error ("Heterogeneous recursive call: " ^ qsotm t);
       
    45       fun incompatible_calls t1 t2 =
       
    46         error ("Incompatible recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
       
    47 
       
    48       val b_names = map Binding.name_of bs;
       
    49       val fp_b_names = map base_name_of_typ fpTs;
       
    50 
       
    51       val nn = length fpTs;
       
    52 
       
    53       fun target_ctr_sugar_of_fp_sugar fpT {T, index, ctr_sugars, ...} =
       
    54         let
       
    55           val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
       
    56           val phi = Morphism.term_morphism (Term.subst_TVars rho);
       
    57         in
       
    58           morph_ctr_sugar phi (nth ctr_sugars index)
       
    59         end;
       
    60 
       
    61       val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
       
    62       val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
       
    63 
       
    64       val ctrss = map #ctrs ctr_sugars0;
       
    65       val ctr_Tss = map (map fastype_of) ctrss;
       
    66 
       
    67       val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
       
    68       val As = map TFree As';
       
    69 
       
    70       val ((Cs, Xs), no_defs_lthy) =
       
    71         no_defs_lthy0
       
    72         |> fold Variable.declare_typ As
       
    73         |> mk_TFrees nn
       
    74         ||>> variant_tfrees fp_b_names;
       
    75 
       
    76       (* If "lose_co_rec" is "true", the function "null" on "'a list" gives rise to
       
    77            'list = unit + 'a list
       
    78          instead of
       
    79            'list = unit + 'list
       
    80          resulting in a simpler (co)induction rule and (co)recursor. *)
       
    81       fun freeze_fp_default (T as Type (s, Ts)) =
       
    82           (case find_index (curry (op =) T) fpTs of
       
    83             ~1 => Type (s, map freeze_fp_default Ts)
       
    84           | kk => nth Xs kk)
       
    85         | freeze_fp_default T = T;
       
    86 
       
    87       fun get_indices_checked call =
       
    88         (case get_indices call of
       
    89           _ :: _ :: _ => heterogeneous_call call
       
    90         | kks => kks);
       
    91 
       
    92       fun freeze_fp calls (T as Type (s, Ts)) =
       
    93           (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of
       
    94             [] =>
       
    95             (case union (op = o pairself fst)
       
    96                 (maps (fn call => map (rpair call) (get_indices_checked call)) calls) [] of
       
    97               [] => T |> not lose_co_rec ? freeze_fp_default
       
    98             | [(kk, _)] => nth Xs kk
       
    99             | (_, call1) :: (_, call2) :: _ => incompatible_calls call1 call2)
       
   100           | callss =>
       
   101             Type (s, map2 freeze_fp (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
       
   102               (transpose callss)) Ts))
       
   103         | freeze_fp _ T = T;
       
   104 
       
   105       val ctr_Tsss = map (map binder_types) ctr_Tss;
       
   106       val ctrXs_Tsss = map2 (map2 (map2 freeze_fp)) callssss ctr_Tsss;
       
   107       val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
       
   108       val Ts = map (body_type o hd) ctr_Tss;
       
   109 
       
   110       val ns = map length ctr_Tsss;
       
   111       val kss = map (fn n => 1 upto n) ns;
       
   112       val mss = map (map length) ctr_Tsss;
       
   113 
       
   114       val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
       
   115 
       
   116       val base_fp_names = Name.variant_list [] fp_b_names;
       
   117       val fp_bs = map2 (fn b_name => fn base_fp_name =>
       
   118           Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
       
   119         b_names base_fp_names;
       
   120 
       
   121       val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
       
   122              dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
       
   123         fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
       
   124 
       
   125       val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
       
   126       val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
       
   127 
       
   128       val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
       
   129         mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
       
   130 
       
   131       fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
       
   132 
       
   133       val ((co_iterss, co_iter_defss), lthy) =
       
   134         fold_map2 (fn b =>
       
   135           (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
       
   136            else define_coiters [unfoldN, corecN] (the coiters_args_types))
       
   137             (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
       
   138         |>> split_list;
       
   139 
       
   140       val rho = tvar_subst thy Ts fpTs;
       
   141       val ctr_sugar_phi =
       
   142         Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
       
   143           (Morphism.term_morphism (Term.subst_TVars rho));
       
   144       val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
       
   145 
       
   146       val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
       
   147 
       
   148       val (co_inducts, un_fold_thmss, co_rec_thmss) =
       
   149         if fp = Least_FP then
       
   150           derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
       
   151             xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
       
   152             co_iterss co_iter_defss lthy
       
   153           |> (fn ((_, induct, _), (fold_thmss, _), (rec_thmss, _)) =>
       
   154             ([induct], fold_thmss, rec_thmss))
       
   155         else
       
   156           derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
       
   157             dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs kss mss ns ctr_defss
       
   158             ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy) lthy
       
   159           |> (fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _), _, _, _, _) =>
       
   160             (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss));
       
   161 
       
   162       val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
       
   163 
       
   164       fun mk_target_fp_sugar (kk, T) =
       
   165         {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
       
   166          nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
       
   167          ctr_sugars = ctr_sugars, co_inducts = co_inducts, co_iterss = co_iterss,
       
   168          co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss]}
       
   169         |> morph_fp_sugar phi;
       
   170     in
       
   171       ((true, map_index mk_target_fp_sugar fpTs), lthy)
       
   172     end
       
   173   else
       
   174     (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
       
   175     ((false, fp_sugars0), no_defs_lthy0);
       
   176 
       
   177 fun indexify_callsss fp_sugar callsss =
       
   178   let
       
   179     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
       
   180     fun do_ctr ctr =
       
   181       (case AList.lookup Term.aconv_untyped callsss ctr of
       
   182         NONE => replicate (num_binder_types (fastype_of ctr)) []
       
   183       | SOME callss => map (map Envir.beta_eta_contract) callss);
       
   184   in
       
   185     map do_ctr ctrs
       
   186   end;
       
   187 
       
   188 fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list [];
       
   189 
       
   190 fun nested_to_mutual_fps lose_co_rec fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
       
   191   let
       
   192     val qsoty = quote o Syntax.string_of_typ lthy;
       
   193     val qsotys = space_implode " or " o map qsoty;
       
   194 
       
   195     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
       
   196     fun not_co_datatype (T as Type (s, _)) =
       
   197         if fp = Least_FP andalso
       
   198            is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
       
   199           error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
       
   200         else
       
   201           not_co_datatype0 T
       
   202       | not_co_datatype T = not_co_datatype0 T;
       
   203     fun not_mutually_nested_rec Ts1 Ts2 =
       
   204       error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
       
   205         qsotys Ts2);
       
   206 
       
   207     val perm_actual_Ts as Type (_, ty_args0) :: _ =
       
   208       sort (int_ord o pairself Term.size_of_typ) actual_Ts;
       
   209 
       
   210     fun check_enrich_with_mutuals _ [] = []
       
   211       | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) =
       
   212         (case fp_sugar_of lthy T_name of
       
   213           SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
       
   214           if fp = fp' then
       
   215             let
       
   216               val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts';
       
   217               val _ =
       
   218                 seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
       
   219                 not_mutually_nested_rec mutual_Ts seen;
       
   220               val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
       
   221             in
       
   222               mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
       
   223             end
       
   224           else
       
   225             not_co_datatype T
       
   226         | NONE => not_co_datatype T)
       
   227       | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
       
   228 
       
   229     val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
       
   230 
       
   231     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
       
   232     val Ts = actual_Ts @ missing_Ts;
       
   233 
       
   234     val nn = length Ts;
       
   235     val kks = 0 upto nn - 1;
       
   236 
       
   237     val common_name = mk_common_name (map Binding.name_of actual_bs);
       
   238     val bs = pad_list (Binding.name common_name) nn actual_bs;
       
   239 
       
   240     fun permute xs = permute_like (op =) Ts perm_Ts xs;
       
   241     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
       
   242 
       
   243     val perm_bs = permute bs;
       
   244     val perm_kks = permute kks;
       
   245     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
       
   246 
       
   247     val mutualize = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts;
       
   248     val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0;
       
   249 
       
   250     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
       
   251 
       
   252     val ((nontriv, perm_fp_sugars), lthy) =
       
   253       mutualize_fp_sugars lose_co_rec mutualize fp perm_bs perm_Ts get_perm_indices perm_callssss
       
   254         perm_fp_sugars0 lthy;
       
   255 
       
   256     val fp_sugars = unpermute perm_fp_sugars;
       
   257   in
       
   258     ((nontriv, missing_Ts, perm_kks, fp_sugars), lthy)
       
   259   end;
       
   260 
       
   261 end;