--- a/src/HOL/Tools/primrec_package.ML Tue May 26 17:29:33 2009 +0200
+++ b/src/HOL/Tools/primrec_package.ML Tue May 26 17:29:34 2009 +0200
@@ -16,6 +16,8 @@
val add_primrec_overloaded: (string * (string * typ) * bool) list ->
(binding * typ option * mixfix) list ->
(Attrib.binding * term) list -> theory -> thm list * theory
+ val add_primrec_simple: ((binding * typ) * mixfix) list -> (binding * term) list ->
+ local_theory -> (string * thm list list) * local_theory
end;
structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -211,22 +213,12 @@
else find_dts dt_info tnames' tnames);
-(* primrec definition *)
+(* distill primitive definition(s) from primrec specification *)
-local
-
-fun prove_spec ctxt names rec_rewrites defs eqs =
+fun distill lthy fixes eqs =
let
- val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
- fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
- val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
- in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end;
-
-fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
- let
- val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
- orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec [];
+ orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
val tnames = distinct (op =) (map (#1 o snd) eqns);
val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames;
val main_fns = map (fn (tname, {index, ...}) =>
@@ -236,31 +228,59 @@
("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
else snd (hd dts);
val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
- val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
- val names1 = map snd fnames;
- val names2 = map fst eqns;
- val _ = if gen_eq_set (op =) (names1, names2) then ()
- else primrec_error ("functions " ^ commas_quote names2 ^
+ val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
+ val defs = map (make_def lthy fixes fs) raw_defs;
+ val names = map snd fnames;
+ val names_eqns = map fst eqns;
+ val _ = if gen_eq_set (op =) (names, names_eqns) then ()
+ else primrec_error ("functions " ^ commas_quote names_eqns ^
"\nare not mutually recursive");
- val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs);
- val qualify = Binding.qualify false prefix;
- val spec' = (map o apfst)
- (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
- val simp_atts = map (Attrib.internal o K)
- [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
+ val rec_rewrites' = map mk_meta_eq rec_rewrites;
+ val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
+ fun prove lthy defs =
+ let
+ val rewrites = rec_rewrites' @ map (snd o snd) defs;
+ fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
+ val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
+ in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end;
+ in ((prefix, (fs, defs)), prove) end
+ handle PrimrecError (msg, some_eqn) =>
+ error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
+ of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
+ | NONE => ""));
+
+
+(* primrec definition *)
+
+fun add_primrec_simple fixes spec lthy =
+ let
+ val ((prefix, (fs, defs)), prove) = distill lthy fixes (map snd spec);
+ in
+ lthy
+ |> fold_map (LocalTheory.define Thm.definitionK) defs
+ |-> (fn defs => `(fn lthy => (prefix, prove lthy defs)))
+ end;
+
+local
+
+fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
+ let
+ val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
+ fun attr_bindings prefix = map (fn ((b, attrs), _) =>
+ (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
+ fun simp_attr_binding prefix = (Binding.qualify false prefix (Binding.name "simps"),
+ map (Attrib.internal o K)
+ [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]);
in
lthy
|> set_group ? LocalTheory.set_group (serial_string ())
- |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
- |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec'))
- |-> (fn simps => fold_map (LocalTheory.note Thm.generatedK) simps)
- |-> (fn simps' => LocalTheory.note Thm.generatedK
- ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps'))
+ |> add_primrec_simple fixes spec
+ |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK)
+ (attr_bindings prefix ~~ simps)
+ #-> (fn simps' => LocalTheory.note Thm.generatedK
+ (simp_attr_binding prefix, maps snd simps')))
|>> snd
- end handle PrimrecError (msg, some_eqn) =>
- error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
- of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
- | NONE => ""));
+ end;
in