src/HOL/Codatatype/Tools/bnf_wrap.ML
changeset 49336 a2e6473145e4
parent 49311 56fcd826f90c
child 49364 838b5e8ede73
equal deleted inserted replaced
49335:096967bf3940 49336:a2e6473145e4
     5 Wrapping existing datatypes.
     5 Wrapping existing datatypes.
     6 *)
     6 *)
     7 
     7 
     8 signature BNF_WRAP =
     8 signature BNF_WRAP =
     9 sig
     9 sig
    10   val no_binder: binding
    10   val no_binding: binding
    11   val mk_half_pairss: 'a list -> ('a * 'a) list list
    11   val mk_half_pairss: 'a list -> ('a * 'a) list list
    12   val mk_ctr: typ list -> term -> term
    12   val mk_ctr: typ list -> term -> term
    13   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    13   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    14     ((bool * term list) * term) *
    14     ((bool * term list) * term) *
    15       (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
    15       (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
    43 val selsN = "sels";
    43 val selsN = "sels";
    44 val splitN = "split";
    44 val splitN = "split";
    45 val split_asmN = "split_asm";
    45 val split_asmN = "split_asm";
    46 val weak_case_cong_thmsN = "weak_case_cong";
    46 val weak_case_cong_thmsN = "weak_case_cong";
    47 
    47 
    48 val no_binder = @{binding ""};
    48 val no_binding = @{binding ""};
    49 val std_binder = @{binding _};
    49 val std_binding = @{binding _};
    50 
    50 
    51 val induct_simp_attrs = @{attributes [induct_simp]};
    51 val induct_simp_attrs = @{attributes [induct_simp]};
    52 val cong_attrs = @{attributes [cong]};
    52 val cong_attrs = @{attributes [cong]};
    53 val iff_attrs = @{attributes [iff]};
    53 val iff_attrs = @{attributes [iff]};
    54 val safe_elim_attrs = @{attributes [elim!]};
    54 val safe_elim_attrs = @{attributes [elim!]};
    78       Const (s, _) => s
    78       Const (s, _) => s
    79     | Free (s, _) => s
    79     | Free (s, _) => s
    80     | _ => error "Cannot extract name of constructor");
    80     | _ => error "Cannot extract name of constructor");
    81 
    81 
    82 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
    82 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
    83     (raw_disc_binders, (raw_sel_binderss, raw_sel_defaultss))) no_defs_lthy =
    83     (raw_disc_bindings, (raw_sel_bindingss, raw_sel_defaultss))) no_defs_lthy =
    84   let
    84   let
    85     (* TODO: sanity checks on arguments *)
    85     (* TODO: sanity checks on arguments *)
    86     (* TODO: attributes (simp, case_names, etc.) *)
    86     (* TODO: attributes (simp, case_names, etc.) *)
    87     (* TODO: case syntax *)
    87     (* TODO: case syntax *)
    88     (* TODO: integration with function package ("size") *)
    88     (* TODO: integration with function package ("size") *)
   109     val ctrs = map (mk_ctr As) ctrs0;
   109     val ctrs = map (mk_ctr As) ctrs0;
   110     val ctr_Tss = map (binder_types o fastype_of) ctrs;
   110     val ctr_Tss = map (binder_types o fastype_of) ctrs;
   111 
   111 
   112     val ms = map length ctr_Tss;
   112     val ms = map length ctr_Tss;
   113 
   113 
   114     val raw_disc_binders' = pad_list no_binder n raw_disc_binders;
   114     val raw_disc_bindings' = pad_list no_binding n raw_disc_bindings;
   115 
   115 
   116     fun can_really_rely_on_disc k =
   116     fun can_really_rely_on_disc k =
   117       not (Binding.eq_name (nth raw_disc_binders' (k - 1), no_binder)) orelse nth ms (k - 1) = 0;
   117       not (Binding.eq_name (nth raw_disc_bindings' (k - 1), no_binding)) orelse nth ms (k - 1) = 0;
   118     fun can_rely_on_disc k =
   118     fun can_rely_on_disc k =
   119       can_really_rely_on_disc k orelse (k = 1 andalso not (can_really_rely_on_disc 2));
   119       can_really_rely_on_disc k orelse (k = 1 andalso not (can_really_rely_on_disc 2));
   120     fun can_omit_disc_binder k m =
   120     fun can_omit_disc_binding k m =
   121       n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (3 - k));
   121       n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (3 - k));
   122 
   122 
   123     val std_disc_binder =
   123     val std_disc_binding =
   124       Binding.qualify false (Binding.name_of data_b) o Binding.name o prefix isN o base_name_of_ctr;
   124       Binding.qualify false (Binding.name_of data_b) o Binding.name o prefix isN o base_name_of_ctr;
   125 
   125 
   126     val disc_binders =
   126     val disc_bindings =
   127       raw_disc_binders'
   127       raw_disc_bindings'
   128       |> map4 (fn k => fn m => fn ctr => fn disc =>
   128       |> map4 (fn k => fn m => fn ctr => fn disc =>
   129         Option.map (Binding.qualify false (Binding.name_of data_b))
   129         Option.map (Binding.qualify false (Binding.name_of data_b))
   130           (if Binding.eq_name (disc, no_binder) then
   130           (if Binding.eq_name (disc, no_binding) then
   131              if can_omit_disc_binder k m then NONE else SOME (std_disc_binder ctr)
   131              if can_omit_disc_binding k m then NONE else SOME (std_disc_binding ctr)
   132            else if Binding.eq_name (disc, std_binder) then
   132            else if Binding.eq_name (disc, std_binding) then
   133              SOME (std_disc_binder ctr)
   133              SOME (std_disc_binding ctr)
   134            else
   134            else
   135              SOME disc)) ks ms ctrs0;
   135              SOME disc)) ks ms ctrs0;
   136 
   136 
   137     val no_discs = map is_none disc_binders;
   137     val no_discs = map is_none disc_bindings;
   138     val no_discs_at_all = forall I no_discs;
   138     val no_discs_at_all = forall I no_discs;
   139 
   139 
   140     fun std_sel_binder m l = Binding.name o mk_unN m l o base_name_of_ctr;
   140     fun std_sel_binding m l = Binding.name o mk_unN m l o base_name_of_ctr;
   141 
   141 
   142     val sel_binderss =
   142     val sel_bindingss =
   143       pad_list [] n raw_sel_binderss
   143       pad_list [] n raw_sel_bindingss
   144       |> map3 (fn ctr => fn m => map2 (fn l => fn sel =>
   144       |> map3 (fn ctr => fn m => map2 (fn l => fn sel =>
   145         Binding.qualify false (Binding.name_of data_b)
   145         Binding.qualify false (Binding.name_of data_b)
   146           (if Binding.eq_name (sel, no_binder) orelse Binding.eq_name (sel, std_binder) then
   146           (if Binding.eq_name (sel, no_binding) orelse Binding.eq_name (sel, std_binding) then
   147             std_sel_binder m l ctr
   147             std_sel_binding m l ctr
   148           else
   148           else
   149             sel)) (1 upto m) o pad_list no_binder m) ctrs0 ms;
   149             sel)) (1 upto m) o pad_list no_binding m) ctrs0 ms;
   150 
   150 
   151     fun mk_case Ts T =
   151     fun mk_case Ts T =
   152       let
   152       let
   153         val (binders, body) = strip_type (fastype_of case0)
   153         val (bindings, body) = strip_type (fastype_of case0)
   154         val Type (_, Ts0) = List.last binders
   154         val Type (_, Ts0) = List.last bindings
   155       in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) case0 end;
   155       in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) case0 end;
   156 
   156 
   157     val casex = mk_case As B;
   157     val casex = mk_case As B;
   158     val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
   158     val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
   159 
   159 
   189     val unique_disc_no_def = TrueI; (*arbitrary marker*)
   189     val unique_disc_no_def = TrueI; (*arbitrary marker*)
   190     val alternate_disc_no_def = FalseE; (*arbitrary marker*)
   190     val alternate_disc_no_def = FalseE; (*arbitrary marker*)
   191 
   191 
   192     fun alternate_disc_lhs get_disc k =
   192     fun alternate_disc_lhs get_disc k =
   193       HOLogic.mk_not
   193       HOLogic.mk_not
   194         (case nth disc_binders (k - 1) of
   194         (case nth disc_bindings (k - 1) of
   195           NONE => nth exist_xs_v_eq_ctrs (k - 1)
   195           NONE => nth exist_xs_v_eq_ctrs (k - 1)
   196         | SOME b => get_disc b (k - 1) $ v);
   196         | SOME b => get_disc b (k - 1) $ v);
   197 
   197 
   198     val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') =
   198     val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') =
   199       if no_dests then
   199       if no_dests then
   235             in
   235             in
   236               mk_Trueprop_eq (Free (Binding.name_of b, dataT --> T) $ v,
   236               mk_Trueprop_eq (Free (Binding.name_of b, dataT --> T) $ v,
   237                 Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v)
   237                 Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v)
   238             end;
   238             end;
   239 
   239 
   240           val sel_binders = flat sel_binderss;
   240           val sel_bindings = flat sel_bindingss;
   241           val uniq_sel_binders = distinct Binding.eq_name sel_binders;
   241           val uniq_sel_bindings = distinct Binding.eq_name sel_bindings;
   242           val all_sels_distinct = (length uniq_sel_binders = length sel_binders);
   242           val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings);
   243 
   243 
   244           val sel_binder_index =
   244           val sel_binding_index =
   245             if all_sels_distinct then 1 upto length sel_binders
   245             if all_sels_distinct then 1 upto length sel_bindings
   246             else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_binders) sel_binders;
   246             else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings;
   247 
   247 
   248           val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
   248           val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss);
   249           val sel_infos =
   249           val sel_infos =
   250             AList.group (op =) (sel_binder_index ~~ proto_sels)
   250             AList.group (op =) (sel_binding_index ~~ proto_sels)
   251             |> sort (int_ord o pairself fst)
   251             |> sort (int_ord o pairself fst)
   252             |> map snd |> curry (op ~~) uniq_sel_binders;
   252             |> map snd |> curry (op ~~) uniq_sel_bindings;
   253           val sel_binders = map fst sel_infos;
   253           val sel_bindings = map fst sel_infos;
   254 
   254 
   255           fun unflat_selss xs = unflat_lookup Binding.eq_name sel_binders xs sel_binderss;
   255           fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss;
   256 
   256 
   257           val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
   257           val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
   258             no_defs_lthy
   258             no_defs_lthy
   259             |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr =>
   259             |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr =>
   260               fn NONE =>
   260               fn NONE =>
   261                  if n = 1 then pair (Term.lambda v (mk_v_eq_v ()), unique_disc_no_def)
   261                  if n = 1 then pair (Term.lambda v (mk_v_eq_v ()), unique_disc_no_def)
   262                  else if m = 0 then pair (Term.lambda v exist_xs_v_eq_ctr, refl)
   262                  else if m = 0 then pair (Term.lambda v exist_xs_v_eq_ctr, refl)
   263                  else pair (alternate_disc k, alternate_disc_no_def)
   263                  else pair (alternate_disc k, alternate_disc_no_def)
   264                | SOME b => Specification.definition (SOME (b, NONE, NoSyn),
   264                | SOME b => Specification.definition (SOME (b, NONE, NoSyn),
   265                    ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
   265                    ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
   266               ks ms exist_xs_v_eq_ctrs disc_binders
   266               ks ms exist_xs_v_eq_ctrs disc_bindings
   267             ||>> apfst split_list o fold_map (fn (b, proto_sels) =>
   267             ||>> apfst split_list o fold_map (fn (b, proto_sels) =>
   268               Specification.definition (SOME (b, NONE, NoSyn),
   268               Specification.definition (SOME (b, NONE, NoSyn),
   269                 ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_infos
   269                 ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_infos
   270             ||> `Local_Theory.restore;
   270             ||> `Local_Theory.restore;
   271 
   271