src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
author blanchet
Tue Sep 24 00:01:10 2013 +0200 (2013-09-24)
changeset 53808 b3e2022530e3
parent 53746 bd038e48526d
child 53974 612505263257
permissions -rw-r--r--
register codatatypes with Nitpick
blanchet@53303
     1
(*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
blanchet@53303
     2
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@53303
     3
    Copyright   2013
blanchet@53303
     4
blanchet@53303
     5
Suggared flattening of nested to mutual (co)recursion.
blanchet@53303
     6
*)
blanchet@53303
     7
blanchet@53303
     8
signature BNF_FP_N2M_SUGAR =
blanchet@53303
     9
sig
blanchet@53303
    10
  val mutualize_fp_sugars: bool -> bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
blanchet@53303
    11
    (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
blanchet@53746
    12
    local_theory ->
blanchet@53746
    13
    (BNF_FP_Def_Sugar.fp_sugar list
blanchet@53746
    14
     * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
blanchet@53746
    15
    * local_theory
blanchet@53303
    16
  val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int ->
blanchet@53303
    17
    (term * term list list) list list -> term list list list list
blanchet@53303
    18
  val nested_to_mutual_fps: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
blanchet@53303
    19
    (term -> int list) -> ((term * term list list) list) list -> local_theory ->
blanchet@53746
    20
    (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
blanchet@53746
    21
     * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
blanchet@53746
    22
    * local_theory
blanchet@53303
    23
end;
blanchet@53303
    24
blanchet@53303
    25
structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
blanchet@53303
    26
struct
blanchet@53303
    27
blanchet@53303
    28
open BNF_Util
blanchet@53303
    29
open BNF_Def
blanchet@53303
    30
open BNF_Ctr_Sugar
blanchet@53303
    31
open BNF_FP_Util
blanchet@53303
    32
open BNF_FP_Def_Sugar
blanchet@53303
    33
open BNF_FP_N2M
blanchet@53303
    34
blanchet@53303
    35
val n2mN = "n2m_"
blanchet@53303
    36
blanchet@53303
    37
(* TODO: test with sort constraints on As *)
blanchet@53303
    38
(* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
blanchet@53303
    39
   as deads? *)
blanchet@53303
    40
fun mutualize_fp_sugars lose_co_rec mutualize fp bs fpTs get_indices callssss fp_sugars0
blanchet@53303
    41
    no_defs_lthy0 =
blanchet@53303
    42
  (* TODO: Also check whether there's any lost recursion? *)
blanchet@53303
    43
  if mutualize orelse has_duplicates (op =) fpTs then
blanchet@53303
    44
    let
blanchet@53303
    45
      val thy = Proof_Context.theory_of no_defs_lthy0;
blanchet@53303
    46
blanchet@53303
    47
      val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
blanchet@53303
    48
blanchet@53303
    49
      fun heterogeneous_call t = error ("Heterogeneous recursive call: " ^ qsotm t);
blanchet@53303
    50
      fun incompatible_calls t1 t2 =
blanchet@53303
    51
        error ("Incompatible recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
blanchet@53303
    52
blanchet@53303
    53
      val b_names = map Binding.name_of bs;
blanchet@53303
    54
      val fp_b_names = map base_name_of_typ fpTs;
blanchet@53303
    55
blanchet@53303
    56
      val nn = length fpTs;
blanchet@53303
    57
blanchet@53303
    58
      fun target_ctr_sugar_of_fp_sugar fpT {T, index, ctr_sugars, ...} =
blanchet@53303
    59
        let
blanchet@53303
    60
          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
blanchet@53303
    61
          val phi = Morphism.term_morphism (Term.subst_TVars rho);
blanchet@53303
    62
        in
blanchet@53303
    63
          morph_ctr_sugar phi (nth ctr_sugars index)
blanchet@53303
    64
        end;
blanchet@53303
    65
blanchet@53303
    66
      val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
blanchet@53476
    67
      val mapss = map (of_fp_sugar #mapss) fp_sugars0;
blanchet@53303
    68
      val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
blanchet@53303
    69
blanchet@53303
    70
      val ctrss = map #ctrs ctr_sugars0;
blanchet@53303
    71
      val ctr_Tss = map (map fastype_of) ctrss;
blanchet@53303
    72
blanchet@53303
    73
      val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
blanchet@53303
    74
      val As = map TFree As';
blanchet@53303
    75
blanchet@53303
    76
      val ((Cs, Xs), no_defs_lthy) =
blanchet@53303
    77
        no_defs_lthy0
blanchet@53303
    78
        |> fold Variable.declare_typ As
blanchet@53303
    79
        |> mk_TFrees nn
blanchet@53303
    80
        ||>> variant_tfrees fp_b_names;
blanchet@53303
    81
blanchet@53303
    82
      (* If "lose_co_rec" is "true", the function "null" on "'a list" gives rise to
blanchet@53303
    83
           'list = unit + 'a list
blanchet@53303
    84
         instead of
blanchet@53303
    85
           'list = unit + 'list
blanchet@53303
    86
         resulting in a simpler (co)induction rule and (co)recursor. *)
blanchet@53303
    87
      fun freeze_fp_default (T as Type (s, Ts)) =
blanchet@53303
    88
          (case find_index (curry (op =) T) fpTs of
blanchet@53303
    89
            ~1 => Type (s, map freeze_fp_default Ts)
blanchet@53303
    90
          | kk => nth Xs kk)
blanchet@53303
    91
        | freeze_fp_default T = T;
blanchet@53303
    92
blanchet@53303
    93
      fun get_indices_checked call =
blanchet@53303
    94
        (case get_indices call of
blanchet@53303
    95
          _ :: _ :: _ => heterogeneous_call call
blanchet@53303
    96
        | kks => kks);
blanchet@53303
    97
blanchet@53303
    98
      fun freeze_fp calls (T as Type (s, Ts)) =
blanchet@53303
    99
          (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of
blanchet@53303
   100
            [] =>
blanchet@53303
   101
            (case union (op = o pairself fst)
blanchet@53303
   102
                (maps (fn call => map (rpair call) (get_indices_checked call)) calls) [] of
blanchet@53303
   103
              [] => T |> not lose_co_rec ? freeze_fp_default
blanchet@53303
   104
            | [(kk, _)] => nth Xs kk
blanchet@53303
   105
            | (_, call1) :: (_, call2) :: _ => incompatible_calls call1 call2)
blanchet@53303
   106
          | callss =>
blanchet@53303
   107
            Type (s, map2 freeze_fp (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
blanchet@53303
   108
              (transpose callss)) Ts))
blanchet@53303
   109
        | freeze_fp _ T = T;
blanchet@53303
   110
blanchet@53303
   111
      val ctr_Tsss = map (map binder_types) ctr_Tss;
blanchet@53303
   112
      val ctrXs_Tsss = map2 (map2 (map2 freeze_fp)) callssss ctr_Tsss;
blanchet@53303
   113
      val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
blanchet@53303
   114
      val Ts = map (body_type o hd) ctr_Tss;
blanchet@53303
   115
blanchet@53303
   116
      val ns = map length ctr_Tsss;
blanchet@53303
   117
      val kss = map (fn n => 1 upto n) ns;
blanchet@53303
   118
      val mss = map (map length) ctr_Tsss;
blanchet@53303
   119
blanchet@53303
   120
      val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
blanchet@53303
   121
blanchet@53303
   122
      val base_fp_names = Name.variant_list [] fp_b_names;
blanchet@53303
   123
      val fp_bs = map2 (fn b_name => fn base_fp_name =>
blanchet@53303
   124
          Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
blanchet@53303
   125
        b_names base_fp_names;
blanchet@53303
   126
blanchet@53303
   127
      val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
blanchet@53303
   128
             dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
blanchet@53303
   129
        fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
blanchet@53303
   130
blanchet@53303
   131
      val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
blanchet@53303
   132
      val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
blanchet@53303
   133
blanchet@53303
   134
      val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
blanchet@53591
   135
        mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
blanchet@53303
   136
blanchet@53303
   137
      fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
blanchet@53303
   138
blanchet@53303
   139
      val ((co_iterss, co_iter_defss), lthy) =
blanchet@53303
   140
        fold_map2 (fn b =>
blanchet@53303
   141
          (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
blanchet@53303
   142
           else define_coiters [unfoldN, corecN] (the coiters_args_types))
blanchet@53303
   143
            (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
blanchet@53303
   144
        |>> split_list;
blanchet@53303
   145
blanchet@53303
   146
      val rho = tvar_subst thy Ts fpTs;
blanchet@53303
   147
      val ctr_sugar_phi =
blanchet@53303
   148
        Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
blanchet@53303
   149
          (Morphism.term_morphism (Term.subst_TVars rho));
blanchet@53303
   150
      val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
blanchet@53303
   151
blanchet@53303
   152
      val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
blanchet@53303
   153
blanchet@53746
   154
      val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
blanchet@53746
   155
            sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
blanchet@53303
   156
        if fp = Least_FP then
blanchet@53303
   157
          derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
blanchet@53303
   158
            xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
blanchet@53303
   159
            co_iterss co_iter_defss lthy
blanchet@53808
   160
          |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
blanchet@53475
   161
            ([induct], fold_thmss, rec_thmss, [], [], [], []))
blanchet@53746
   162
          ||> (fn info => (SOME info, NONE))
blanchet@53303
   163
        else
blanchet@53303
   164
          derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
blanchet@53303
   165
            dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs kss mss ns ctr_defss
blanchet@53303
   166
            ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy) lthy
blanchet@53746
   167
          |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _), _,
blanchet@53678
   168
                  (disc_unfold_thmss, disc_corec_thmss, _), _,
blanchet@53475
   169
                  (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
blanchet@53475
   170
            (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
blanchet@53746
   171
             disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
blanchet@53746
   172
          ||> (fn info => (NONE, SOME info));
blanchet@53303
   173
blanchet@53303
   174
      val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
blanchet@53303
   175
blanchet@53303
   176
      fun mk_target_fp_sugar (kk, T) =
blanchet@53303
   177
        {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
blanchet@53303
   178
         nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
blanchet@53476
   179
         ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
blanchet@53475
   180
         co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
blanchet@53475
   181
         disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
blanchet@53475
   182
         sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
blanchet@53303
   183
        |> morph_fp_sugar phi;
blanchet@53303
   184
    in
blanchet@53746
   185
      ((map_index mk_target_fp_sugar fpTs, fp_sugar_thms), lthy)
blanchet@53303
   186
    end
blanchet@53303
   187
  else
blanchet@53303
   188
    (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
blanchet@53746
   189
    ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
blanchet@53303
   190
blanchet@53303
   191
fun indexify_callsss fp_sugar callsss =
blanchet@53303
   192
  let
blanchet@53303
   193
    val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
blanchet@53303
   194
    fun do_ctr ctr =
blanchet@53303
   195
      (case AList.lookup Term.aconv_untyped callsss ctr of
blanchet@53303
   196
        NONE => replicate (num_binder_types (fastype_of ctr)) []
blanchet@53303
   197
      | SOME callss => map (map Envir.beta_eta_contract) callss);
blanchet@53303
   198
  in
blanchet@53303
   199
    map do_ctr ctrs
blanchet@53303
   200
  end;
blanchet@53303
   201
blanchet@53303
   202
fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list [];
blanchet@53303
   203
blanchet@53303
   204
fun nested_to_mutual_fps lose_co_rec fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
blanchet@53303
   205
  let
blanchet@53303
   206
    val qsoty = quote o Syntax.string_of_typ lthy;
blanchet@53303
   207
    val qsotys = space_implode " or " o map qsoty;
blanchet@53303
   208
blanchet@53303
   209
    fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
blanchet@53303
   210
    fun not_co_datatype (T as Type (s, _)) =
blanchet@53303
   211
        if fp = Least_FP andalso
blanchet@53303
   212
           is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
blanchet@53303
   213
          error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
blanchet@53303
   214
        else
blanchet@53303
   215
          not_co_datatype0 T
blanchet@53303
   216
      | not_co_datatype T = not_co_datatype0 T;
blanchet@53303
   217
    fun not_mutually_nested_rec Ts1 Ts2 =
blanchet@53303
   218
      error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
blanchet@53303
   219
        qsotys Ts2);
blanchet@53303
   220
blanchet@53303
   221
    val perm_actual_Ts as Type (_, ty_args0) :: _ =
blanchet@53303
   222
      sort (int_ord o pairself Term.size_of_typ) actual_Ts;
blanchet@53303
   223
blanchet@53303
   224
    fun check_enrich_with_mutuals _ [] = []
blanchet@53303
   225
      | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) =
blanchet@53303
   226
        (case fp_sugar_of lthy T_name of
blanchet@53303
   227
          SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
blanchet@53303
   228
          if fp = fp' then
blanchet@53303
   229
            let
blanchet@53303
   230
              val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts';
blanchet@53303
   231
              val _ =
blanchet@53303
   232
                seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
blanchet@53303
   233
                not_mutually_nested_rec mutual_Ts seen;
blanchet@53303
   234
              val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
blanchet@53303
   235
            in
blanchet@53303
   236
              mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
blanchet@53303
   237
            end
blanchet@53303
   238
          else
blanchet@53303
   239
            not_co_datatype T
blanchet@53303
   240
        | NONE => not_co_datatype T)
blanchet@53303
   241
      | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
blanchet@53303
   242
blanchet@53303
   243
    val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
blanchet@53303
   244
blanchet@53303
   245
    val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
blanchet@53303
   246
    val Ts = actual_Ts @ missing_Ts;
blanchet@53303
   247
blanchet@53303
   248
    val nn = length Ts;
blanchet@53303
   249
    val kks = 0 upto nn - 1;
blanchet@53303
   250
blanchet@53303
   251
    val common_name = mk_common_name (map Binding.name_of actual_bs);
blanchet@53303
   252
    val bs = pad_list (Binding.name common_name) nn actual_bs;
blanchet@53303
   253
blanchet@53303
   254
    fun permute xs = permute_like (op =) Ts perm_Ts xs;
blanchet@53303
   255
    fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
blanchet@53303
   256
blanchet@53303
   257
    val perm_bs = permute bs;
blanchet@53303
   258
    val perm_kks = permute kks;
blanchet@53303
   259
    val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
blanchet@53303
   260
blanchet@53303
   261
    val mutualize = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts;
blanchet@53303
   262
    val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0;
blanchet@53303
   263
blanchet@53303
   264
    val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
blanchet@53303
   265
blanchet@53746
   266
    val ((perm_fp_sugars, fp_sugar_thms), lthy) =
blanchet@53303
   267
      mutualize_fp_sugars lose_co_rec mutualize fp perm_bs perm_Ts get_perm_indices perm_callssss
blanchet@53303
   268
        perm_fp_sugars0 lthy;
blanchet@53303
   269
blanchet@53303
   270
    val fp_sugars = unpermute perm_fp_sugars;
blanchet@53303
   271
  in
blanchet@53746
   272
    ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)
blanchet@53303
   273
  end;
blanchet@53303
   274
blanchet@53303
   275
end;