src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
changeset 53303 ae49b835ca01
child 53475 185ad6cf6576
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Fri Aug 30 11:27:23 2013 +0200
     1.3 @@ -0,0 +1,261 @@
     1.4 +(*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
     1.5 +    Author:     Jasmin Blanchette, TU Muenchen
     1.6 +    Copyright   2013
     1.7 +
     1.8 +Suggared flattening of nested to mutual (co)recursion.
     1.9 +*)
    1.10 +
    1.11 +signature BNF_FP_N2M_SUGAR =
    1.12 +sig
    1.13 +  val mutualize_fp_sugars: bool -> bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
    1.14 +    (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
    1.15 +    local_theory -> (bool * BNF_FP_Def_Sugar.fp_sugar list) * local_theory
    1.16 +  val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int ->
    1.17 +    (term * term list list) list list -> term list list list list
    1.18 +  val nested_to_mutual_fps: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
    1.19 +    (term -> int list) -> ((term * term list list) list) list -> local_theory ->
    1.20 +    (bool * typ list * int list * BNF_FP_Def_Sugar.fp_sugar list) * local_theory
    1.21 +end;
    1.22 +
    1.23 +structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
    1.24 +struct
    1.25 +
    1.26 +open BNF_Util
    1.27 +open BNF_Def
    1.28 +open BNF_Ctr_Sugar
    1.29 +open BNF_FP_Util
    1.30 +open BNF_FP_Def_Sugar
    1.31 +open BNF_FP_N2M
    1.32 +
    1.33 +val n2mN = "n2m_"
    1.34 +
    1.35 +(* TODO: test with sort constraints on As *)
    1.36 +(* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
    1.37 +   as deads? *)
    1.38 +fun mutualize_fp_sugars lose_co_rec mutualize fp bs fpTs get_indices callssss fp_sugars0
    1.39 +    no_defs_lthy0 =
    1.40 +  (* TODO: Also check whether there's any lost recursion? *)
    1.41 +  if mutualize orelse has_duplicates (op =) fpTs then
    1.42 +    let
    1.43 +      val thy = Proof_Context.theory_of no_defs_lthy0;
    1.44 +
    1.45 +      val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
    1.46 +
    1.47 +      fun heterogeneous_call t = error ("Heterogeneous recursive call: " ^ qsotm t);
    1.48 +      fun incompatible_calls t1 t2 =
    1.49 +        error ("Incompatible recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
    1.50 +
    1.51 +      val b_names = map Binding.name_of bs;
    1.52 +      val fp_b_names = map base_name_of_typ fpTs;
    1.53 +
    1.54 +      val nn = length fpTs;
    1.55 +
    1.56 +      fun target_ctr_sugar_of_fp_sugar fpT {T, index, ctr_sugars, ...} =
    1.57 +        let
    1.58 +          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
    1.59 +          val phi = Morphism.term_morphism (Term.subst_TVars rho);
    1.60 +        in
    1.61 +          morph_ctr_sugar phi (nth ctr_sugars index)
    1.62 +        end;
    1.63 +
    1.64 +      val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
    1.65 +      val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
    1.66 +
    1.67 +      val ctrss = map #ctrs ctr_sugars0;
    1.68 +      val ctr_Tss = map (map fastype_of) ctrss;
    1.69 +
    1.70 +      val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
    1.71 +      val As = map TFree As';
    1.72 +
    1.73 +      val ((Cs, Xs), no_defs_lthy) =
    1.74 +        no_defs_lthy0
    1.75 +        |> fold Variable.declare_typ As
    1.76 +        |> mk_TFrees nn
    1.77 +        ||>> variant_tfrees fp_b_names;
    1.78 +
    1.79 +      (* If "lose_co_rec" is "true", the function "null" on "'a list" gives rise to
    1.80 +           'list = unit + 'a list
    1.81 +         instead of
    1.82 +           'list = unit + 'list
    1.83 +         resulting in a simpler (co)induction rule and (co)recursor. *)
    1.84 +      fun freeze_fp_default (T as Type (s, Ts)) =
    1.85 +          (case find_index (curry (op =) T) fpTs of
    1.86 +            ~1 => Type (s, map freeze_fp_default Ts)
    1.87 +          | kk => nth Xs kk)
    1.88 +        | freeze_fp_default T = T;
    1.89 +
    1.90 +      fun get_indices_checked call =
    1.91 +        (case get_indices call of
    1.92 +          _ :: _ :: _ => heterogeneous_call call
    1.93 +        | kks => kks);
    1.94 +
    1.95 +      fun freeze_fp calls (T as Type (s, Ts)) =
    1.96 +          (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of
    1.97 +            [] =>
    1.98 +            (case union (op = o pairself fst)
    1.99 +                (maps (fn call => map (rpair call) (get_indices_checked call)) calls) [] of
   1.100 +              [] => T |> not lose_co_rec ? freeze_fp_default
   1.101 +            | [(kk, _)] => nth Xs kk
   1.102 +            | (_, call1) :: (_, call2) :: _ => incompatible_calls call1 call2)
   1.103 +          | callss =>
   1.104 +            Type (s, map2 freeze_fp (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
   1.105 +              (transpose callss)) Ts))
   1.106 +        | freeze_fp _ T = T;
   1.107 +
   1.108 +      val ctr_Tsss = map (map binder_types) ctr_Tss;
   1.109 +      val ctrXs_Tsss = map2 (map2 (map2 freeze_fp)) callssss ctr_Tsss;
   1.110 +      val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   1.111 +      val Ts = map (body_type o hd) ctr_Tss;
   1.112 +
   1.113 +      val ns = map length ctr_Tsss;
   1.114 +      val kss = map (fn n => 1 upto n) ns;
   1.115 +      val mss = map (map length) ctr_Tsss;
   1.116 +
   1.117 +      val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
   1.118 +
   1.119 +      val base_fp_names = Name.variant_list [] fp_b_names;
   1.120 +      val fp_bs = map2 (fn b_name => fn base_fp_name =>
   1.121 +          Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
   1.122 +        b_names base_fp_names;
   1.123 +
   1.124 +      val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
   1.125 +             dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
   1.126 +        fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
   1.127 +
   1.128 +      val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
   1.129 +      val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
   1.130 +
   1.131 +      val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
   1.132 +        mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
   1.133 +
   1.134 +      fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
   1.135 +
   1.136 +      val ((co_iterss, co_iter_defss), lthy) =
   1.137 +        fold_map2 (fn b =>
   1.138 +          (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
   1.139 +           else define_coiters [unfoldN, corecN] (the coiters_args_types))
   1.140 +            (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
   1.141 +        |>> split_list;
   1.142 +
   1.143 +      val rho = tvar_subst thy Ts fpTs;
   1.144 +      val ctr_sugar_phi =
   1.145 +        Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
   1.146 +          (Morphism.term_morphism (Term.subst_TVars rho));
   1.147 +      val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
   1.148 +
   1.149 +      val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
   1.150 +
   1.151 +      val (co_inducts, un_fold_thmss, co_rec_thmss) =
   1.152 +        if fp = Least_FP then
   1.153 +          derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
   1.154 +            xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
   1.155 +            co_iterss co_iter_defss lthy
   1.156 +          |> (fn ((_, induct, _), (fold_thmss, _), (rec_thmss, _)) =>
   1.157 +            ([induct], fold_thmss, rec_thmss))
   1.158 +        else
   1.159 +          derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
   1.160 +            dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs kss mss ns ctr_defss
   1.161 +            ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy) lthy
   1.162 +          |> (fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _), _, _, _, _) =>
   1.163 +            (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss));
   1.164 +
   1.165 +      val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   1.166 +
   1.167 +      fun mk_target_fp_sugar (kk, T) =
   1.168 +        {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
   1.169 +         nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
   1.170 +         ctr_sugars = ctr_sugars, co_inducts = co_inducts, co_iterss = co_iterss,
   1.171 +         co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss]}
   1.172 +        |> morph_fp_sugar phi;
   1.173 +    in
   1.174 +      ((true, map_index mk_target_fp_sugar fpTs), lthy)
   1.175 +    end
   1.176 +  else
   1.177 +    (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
   1.178 +    ((false, fp_sugars0), no_defs_lthy0);
   1.179 +
   1.180 +fun indexify_callsss fp_sugar callsss =
   1.181 +  let
   1.182 +    val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   1.183 +    fun do_ctr ctr =
   1.184 +      (case AList.lookup Term.aconv_untyped callsss ctr of
   1.185 +        NONE => replicate (num_binder_types (fastype_of ctr)) []
   1.186 +      | SOME callss => map (map Envir.beta_eta_contract) callss);
   1.187 +  in
   1.188 +    map do_ctr ctrs
   1.189 +  end;
   1.190 +
   1.191 +fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list [];
   1.192 +
   1.193 +fun nested_to_mutual_fps lose_co_rec fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   1.194 +  let
   1.195 +    val qsoty = quote o Syntax.string_of_typ lthy;
   1.196 +    val qsotys = space_implode " or " o map qsoty;
   1.197 +
   1.198 +    fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   1.199 +    fun not_co_datatype (T as Type (s, _)) =
   1.200 +        if fp = Least_FP andalso
   1.201 +           is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
   1.202 +          error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
   1.203 +        else
   1.204 +          not_co_datatype0 T
   1.205 +      | not_co_datatype T = not_co_datatype0 T;
   1.206 +    fun not_mutually_nested_rec Ts1 Ts2 =
   1.207 +      error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
   1.208 +        qsotys Ts2);
   1.209 +
   1.210 +    val perm_actual_Ts as Type (_, ty_args0) :: _ =
   1.211 +      sort (int_ord o pairself Term.size_of_typ) actual_Ts;
   1.212 +
   1.213 +    fun check_enrich_with_mutuals _ [] = []
   1.214 +      | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) =
   1.215 +        (case fp_sugar_of lthy T_name of
   1.216 +          SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
   1.217 +          if fp = fp' then
   1.218 +            let
   1.219 +              val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts';
   1.220 +              val _ =
   1.221 +                seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
   1.222 +                not_mutually_nested_rec mutual_Ts seen;
   1.223 +              val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
   1.224 +            in
   1.225 +              mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
   1.226 +            end
   1.227 +          else
   1.228 +            not_co_datatype T
   1.229 +        | NONE => not_co_datatype T)
   1.230 +      | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
   1.231 +
   1.232 +    val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
   1.233 +
   1.234 +    val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   1.235 +    val Ts = actual_Ts @ missing_Ts;
   1.236 +
   1.237 +    val nn = length Ts;
   1.238 +    val kks = 0 upto nn - 1;
   1.239 +
   1.240 +    val common_name = mk_common_name (map Binding.name_of actual_bs);
   1.241 +    val bs = pad_list (Binding.name common_name) nn actual_bs;
   1.242 +
   1.243 +    fun permute xs = permute_like (op =) Ts perm_Ts xs;
   1.244 +    fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   1.245 +
   1.246 +    val perm_bs = permute bs;
   1.247 +    val perm_kks = permute kks;
   1.248 +    val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   1.249 +
   1.250 +    val mutualize = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts;
   1.251 +    val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0;
   1.252 +
   1.253 +    val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   1.254 +
   1.255 +    val ((nontriv, perm_fp_sugars), lthy) =
   1.256 +      mutualize_fp_sugars lose_co_rec mutualize fp perm_bs perm_Ts get_perm_indices perm_callssss
   1.257 +        perm_fp_sugars0 lthy;
   1.258 +
   1.259 +    val fp_sugars = unpermute perm_fp_sugars;
   1.260 +  in
   1.261 +    ((nontriv, missing_Ts, perm_kks, fp_sugars), lthy)
   1.262 +  end;
   1.263 +
   1.264 +end;