src/HOL/Codatatype/Tools/bnf_wrap.ML
changeset 49121 9e0acaa470ab
parent 49120 7f8e69fc6ac9
child 49122 83515378d4d7
equal deleted inserted replaced
49120:7f8e69fc6ac9 49121:9e0acaa470ab
     6 *)
     6 *)
     7 
     7 
     8 signature BNF_WRAP =
     8 signature BNF_WRAP =
     9 sig
     9 sig
    10   val no_name: binding
    10   val no_name: binding
    11   val wrap: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    11   val mk_half_pairss: 'a list -> ('a * 'a) list list
       
    12   val wrap_data: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    12     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
    13     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
    13 end;
    14 end;
    14 
    15 
    15 structure BNF_Wrap : BNF_WRAP =
    16 structure BNF_Wrap : BNF_WRAP =
    16 struct
    17 struct
    60   case head_of t of
    61   case head_of t of
    61     Const (s, _) => s
    62     Const (s, _) => s
    62   | Free (s, _) => s
    63   | Free (s, _) => s
    63   | _ => error "Cannot extract name of constructor";
    64   | _ => error "Cannot extract name of constructor";
    64 
    65 
    65 fun prepare_wrap prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
    66 fun prepare_wrap_data prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
    66   no_defs_lthy =
    67   no_defs_lthy =
    67   let
    68   let
    68     (* TODO: sanity checks on arguments *)
    69     (* TODO: sanity checks on arguments *)
    69     (* TODO: attributes (simp, case_names, etc.) *)
    70     (* TODO: attributes (simp, case_names, etc.) *)
    70     (* TODO: case syntax *)
    71     (* TODO: case syntax *)
    74     val caseof0 = prep_term no_defs_lthy raw_caseof;
    75     val caseof0 = prep_term no_defs_lthy raw_caseof;
    75 
    76 
    76     val n = length ctrs0;
    77     val n = length ctrs0;
    77     val ks = 1 upto n;
    78     val ks = 1 upto n;
    78 
    79 
    79     val (T_name, As0) = dest_Type (body_type (fastype_of (hd ctrs0)));
    80     val _ = if n > 0 then () else error "No constructors specified";
       
    81 
       
    82     val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
    80     val b = Binding.qualified_name T_name;
    83     val b = Binding.qualified_name T_name;
    81 
    84 
    82     val (As, B) =
    85     val (As, B) =
    83       no_defs_lthy
    86       no_defs_lthy
    84       |> mk_TFrees (length As0)
    87       |> mk_TFrees (length As0)
    85       ||> the_single o fst o mk_TFrees 1;
    88       ||> the_single o fst o mk_TFrees 1;
    86 
    89 
    87     fun mk_ctr Ts ctr =
    90     fun mk_ctr Ts ctr =
    88       let val Ts0 = snd (dest_Type (body_type (fastype_of ctr))) in
    91       let val Type (_, Ts0) = body_type (fastype_of ctr) in
    89         Term.subst_atomic_types (Ts0 ~~ Ts) ctr
    92         Term.subst_atomic_types (Ts0 ~~ Ts) ctr
    90       end;
    93       end;
    91 
    94 
    92     val T = Type (T_name, As);
    95     val T = Type (T_name, As);
    93     val ctrs = map (mk_ctr As) ctrs0;
    96     val ctrs = map (mk_ctr As) ctrs0;
   125           fallback_sel_name m l ctr
   128           fallback_sel_name m l ctr
   126         else
   129         else
   127           sel) (1 upto m) o pad_list no_name m) ctrs0 ms;
   130           sel) (1 upto m) o pad_list no_name m) ctrs0 ms;
   128 
   131 
   129     fun mk_caseof Ts T =
   132     fun mk_caseof Ts T =
   130       let val (binders, body) = strip_type (fastype_of caseof0) in
   133       let
   131         Term.subst_atomic_types ((body, T) :: (snd (dest_Type (List.last binders)) ~~ Ts)) caseof0
   134         val (binders, body) = strip_type (fastype_of caseof0)
   132       end;
   135         val Type (_, Ts0) = List.last binders
       
   136       in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) caseof0 end;
   133 
   137 
   134     val caseofB = mk_caseof As B;
   138     val caseofB = mk_caseof As B;
   135     val caseofB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
   139     val caseofB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
   136 
   140 
   137     fun mk_caseofB_term eta_fs = Term.list_comb (caseofB, eta_fs);
   141     fun mk_caseofB_term eta_fs = Term.list_comb (caseofB, eta_fs);
   205 
   209 
   206     val discs0 = map (Morphism.term phi) raw_discs;
   210     val discs0 = map (Morphism.term phi) raw_discs;
   207     val selss0 = map (map (Morphism.term phi)) raw_selss;
   211     val selss0 = map (map (Morphism.term phi)) raw_selss;
   208 
   212 
   209     fun mk_disc_or_sel Ts t =
   213     fun mk_disc_or_sel Ts t =
   210       Term.subst_atomic_types (snd (dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
   214       Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
   211 
   215 
   212     val discs = map (mk_disc_or_sel As) discs0;
   216     val discs = map (mk_disc_or_sel As) discs0;
   213     val selss = map (map (mk_disc_or_sel As)) selss0;
   217     val selss = map (map (mk_disc_or_sel As)) selss0;
   214 
   218 
   215     fun mk_imp_p Qs = Logic.list_implies (Qs, HOLogic.mk_Trueprop p);
   219     fun mk_imp_p Qs = Logic.list_implies (Qs, HOLogic.mk_Trueprop p);
   216 
   220 
   217     val goal_exhaust =
   221     val goal_exhaust =
   218       let fun mk_prem xctr xs = fold_rev Logic.all xs (mk_imp_p [mk_Trueprop_eq (v, xctr)]) in
   222       let fun mk_prem xctr xs = fold_rev Logic.all xs (mk_imp_p [mk_Trueprop_eq (v, xctr)]) in
   219         mk_imp_p (map2 mk_prem xctrs xss)
   223         fold_rev Logic.all [p, v] (mk_imp_p (map2 mk_prem xctrs xss))
   220       end;
   224       end;
   221 
   225 
   222     val goal_injectss =
   226     val goal_injectss =
   223       let
   227       let
   224         fun mk_goal _ _ [] [] = []
   228         fun mk_goal _ _ [] [] = []
   225           | mk_goal xctr yctr xs ys =
   229           | mk_goal xctr yctr xs ys =
   226             [mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
   230             [fold_rev Logic.all (xs @ ys) (mk_Trueprop_eq (HOLogic.mk_eq (xctr, yctr),
   227               Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys))];
   231               Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) xs ys)))];
   228       in
   232       in
   229         map4 mk_goal xctrs yctrs xss yss
   233         map4 mk_goal xctrs yctrs xss yss
   230       end;
   234       end;
   231 
   235 
   232     val goal_half_distinctss =
   236     val goal_half_distinctss =
   233       map (map (HOLogic.mk_Trueprop o HOLogic.mk_not o HOLogic.mk_eq)) (mk_half_pairss xctrs);
   237       let
   234 
   238         fun mk_goal ((xs, t), (xs', t')) =
   235     val goal_cases = map2 (fn xctr => fn xf => mk_Trueprop_eq (caseofB_fs $ xctr, xf)) xctrs xfs;
   239           fold_rev Logic.all (xs @ xs')
   236 
   240             (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (t, t'))));
   237     val goals = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
   241       in
       
   242         map (map mk_goal) (mk_half_pairss (xss ~~ xctrs))
       
   243       end;
       
   244 
       
   245     val goal_cases =
       
   246       map3 (fn xs => fn xctr => fn xf =>
       
   247         fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (caseofB_fs $ xctr, xf))) xss xctrs xfs;
       
   248 
       
   249     val goalss = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
   238 
   250 
   239     fun after_qed thmss lthy =
   251     fun after_qed thmss lthy =
   240       let
   252       let
   241         val ([exhaust_thm], (inject_thmss, (half_distinct_thmss, [case_thms]))) =
   253         val ([exhaust_thm], (inject_thmss, (half_distinct_thmss, [case_thms]))) =
   242           (hd thmss, apsnd (chop (n * n)) (chop n (tl thmss)));
   254           (hd thmss, apsnd (chop (n * n)) (chop n (tl thmss)));
   354           if has_not_other_disc_def orelse forall I no_discs then
   366           if has_not_other_disc_def orelse forall I no_discs then
   355             []
   367             []
   356           else
   368           else
   357             let
   369             let
   358               fun mk_prem disc = mk_imp_p [HOLogic.mk_Trueprop (betapply (disc, v))];
   370               fun mk_prem disc = mk_imp_p [HOLogic.mk_Trueprop (betapply (disc, v))];
   359               val goal = fold Logic.all [p, v] (mk_imp_p (map mk_prem discs));
   371               val goal = fold_rev Logic.all [p, v] (mk_imp_p (map mk_prem discs));
   360             in
   372             in
   361               [Skip_Proof.prove lthy [] [] goal (fn _ =>
   373               [Skip_Proof.prove lthy [] [] goal (fn _ =>
   362                  mk_disc_exhaust_tac n exhaust_thm discI_thms)]
   374                  mk_disc_exhaust_tac n exhaust_thm discI_thms)]
   363             end;
   375             end;
   364 
   376 
   453            (discsN, disc_thms),
   465            (discsN, disc_thms),
   454            (disc_exclusN, disc_exclus_thms),
   466            (disc_exclusN, disc_exclus_thms),
   455            (disc_exhaustN, disc_exhaust_thms),
   467            (disc_exhaustN, disc_exhaust_thms),
   456            (distinctN, distinct_thms),
   468            (distinctN, distinct_thms),
   457            (exhaustN, [exhaust_thm]),
   469            (exhaustN, [exhaust_thm]),
   458            (injectN, (flat inject_thmss)),
   470            (injectN, flat inject_thmss),
   459            (nchotomyN, [nchotomy_thm]),
   471            (nchotomyN, [nchotomy_thm]),
   460            (selsN, (flat sel_thmss)),
   472            (selsN, flat sel_thmss),
   461            (splitN, [split_thm]),
   473            (splitN, [split_thm]),
   462            (split_asmN, [split_asm_thm]),
   474            (split_asmN, [split_asm_thm]),
   463            (weak_case_cong_thmsN, [weak_case_cong_thm])]
   475            (weak_case_cong_thmsN, [weak_case_cong_thm])]
   464           |> filter_out (null o snd)
   476           |> filter_out (null o snd)
   465           |> map (fn (thmN, thms) =>
   477           |> map (fn (thmN, thms) =>
   466             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]));
   478             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]));
   467       in
   479       in
   468         lthy |> Local_Theory.notes notes |> snd
   480         lthy |> Local_Theory.notes notes |> snd
   469       end;
   481       end;
   470   in
   482   in
   471     (goals, after_qed, lthy')
   483     (goalss, after_qed, lthy')
   472   end;
   484   end;
   473 
   485 
   474 fun wrap tacss = (fn (goalss, after_qed, lthy) =>
   486 fun wrap_data tacss = (fn (goalss, after_qed, lthy) =>
   475   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   487   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   476   |> (fn thms => after_qed thms lthy)) oo
   488   |> (fn thms => after_qed thms lthy)) oo
   477   prepare_wrap (singleton o Type_Infer_Context.infer_types)
   489   prepare_wrap_data (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
   478 
   490 
   479 val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
   491 val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
   480 val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
   492 val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
   481 
   493 
   482 val wrap_data_cmd = (fn (goalss, after_qed, lthy) =>
   494 val wrap_data_cmd = (fn (goalss, after_qed, lthy) =>
   483   Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
   495   Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
   484   prepare_wrap Syntax.read_term;
   496   prepare_wrap_data Syntax.read_term;
   485 
   497 
   486 val _ =
   498 val _ =
   487   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
   499   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
   488     (((Parse.$$$ "[" |-- Parse.list Parse.term --| Parse.$$$ "]") -- Parse.term --
   500     (((Parse.$$$ "[" |-- Parse.list Parse.term --| Parse.$$$ "]") -- Parse.term --
   489       Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))
   501       Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))