src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
changeset 55772 367ec44763fd
parent 55702 63c80031d8dd
child 55803 74d3fe9031d8
equal deleted inserted replaced
55771:a421f1ccfc9f 55772:367ec44763fd
     8 signature BNF_FP_N2M_SUGAR =
     8 signature BNF_FP_N2M_SUGAR =
     9 sig
     9 sig
    10   val unfold_lets_splits: term -> term
    10   val unfold_lets_splits: term -> term
    11   val dest_map: Proof.context -> string -> term -> term * term list
    11   val dest_map: Proof.context -> string -> term -> term * term list
    12 
    12 
    13   val mutualize_fp_sugars: BNF_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    13   val mutualize_fp_sugars: BNF_Util.fp_kind -> binding list -> typ list -> term list ->
    14     term list list list list -> BNF_FP_Def_Sugar.fp_sugar list -> local_theory ->
    14     term list list list list -> BNF_FP_Def_Sugar.fp_sugar list -> local_theory ->
    15     (BNF_FP_Def_Sugar.fp_sugar list
    15     (BNF_FP_Def_Sugar.fp_sugar list
    16      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    16      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    17     * local_theory
    17     * local_theory
    18   val nested_to_mutual_fps: BNF_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    18   val nested_to_mutual_fps: BNF_Util.fp_kind -> binding list -> typ list -> term list ->
    19     (term * term list list) list list -> local_theory ->
    19     (term * term list list) list list -> local_theory ->
    20     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    20     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    21      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    21      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    22     * local_theory
    22     * local_theory
    23 end;
    23 end;
   101     xs ([], ([], []));
   101     xs ([], ([], []));
   102 
   102 
   103 fun key_of_fp_eqs fp fpTs fp_eqs =
   103 fun key_of_fp_eqs fp fpTs fp_eqs =
   104   Type (fp_case fp "l" "g", fpTs @ maps (fn (x, T) => [TFree x, T]) fp_eqs);
   104   Type (fp_case fp "l" "g", fpTs @ maps (fn (x, T) => [TFree x, T]) fp_eqs);
   105 
   105 
       
   106 fun get_indices callers t =
       
   107   callers
       
   108   |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
       
   109   |> map_filter I;
       
   110 
   106 (* TODO: test with sort constraints on As *)
   111 (* TODO: test with sort constraints on As *)
   107 fun mutualize_fp_sugars fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
   112 fun mutualize_fp_sugars fp bs fpTs callers callssss fp_sugars0 no_defs_lthy0 =
   108   let
   113   let
   109     val thy = Proof_Context.theory_of no_defs_lthy0;
   114     val thy = Proof_Context.theory_of no_defs_lthy0;
   110 
   115 
   111     val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
   116     val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
   112 
   117 
   113     fun incompatible_calls ts =
   118     fun incompatible_calls ts =
   114       error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts));
   119       error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts));
       
   120     fun mutual_self_call caller t =
       
   121       error ("Unsupported mutual self-call " ^ qsotm t ^ " from " ^ qsotm caller);
   115     fun nested_self_call t =
   122     fun nested_self_call t =
   116       error ("Unsupported nested self-call " ^ qsotm t);
   123       error ("Unsupported nested self-call " ^ qsotm t);
   117 
   124 
   118     val b_names = map Binding.name_of bs;
   125     val b_names = map Binding.name_of bs;
   119     val fp_b_names = map base_name_of_typ fpTs;
   126     val fp_b_names = map base_name_of_typ fpTs;
   144       |> fold Variable.declare_typ As
   151       |> fold Variable.declare_typ As
   145       |> mk_TFrees nn
   152       |> mk_TFrees nn
   146       ||>> variant_tfrees fp_b_names;
   153       ||>> variant_tfrees fp_b_names;
   147 
   154 
   148     fun check_call_dead live_call call =
   155     fun check_call_dead live_call call =
   149       if null (get_indices call) then () else incompatible_calls [live_call, call];
   156       if null (get_indices callers call) then () else incompatible_calls [live_call, call];
   150 
   157 
   151     fun freeze_fpTs_type_based_default (T as Type (s, Ts)) =
   158     fun freeze_fpTs_type_based_default (T as Type (s, Ts)) =
   152         (case filter (curry (op =) T o snd) (map_index I fpTs) of
   159         (case filter (curry (op =) T o snd) (map_index I fpTs) of
   153           [(kk, _)] => nth Xs kk
   160           [(kk, _)] => nth Xs kk
   154         | _ => Type (s, map freeze_fpTs_type_based_default Ts))
   161         | _ => Type (s, map freeze_fpTs_type_based_default Ts))
   155       | freeze_fpTs_type_based_default T = T;
   162       | freeze_fpTs_type_based_default T = T;
   156 
   163 
   157     fun freeze_fpTs_mutual_call calls T =
   164     fun freeze_fpTs_mutual_call kk fpT calls T =
   158       (case fold (union (op =)) (map get_indices calls) [] of
   165       (case fold (union (op =)) (map (get_indices callers) calls) [] of
   159         [] => freeze_fpTs_type_based_default T
   166         [] => if T = fpT then nth Xs kk else freeze_fpTs_type_based_default T
   160       | [kk] => nth Xs kk
   167       | [kk'] =>
       
   168         if T = fpT andalso kk' <> kk then
       
   169           mutual_self_call (nth callers kk)
       
   170             (the (find_first (not o null o get_indices callers) calls))
       
   171         else
       
   172           nth Xs kk'
   161       | _ => incompatible_calls calls);
   173       | _ => incompatible_calls calls);
   162 
   174 
   163     fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
   175     fun freeze_fpTs_map kk (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
   164         (Type (s, Ts)) =
   176         (Type (s, Ts)) =
   165       if Ts' = Ts then
   177       if Ts' = Ts then
   166         nested_self_call live_call
   178         nested_self_call live_call
   167       else
   179       else
   168         (List.app (check_call_dead live_call) dead_calls;
   180         (List.app (check_call_dead live_call) dead_calls;
   169          Type (s, map2 (freeze_fpTs_call fpT)
   181          Type (s, map2 (freeze_fpTs_call kk fpT)
   170            (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) [] (transpose callss)) Ts))
   182            (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) [] (transpose callss)) Ts))
   171     and freeze_fpTs_call fpT calls (T as Type (s, _)) =
   183     and freeze_fpTs_call kk fpT calls (T as Type (s, _)) =
   172         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
   184         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
   173           ([], _) =>
   185           ([], _) =>
   174           (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
   186           (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
   175             ([], _) => freeze_fpTs_mutual_call calls T
   187             ([], _) => freeze_fpTs_mutual_call kk fpT calls T
   176           | callsp => freeze_fpTs_map fpT callsp T)
   188           | callsp => freeze_fpTs_map kk fpT callsp T)
   177         | callsp => freeze_fpTs_map fpT callsp T)
   189         | callsp => freeze_fpTs_map kk fpT callsp T)
   178       | freeze_fpTs_call _ _ T = T;
   190       | freeze_fpTs_call _ _ _ T = T;
   179 
   191 
   180     val ctr_Tsss = map (map binder_types) ctr_Tss;
   192     val ctr_Tsss = map (map binder_types) ctr_Tss;
   181     val ctrXs_Tsss = map3 (map2 o map2 o freeze_fpTs_call) fpTs callssss ctr_Tsss;
   193     val ctrXs_Tsss = map4 (map2 o map2 oo freeze_fpTs_call) kks fpTs callssss ctr_Tsss;
   182     val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   194     val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   183 
   195 
   184     val ns = map length ctr_Tsss;
   196     val ns = map length ctr_Tsss;
   185     val kss = map (fn n => 1 upto n) ns;
   197     val kss = map (fn n => 1 upto n) ns;
   186     val mss = map (map length) ctr_Tsss;
   198     val mss = map (map length) ctr_Tsss;
   252                sel_corec_thmsss))
   264                sel_corec_thmsss))
   253             ||> (fn info => (NONE, SOME info));
   265             ||> (fn info => (NONE, SOME info));
   254 
   266 
   255         val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   267         val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   256 
   268 
   257         fun mk_target_fp_sugar T kk pre_bnf ctr_defs ctr_sugar co_iters maps co_inducts un_fold_thms
   269         fun mk_target_fp_sugar T X kk pre_bnf ctrXs_Tss ctr_defs ctr_sugar co_iters maps co_inducts
   258             co_rec_thms disc_unfold_thms disc_corec_thms sel_unfold_thmss sel_corec_thmss =
   270             un_fold_thms co_rec_thms disc_unfold_thms disc_corec_thms sel_unfold_thmss
   259           {T = T, fp = fp, fp_res = fp_res, fp_res_index = kk, pre_bnf = pre_bnf,
   271             sel_corec_thmss =
   260            nested_bnfs = nested_bnfs, nesting_bnfs = nesting_bnfs, ctr_defs = ctr_defs,
   272           {T = T, X = X, fp = fp, fp_res = fp_res, fp_res_index = kk, pre_bnf = pre_bnf,
   261            ctr_sugar = ctr_sugar, co_iters = co_iters, maps = maps,
   273            nested_bnfs = nested_bnfs, nesting_bnfs = nesting_bnfs, ctrXs_Tss = ctrXs_Tss,
       
   274            ctr_defs = ctr_defs, ctr_sugar = ctr_sugar, co_iters = co_iters, maps = maps,
   262            common_co_inducts = common_co_inducts, co_inducts = co_inducts,
   275            common_co_inducts = common_co_inducts, co_inducts = co_inducts,
   263            co_iter_thmss = [un_fold_thms, co_rec_thms],
   276            co_iter_thmss = [un_fold_thms, co_rec_thms],
   264            disc_co_iterss = [disc_unfold_thms, disc_corec_thms],
   277            disc_co_iterss = [disc_unfold_thms, disc_corec_thms],
   265            sel_co_itersss = [sel_unfold_thmss, sel_corec_thmss]}
   278            sel_co_itersss = [sel_unfold_thmss, sel_corec_thmss]}
   266           |> morph_fp_sugar phi;
   279           |> morph_fp_sugar phi;
   267 
   280 
   268         val target_fp_sugars =
   281         val target_fp_sugars =
   269           map14 mk_target_fp_sugar fpTs kks pre_bnfs ctr_defss ctr_sugars co_iterss mapss
   282           map16 mk_target_fp_sugar fpTs Xs kks pre_bnfs ctrXs_Tsss ctr_defss ctr_sugars co_iterss
   270             (transpose co_inductss) un_fold_thmss co_rec_thmss disc_unfold_thmss disc_corec_thmss
   283             mapss (transpose co_inductss) un_fold_thmss co_rec_thmss disc_unfold_thmss
   271             sel_unfold_thmsss sel_corec_thmsss;
   284             disc_corec_thmss sel_unfold_thmsss sel_corec_thmsss;
   272 
   285 
   273         val n2m_sugar = (target_fp_sugars, fp_sugar_thms);
   286         val n2m_sugar = (target_fp_sugars, fp_sugar_thms);
   274       in
   287       in
   275         (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
   288         (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
   276       end)
   289       end)
   290 
   303 
   291 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
   304 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
   292     f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
   305     f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
   293   | fold_subtype_pairs f TU = f TU;
   306   | fold_subtype_pairs f TU = f TU;
   294 
   307 
   295 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   308 val impossible_caller = Bound ~1;
       
   309 
       
   310 fun nested_to_mutual_fps fp actual_bs actual_Ts actual_callers actual_callssss0 lthy =
   296   let
   311   let
   297     val qsoty = quote o Syntax.string_of_typ lthy;
   312     val qsoty = quote o Syntax.string_of_typ lthy;
   298     val qsotys = space_implode " or " o map qsoty;
   313     val qsotys = space_implode " or " o map qsoty;
   299 
   314 
   300     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   315     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   389 
   404 
   390     val callssss0 = pad_list [] nn actual_callssss0;
   405     val callssss0 = pad_list [] nn actual_callssss0;
   391 
   406 
   392     val common_name = mk_common_name (map Binding.name_of actual_bs);
   407     val common_name = mk_common_name (map Binding.name_of actual_bs);
   393     val bs = pad_list (Binding.name common_name) nn actual_bs;
   408     val bs = pad_list (Binding.name common_name) nn actual_bs;
       
   409     val callers = pad_list impossible_caller nn actual_callers;
   394 
   410 
   395     fun permute xs = permute_like (op =) Ts perm_Ts xs;
   411     fun permute xs = permute_like (op =) Ts perm_Ts xs;
   396     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   412     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   397 
   413 
   398     val perm_bs = permute bs;
   414     val perm_bs = permute bs;
       
   415     val perm_callers = permute callers;
   399     val perm_kks = permute kks;
   416     val perm_kks = permute kks;
   400     val perm_callssss0 = permute callssss0;
   417     val perm_callssss0 = permute callssss0;
   401     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   418     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   402 
   419 
   403     val perm_callssss = map2 (indexify_callsss o #ctrs o #ctr_sugar) perm_fp_sugars0 perm_callssss0;
   420     val perm_callssss = map2 (indexify_callsss o #ctrs o #ctr_sugar) perm_fp_sugars0 perm_callssss0;
   404 
   421 
   405     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
       
   406 
       
   407     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   422     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   408       if num_groups > 1 then
   423       if num_groups > 1 then
   409         mutualize_fp_sugars fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
   424         mutualize_fp_sugars fp perm_bs perm_frozen_gen_Ts perm_callers perm_callssss perm_fp_sugars0
   410           perm_fp_sugars0 lthy
   425           lthy
   411       else
   426       else
   412         ((perm_fp_sugars0, (NONE, NONE)), lthy);
   427         ((perm_fp_sugars0, (NONE, NONE)), lthy);
   413 
   428 
   414     val fp_sugars = unpermute perm_fp_sugars;
   429     val fp_sugars = unpermute perm_fp_sugars;
   415   in
   430   in