--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Tue Oct 01 15:02:12 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Tue Oct 01 17:04:27 2013 +0200
@@ -7,8 +7,17 @@
signature BNF_FP_REC_SUGAR =
sig
+ val add_primrec: (binding * typ option * mixfix) list ->
+ (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
val add_primrec_cmd: (binding * string option * mixfix) list ->
- (Attrib.binding * string) list -> local_theory -> local_theory;
+ (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
+ val add_primrec_global: (binding * typ option * mixfix) list ->
+ (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
+ val add_primrec_overloaded: (string * (string * typ) * bool) list ->
+ (binding * typ option * mixfix) list ->
+ (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
+ val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
+ local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
val add_primcorecursive_cmd: bool ->
(binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
Proof.context -> Proof.state
@@ -31,8 +40,9 @@
val selN = "sel"
val nitpick_attrs = @{attributes [nitpick_simp]};
-val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
val simp_attrs = @{attributes [simp]};
+val code_nitpick_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
+val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs @ simp_attrs;
exception Primrec_Error of string * term list;
@@ -300,11 +310,11 @@
|> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
end;
-fun add_primrec fixes specs lthy =
+fun prepare_primrec fixes specs lthy =
let
val (bs, mxs) = map_split (apfst fst) fixes;
val fun_names = map Binding.name_of bs;
- val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
+ val eqns_data = map (dissect_eqn lthy fun_names) specs;
val funs_data = eqns_data
|> partition_eq ((op =) o pairself #fun_name)
|> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
@@ -330,52 +340,51 @@
val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
- fun prove def_thms' ({nested_map_idents, nested_map_comps, ctr_specs, ...} : rec_spec)
- induct_thm (fun_data : eqn_data list) lthy =
+ fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
+ (fun_data : eqn_data list) =
let
- val fun_name = #fun_name (hd fun_data);
val def_thms = map (snd o snd) def_thms';
- val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
+ val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
|> fst
|> map_filter (try (fn (x, [y]) =>
(#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
|> map (fn (user_eqn, num_extra_args, rec_thm) =>
mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
- |> K |> Goal.prove lthy [] [] user_eqn)
+ |> K |> Goal.prove lthy [] [] user_eqn);
+ val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
+ in
+ (poss, simp_thmss)
+ end;
- val notes =
- [(inductN, if n2m then [induct_thm] else [], []),
- (simpsN, simp_thms, code_nitpick_simp_attrs @ simp_attrs)]
- |> filter_out (null o #2)
- |> map (fn (thmN, thms, attrs) =>
- ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]));
- in
- lthy |> Local_Theory.notes notes
- end;
+ val notes =
+ (if n2m then map2 (fn name => fn thm =>
+ (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
+ |> map (fn (prefix, thmN, thms, attrs) =>
+ ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
val common_name = mk_common_name fun_names;
val common_notes =
- [(inductN, if n2m then [induct_thm] else [], [])]
- |> filter_out (null o #2)
+ (if n2m then [(inductN, [induct_thm], [])] else [])
|> map (fn (thmN, thms, attrs) =>
((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
in
- lthy'
- |> fold_map Local_Theory.define defs
- |-> snd oo (fn def_thms' => fold_map3 (prove def_thms') (take actual_nn rec_specs)
- (take actual_nn induct_thms) funs_data)
- |> Local_Theory.notes common_notes |> snd
+ (((fun_names, defs),
+ fn lthy => fn defs =>
+ split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
+ lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
end;
-fun add_primrec_cmd raw_fixes raw_specs lthy =
+(* primrec definition *)
+
+fun add_primrec_simple fixes ts lthy =
let
- val _ = let val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) in null d orelse
- primrec_error ("duplicate function name(s): " ^ commas d) end;
- val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
+ val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
+ handle ERROR str => primrec_error str;
in
- add_primrec fixes specs lthy
- handle ERROR str => primrec_error str
+ lthy
+ |> fold_map Local_Theory.define defs
+ |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
end
handle Primrec_Error (str, eqns) =>
if null eqns
@@ -383,6 +392,56 @@
else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^
space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
+local
+
+fun gen_primrec prep_spec raw_fixes raw_spec lthy =
+ let
+ val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
+ val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
+
+ val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
+
+ val mk_notes =
+ flat ooo map3 (fn poss => fn prefix => fn thms =>
+ let
+ val (bs, attrss) = map_split (fst o nth specs) poss;
+ val notes =
+ map3 (fn b => fn attrs => fn thm =>
+ ((Binding.qualify false prefix b, code_nitpick_simp_attrs @ attrs), [([thm], [])]))
+ bs attrss thms;
+ in
+ ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
+ end);
+ in
+ lthy
+ |> add_primrec_simple fixes (map snd specs)
+ |-> (fn (names, (ts, (posss, simpss))) =>
+ Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
+ #> Local_Theory.notes (mk_notes posss names simpss)
+ #>> pair ts o map snd)
+ end;
+
+in
+
+val add_primrec = gen_primrec Specification.check_spec;
+val add_primrec_cmd = gen_primrec Specification.read_spec;
+
+end;
+
+fun add_primrec_global fixes specs thy =
+ let
+ val lthy = Named_Target.theory_init thy;
+ val ((ts, simps), lthy') = add_primrec fixes specs lthy;
+ val simps' = burrow (Proof_Context.export lthy' lthy) simps;
+ in ((ts, simps'), Local_Theory.exit_global lthy') end;
+
+fun add_primrec_overloaded ops fixes specs thy =
+ let
+ val lthy = Overloading.overloading ops thy;
+ val ((ts, simps), lthy') = add_primrec fixes specs lthy;
+ val simps' = burrow (Proof_Context.export lthy' lthy) simps;
+ in ((ts, simps'), Local_Theory.exit_global lthy') end;
+
(* Primcorec *)
@@ -875,7 +934,7 @@
val notes =
[(coinductN, map (if n2m then single else K []) coinduct_thms, []),
- (codeN, ctr_thmss(*FIXME*), code_nitpick_simp_attrs),
+ (codeN, ctr_thmss(*FIXME*), code_nitpick_attrs),
(ctrN, ctr_thmss, []),
(discN, disc_thmss, simp_attrs),
(selN, sel_thmss, simp_attrs),