src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
changeset 53303 ae49b835ca01
child 53329 c31c0c311cf0
equal deleted inserted replaced
53302:98fdf6c34142 53303:ae49b835ca01
       
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
       
     2     Author:     Lorenz Panny, TU Muenchen
       
     3     Author:     Jasmin Blanchette, TU Muenchen
       
     4     Copyright   2013
       
     5 
       
     6 Library for recursor and corecursor sugar.
       
     7 *)
       
     8 
       
     9 signature BNF_FP_REC_SUGAR_UTIL =
       
    10 sig
       
    11   datatype rec_call =
       
    12     No_Rec of int |
       
    13     Direct_Rec of int (*before*) * int (*after*) |
       
    14     Indirect_Rec of int
       
    15 
       
    16   datatype corec_call =
       
    17     Dummy_No_Corec of int |
       
    18     No_Corec of int |
       
    19     Direct_Corec of int (*stop?*) * int (*end*) * int (*continue*) |
       
    20     Indirect_Corec of int
       
    21 
       
    22   type rec_ctr_spec =
       
    23     {ctr: term,
       
    24      offset: int,
       
    25      calls: rec_call list,
       
    26      rec_thm: thm}
       
    27 
       
    28   type corec_ctr_spec =
       
    29     {ctr: term,
       
    30      disc: term,
       
    31      sels: term list,
       
    32      pred: int option,
       
    33      calls: corec_call list,
       
    34      corec_thm: thm}
       
    35 
       
    36   type rec_spec =
       
    37     {recx: term,
       
    38      nested_map_id's: thm list,
       
    39      nested_map_comps: thm list,
       
    40      ctr_specs: rec_ctr_spec list}
       
    41 
       
    42   type corec_spec =
       
    43     {corec: term,
       
    44      ctr_specs: corec_ctr_spec list}
       
    45 
       
    46   val massage_indirect_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
       
    47     typ list -> term -> term -> term -> term
       
    48   val massage_direct_corec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
       
    49     typ list -> typ -> term -> term
       
    50   val massage_indirect_corec_call: Proof.context -> (term -> bool) ->
       
    51     (typ -> typ -> term -> term) -> typ list -> typ -> term -> term
       
    52   val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
       
    53     ((term * term list list) list) list -> local_theory ->
       
    54     (bool * rec_spec list * typ list * thm * thm list) * local_theory
       
    55   val corec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
       
    56     ((term * term list list) list) list -> local_theory ->
       
    57     (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory
       
    58 end;
       
    59 
       
    60 structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL =
       
    61 struct
       
    62 
       
    63 open BNF_Util
       
    64 open BNF_Def
       
    65 open BNF_Ctr_Sugar
       
    66 open BNF_FP_Util
       
    67 open BNF_FP_Def_Sugar
       
    68 open BNF_FP_N2M_Sugar
       
    69 
       
    70 datatype rec_call =
       
    71   No_Rec of int |
       
    72   Direct_Rec of int * int |
       
    73   Indirect_Rec of int;
       
    74 
       
    75 datatype corec_call =
       
    76   Dummy_No_Corec of int |
       
    77   No_Corec of int |
       
    78   Direct_Corec of int * int * int |
       
    79   Indirect_Corec of int;
       
    80 
       
    81 type rec_ctr_spec =
       
    82   {ctr: term,
       
    83    offset: int,
       
    84    calls: rec_call list,
       
    85    rec_thm: thm};
       
    86 
       
    87 type corec_ctr_spec =
       
    88   {ctr: term,
       
    89    disc: term,
       
    90    sels: term list,
       
    91    pred: int option,
       
    92    calls: corec_call list,
       
    93    corec_thm: thm};
       
    94 
       
    95 type rec_spec =
       
    96   {recx: term,
       
    97    nested_map_id's: thm list,
       
    98    nested_map_comps: thm list,
       
    99    ctr_specs: rec_ctr_spec list};
       
   100 
       
   101 type corec_spec =
       
   102   {corec: term,
       
   103    ctr_specs: corec_ctr_spec list};
       
   104 
       
   105 val id_def = @{thm id_def};
       
   106 
       
   107 exception AINT_NO_MAP of term;
       
   108 
       
   109 fun ill_formed_rec_call ctxt t =
       
   110   error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
       
   111 fun ill_formed_corec_call ctxt t =
       
   112   error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
       
   113 fun invalid_map ctxt t =
       
   114   error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
       
   115 fun unexpected_rec_call ctxt t =
       
   116   error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
       
   117 fun unexpected_corec_call ctxt t =
       
   118   error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
       
   119 
       
   120 fun factor_out_types ctxt massage destU U T =
       
   121   (case try destU U of
       
   122     SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
       
   123   | NONE => invalid_map ctxt);
       
   124 
       
   125 fun map_flattened_map_args ctxt s map_args fs =
       
   126   let
       
   127     val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
       
   128     val flat_fs' = map_args flat_fs;
       
   129   in
       
   130     permute_like (op aconv) flat_fs fs flat_fs'
       
   131   end;
       
   132 
       
   133 fun massage_indirect_rec_call ctxt has_call massage_unapplied_direct_call bound_Ts y y' =
       
   134   let
       
   135     val typof = curry fastype_of1 bound_Ts;
       
   136     val build_map_fst = build_map ctxt (fst_const o fst);
       
   137 
       
   138     val yT = typof y;
       
   139     val yU = typof y';
       
   140 
       
   141     fun y_of_y' () = build_map_fst (yU, yT) $ y';
       
   142     val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
       
   143 
       
   144     fun check_and_massage_unapplied_direct_call U T t =
       
   145       if has_call t then
       
   146         factor_out_types ctxt massage_unapplied_direct_call HOLogic.dest_prodT U T t
       
   147       else
       
   148         HOLogic.mk_comp (t, build_map_fst (U, T));
       
   149 
       
   150     fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
       
   151         (case try (dest_map ctxt s) t of
       
   152           SOME (map0, fs) =>
       
   153           let
       
   154             val Type (_, ran_Ts) = range_type (typof t);
       
   155             val map' = mk_map (length fs) Us ran_Ts map0;
       
   156             val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
       
   157           in
       
   158             list_comb (map', fs')
       
   159           end
       
   160         | NONE => raise AINT_NO_MAP t)
       
   161       | massage_map _ _ t = raise AINT_NO_MAP t
       
   162     and massage_map_or_map_arg U T t =
       
   163       if T = U then
       
   164         if has_call t then unexpected_rec_call ctxt t else t
       
   165       else
       
   166         massage_map U T t
       
   167         handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
       
   168 
       
   169     fun massage_call (t as t1 $ t2) =
       
   170         if t2 = y then
       
   171           massage_map yU yT (elim_y t1) $ y'
       
   172           handle AINT_NO_MAP t' => invalid_map ctxt t'
       
   173         else
       
   174           ill_formed_rec_call ctxt t
       
   175       | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
       
   176   in
       
   177     massage_call o Envir.beta_eta_contract
       
   178   end;
       
   179 
       
   180 fun massage_let_and_if ctxt has_call massage_rec massage_else U T t =
       
   181   (case Term.strip_comb t of
       
   182     (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec U T (betapply (arg2, arg1))
       
   183   | (Const (@{const_name If}, _), arg :: args) =>
       
   184     if has_call arg then unexpected_corec_call ctxt arg
       
   185     else list_comb (If_const U $ arg, map (massage_rec U T) args)
       
   186   | _ => massage_else U T t);
       
   187 
       
   188 fun massage_direct_corec_call ctxt has_call massage_direct_call bound_Ts res_U t =
       
   189   let
       
   190     val typof = curry fastype_of1 bound_Ts;
       
   191 
       
   192     fun massage_call U T =
       
   193       massage_let_and_if ctxt has_call massage_call massage_direct_call U T;
       
   194   in
       
   195     massage_call res_U (typof t) (Envir.beta_eta_contract t)
       
   196   end;
       
   197 
       
   198 fun massage_indirect_corec_call ctxt has_call massage_direct_call bound_Ts res_U t =
       
   199   let
       
   200     val typof = curry fastype_of1 bound_Ts;
       
   201     val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o fst);
       
   202 
       
   203     fun check_and_massage_direct_call U T t =
       
   204       if has_call t then factor_out_types ctxt massage_direct_call dest_sumT U T t
       
   205       else build_map_Inl (U, T) $ t;
       
   206 
       
   207     fun check_and_massage_unapplied_direct_call U T t =
       
   208       let val var = Var ((Name.uu, Term.maxidx_of_term t + 1), domain_type (typof t)) in
       
   209         Term.lambda var (check_and_massage_direct_call U T (t $ var))
       
   210       end;
       
   211 
       
   212     fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
       
   213         (case try (dest_map ctxt s) t of
       
   214           SOME (map0, fs) =>
       
   215           let
       
   216             val Type (_, dom_Ts) = domain_type (typof t);
       
   217             val map' = mk_map (length fs) dom_Ts Us map0;
       
   218             val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
       
   219           in
       
   220             list_comb (map', fs')
       
   221           end
       
   222         | NONE => raise AINT_NO_MAP t)
       
   223       | massage_map _ _ t = raise AINT_NO_MAP t
       
   224     and massage_map_or_map_arg U T t =
       
   225       if T = U then
       
   226         if has_call t then unexpected_corec_call ctxt t else t
       
   227       else
       
   228         massage_map U T t
       
   229         handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
       
   230 
       
   231     fun massage_call U T =
       
   232       massage_let_and_if ctxt has_call massage_call
       
   233         (fn U => fn T => fn t =>
       
   234             (case U of
       
   235               Type (s, Us) =>
       
   236               (case try (dest_ctr ctxt s) t of
       
   237                 SOME (f, args) =>
       
   238                 let val f' = mk_ctr Us f in
       
   239                   list_comb (f', map3 massage_call (binder_types (typof f')) (map typof args) args)
       
   240                 end
       
   241               | NONE =>
       
   242                 (case t of
       
   243                   t1 $ t2 =>
       
   244                   if has_call t2 then
       
   245                     check_and_massage_direct_call U T t
       
   246                   else
       
   247                     massage_map U T t1 $ t2
       
   248                     handle AINT_NO_MAP _ => check_and_massage_direct_call U T t
       
   249                 | _ => check_and_massage_direct_call U T t))
       
   250             | _ => ill_formed_corec_call ctxt t))
       
   251         U T
       
   252   in
       
   253     massage_call res_U (typof t) (Envir.beta_eta_contract t)
       
   254   end;
       
   255 
       
   256 fun indexed xs h = let val h' = h + length xs in (h upto h' - 1, h') end;
       
   257 fun indexedd xss = fold_map indexed xss;
       
   258 fun indexeddd xsss = fold_map indexedd xsss;
       
   259 fun indexedddd xssss = fold_map indexeddd xssss;
       
   260 
       
   261 fun find_index_eq hs h = find_index (curry (op =) h) hs;
       
   262 
       
   263 val lose_co_rec = false (*FIXME: try true?*);
       
   264 
       
   265 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
       
   266   let
       
   267     val thy = Proof_Context.theory_of lthy;
       
   268 
       
   269     val ((nontriv, missing_arg_Ts, perm0_kks,
       
   270           fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
       
   271             co_inducts = [induct_thm], ...} :: _), lthy') =
       
   272       nested_to_mutual_fps lose_co_rec Least_FP bs arg_Ts get_indices callssss0 lthy;
       
   273 
       
   274     val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
       
   275 
       
   276     val indices = map #index fp_sugars;
       
   277     val perm_indices = map #index perm_fp_sugars;
       
   278 
       
   279     val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
       
   280     val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
       
   281     val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
       
   282 
       
   283     val nn0 = length arg_Ts;
       
   284     val nn = length perm_fpTs;
       
   285     val kks = 0 upto nn - 1;
       
   286     val perm_ns = map length perm_ctr_Tsss;
       
   287     val perm_mss = map (map length) perm_ctr_Tsss;
       
   288 
       
   289     val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
       
   290       perm_fp_sugars;
       
   291     val perm_fun_arg_Tssss = mk_iter_fun_arg_types perm_Cs perm_ns perm_mss (co_rec_of ctor_iters1);
       
   292 
       
   293     fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
       
   294     fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
       
   295 
       
   296     val induct_thms = unpermute0 (conj_dests nn induct_thm);
       
   297 
       
   298     val fpTs = unpermute perm_fpTs;
       
   299     val Cs = unpermute perm_Cs;
       
   300 
       
   301     val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts;
       
   302     val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
       
   303 
       
   304     val substA = Term.subst_TVars As_rho;
       
   305     val substAT = Term.typ_subst_TVars As_rho;
       
   306     val substCT = Term.typ_subst_TVars Cs_rho;
       
   307 
       
   308     val perm_Cs' = map substCT perm_Cs;
       
   309 
       
   310     fun offset_of_ctr 0 _ = 0
       
   311       | offset_of_ctr n ({ctrs, ...} :: ctr_sugars) =
       
   312         length ctrs + offset_of_ctr (n - 1) ctr_sugars;
       
   313 
       
   314     fun call_of [i] [T] = (if exists_subtype_in Cs T then Indirect_Rec else No_Rec) i
       
   315       | call_of [i, i'] _ = Direct_Rec (i, i');
       
   316 
       
   317     fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
       
   318       let
       
   319         val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
       
   320         val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
       
   321         val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
       
   322       in
       
   323         {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
       
   324          rec_thm = rec_thm}
       
   325       end;
       
   326 
       
   327     fun mk_ctr_specs index ctr_sugars iter_thmsss =
       
   328       let
       
   329         val ctrs = #ctrs (nth ctr_sugars index);
       
   330         val rec_thmss = co_rec_of (nth iter_thmsss index);
       
   331         val k = offset_of_ctr index ctr_sugars;
       
   332         val n = length ctrs;
       
   333       in
       
   334         map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
       
   335       end;
       
   336 
       
   337     fun mk_spec {T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...} =
       
   338       {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
       
   339        nested_map_id's = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
       
   340        nested_map_comps = map map_comp_of_bnf nested_bnfs,
       
   341        ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
       
   342   in
       
   343     ((nontriv, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), lthy')
       
   344   end;
       
   345 
       
   346 fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
       
   347   let
       
   348     val thy = Proof_Context.theory_of lthy;
       
   349 
       
   350     val ((nontriv, missing_res_Ts, perm0_kks,
       
   351           fp_sugars as {fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...},
       
   352           co_inducts = coinduct_thms, ...} :: _), lthy') =
       
   353       nested_to_mutual_fps lose_co_rec Greatest_FP bs res_Ts get_indices callssss0 lthy;
       
   354 
       
   355     val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
       
   356 
       
   357     val indices = map #index fp_sugars;
       
   358     val perm_indices = map #index perm_fp_sugars;
       
   359 
       
   360     val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
       
   361     val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
       
   362     val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
       
   363 
       
   364     val nn0 = length res_Ts;
       
   365     val nn = length perm_fpTs;
       
   366     val kks = 0 upto nn - 1;
       
   367     val perm_ns = map length perm_ctr_Tsss;
       
   368     val perm_mss = map (map length) perm_ctr_Tsss;
       
   369 
       
   370     val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
       
   371       of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
       
   372     val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
       
   373       mk_coiter_fun_arg_types perm_Cs perm_ns perm_mss (co_rec_of dtor_coiters1);
       
   374 
       
   375     val (perm_p_hss, h) = indexedd perm_p_Tss 0;
       
   376     val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;
       
   377     val (perm_f_hssss, _) = indexedddd perm_f_Tssss h';
       
   378 
       
   379     val fun_arg_hs =
       
   380       flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
       
   381 
       
   382     fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
       
   383     fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
       
   384 
       
   385     val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
       
   386 
       
   387     val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss);
       
   388     val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss);
       
   389     val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss);
       
   390 
       
   391     val f_Tssss = unpermute perm_f_Tssss;
       
   392     val fpTs = unpermute perm_fpTs;
       
   393     val Cs = unpermute perm_Cs;
       
   394 
       
   395     val As_rho = tvar_subst thy (take nn0 fpTs) res_Ts;
       
   396     val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts;
       
   397 
       
   398     val substA = Term.subst_TVars As_rho;
       
   399     val substAT = Term.typ_subst_TVars As_rho;
       
   400     val substCT = Term.typ_subst_TVars Cs_rho;
       
   401 
       
   402     val perm_Cs' = map substCT perm_Cs;
       
   403 
       
   404     fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] =
       
   405         (if exists_subtype_in Cs T then Indirect_Corec
       
   406          else if nullary then Dummy_No_Corec
       
   407          else No_Corec) g_i
       
   408       | call_of _ [q_i] [g_i, g_i'] _ = Direct_Corec (q_i, g_i, g_i');
       
   409 
       
   410     fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss corec_thm =
       
   411       let val nullary = not (can dest_funT (fastype_of ctr)) in
       
   412         {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho,
       
   413          calls = map3 (call_of nullary) q_iss f_iss f_Tss, corec_thm = corec_thm}
       
   414       end;
       
   415 
       
   416     fun mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss =
       
   417       let
       
   418         val ctrs = #ctrs (nth ctr_sugars index);
       
   419         val discs = #discs (nth ctr_sugars index);
       
   420         val selss = #selss (nth ctr_sugars index);
       
   421         val p_ios = map SOME p_is @ [NONE];
       
   422         val corec_thmss = co_rec_of (nth coiter_thmsss index);
       
   423       in
       
   424         map8 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss corec_thmss
       
   425       end;
       
   426 
       
   427     fun mk_spec {T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss, ...}
       
   428         p_is q_isss f_isss f_Tsss =
       
   429       {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)),
       
   430        ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss};
       
   431   in
       
   432     ((nontriv, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
       
   433       co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss,
       
   434       strong_co_induct_of coinduct_thmss), lthy')
       
   435   end;
       
   436 
       
   437 end;