add_primrec_simple
authorhaftmann
Tue, 26 May 2009 17:29:34 +0200
changeset 31262 580510315dda
parent 31261 900ebbc35e30
child 31263 4dbe0b4c313b
add_primrec_simple
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Tue May 26 17:29:33 2009 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Tue May 26 17:29:34 2009 +0200
@@ -16,6 +16,8 @@
   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
     (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> theory -> thm list * theory
+  val add_primrec_simple: ((binding * typ) * mixfix) list -> (binding * term) list ->
+    local_theory -> (string * thm list list) * local_theory
 end;
 
 structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -211,22 +213,12 @@
             else find_dts dt_info tnames' tnames);
 
 
-(* primrec definition *)
+(* distill primitive definition(s) from primrec specification *)
 
-local
-
-fun prove_spec ctxt names rec_rewrites defs eqs =
+fun distill lthy fixes eqs = 
   let
-    val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
-    fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
-    val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
-  in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end;
-
-fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
-  let
-    val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
-      orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec [];
+      orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
     val tnames = distinct (op =) (map (#1 o snd) eqns);
     val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames;
     val main_fns = map (fn (tname, {index, ...}) =>
@@ -236,31 +228,59 @@
         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
       else snd (hd dts);
     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
-    val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
-    val names1 = map snd fnames;
-    val names2 = map fst eqns;
-    val _ = if gen_eq_set (op =) (names1, names2) then ()
-      else primrec_error ("functions " ^ commas_quote names2 ^
+    val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
+    val defs = map (make_def lthy fixes fs) raw_defs;
+    val names = map snd fnames;
+    val names_eqns = map fst eqns;
+    val _ = if gen_eq_set (op =) (names, names_eqns) then ()
+      else primrec_error ("functions " ^ commas_quote names_eqns ^
         "\nare not mutually recursive");
-    val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs);
-    val qualify = Binding.qualify false prefix;
-    val spec' = (map o apfst)
-      (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
-    val simp_atts = map (Attrib.internal o K)
-      [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
+    val rec_rewrites' = map mk_meta_eq rec_rewrites;
+    val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
+    fun prove lthy defs =
+      let
+        val rewrites = rec_rewrites' @ map (snd o snd) defs;
+        fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
+        val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
+      in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end;
+  in ((prefix, (fs, defs)), prove) end
+  handle PrimrecError (msg, some_eqn) =>
+    error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
+     of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
+      | NONE => ""));
+
+
+(* primrec definition *)
+
+fun add_primrec_simple fixes spec lthy =
+  let
+    val ((prefix, (fs, defs)), prove) = distill lthy fixes (map snd spec);
+  in
+    lthy
+    |> fold_map (LocalTheory.define Thm.definitionK) defs
+    |-> (fn defs => `(fn lthy => (prefix, prove lthy defs)))
+  end;
+
+local
+
+fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
+  let
+    val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
+    fun attr_bindings prefix = map (fn ((b, attrs), _) =>
+      (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
+    fun simp_attr_binding prefix = (Binding.qualify false prefix (Binding.name "simps"),
+      map (Attrib.internal o K)
+        [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]);
   in
     lthy
     |> set_group ? LocalTheory.set_group (serial_string ())
-    |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
-    |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec'))
-    |-> (fn simps => fold_map (LocalTheory.note Thm.generatedK) simps)
-    |-> (fn simps' => LocalTheory.note Thm.generatedK
-          ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps'))
+    |> add_primrec_simple fixes spec
+    |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK)
+          (attr_bindings prefix ~~ simps)
+    #-> (fn simps' => LocalTheory.note Thm.generatedK
+          (simp_attr_binding prefix, maps snd simps')))
     |>> snd
-  end handle PrimrecError (msg, some_eqn) =>
-    error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
-     of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
-      | NONE => ""));
+  end;
 
 in