added gen_primrec
authorhaftmann
Mon, 02 Oct 2006 23:01:00 +0200
changeset 20841 d4f94d2a3414
parent 20840 5e92606245b6
child 20842 f5f69a1059f4
added gen_primrec
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Mon Oct 02 23:00:58 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Mon Oct 02 23:01:00 2006 +0200
@@ -8,7 +8,6 @@
 signature PRIMREC_PACKAGE =
 sig
   val quiet_mode: bool ref
-  val mk_combdefs: theory -> term list -> (string * term) list
   val add_primrec: string -> ((bstring * string) * Attrib.src list) list
     -> theory -> thm list * theory
   val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list
@@ -17,6 +16,10 @@
     -> theory -> thm list * theory
   val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
     -> theory -> thm list * theory
+  val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory)
+    -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory)
+    -> string -> ((bstring * attribute list) * term) list
+    -> theory -> thm list * theory;
 end;
 
 structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -225,10 +228,13 @@
     |> RuleCases.save induction
   end;
 
-fun mk_defs thy eqns =
+local
+
+fun gen_primrec_i note def alt_name eqns_atts thy =
   let
+    val (eqns, atts) = split_list eqns_atts;
     val dt_info = DatatypePackage.get_datatypes thy;
-    val rec_eqns = foldr (process_eqn thy) [] eqns;
+    val rec_eqns = foldr (process_eqn thy) [] (map snd eqns);
     val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
     val dts = find_dts dt_info tnames tnames;
     val main_fns = 
@@ -244,42 +250,37 @@
       foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
     val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
     val defs' = map (make_def thy fs) defs;
-  in (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') end;
-
-fun mk_combdefs thy =
-  #6 o mk_defs thy o map (ObjectLogic.ensure_propT thy);
-
-fun gen_primrec_i unchecked alt_name eqns_atts thy =
-  let
-    val (eqns, atts) = split_list eqns_atts;
-    val (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') =
-      mk_defs thy (map snd eqns);
     val nameTs1 = map snd fnameTs;
     val nameTs2 = map fst rec_eqns;
+    val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
+            else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
+              "\nare not mutually recursive");
     val primrec_name =
       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
-    val (defs_thms', thy') = thy |> Theory.add_path primrec_name |>
-      (if eq_set (nameTs1, nameTs2) then
-         ((if unchecked then PureThy.add_defs_unchecked_i else PureThy.add_defs_i) false o
-           map Thm.no_attributes) defs'
-       else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
-         "\nare not mutually recursive"));
-    val rewrites = (map mk_meta_eq rec_rewrites) @ defs_thms';
+    val (defs_thms', thy') =
+      thy
+      |> Theory.add_path primrec_name
+      |> fold_map def (map (fn (name, t) => ((name, []), t)) defs');
+    val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms';
     val _ = message ("Proving equations for primrec function(s) " ^
       commas_quote (map fst nameTs1) ^ " ...");
     val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t
         (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns;
-    val (simps', thy'') = thy' |> PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts);
+    val (simps', thy'') =
+      thy'
+      |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps);
+    val simps'' = maps snd simps';
   in
     thy''
-    |> (snd o PureThy.add_thmss [(("simps", simps'),
-        [Simplifier.simp_add, RecfunCodegen.add NONE])])
-    |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
+    |> note (("simps", [Simplifier.simp_add, RecfunCodegen.add NONE]), simps'')
+    |> snd
+    |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns])
+    |> snd
     |> Theory.parent_path
-    |> pair simps'
+    |> pair simps''
   end;
 
-fun gen_primrec unchecked alt_name eqns thy =
+fun gen_primrec note def alt_name eqns thy =
   let
     val ((names, strings), srcss) = apfst split_list (split_list eqns);
     val atts = map (map (Attrib.attribute thy)) srcss;
@@ -289,13 +290,26 @@
       handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
     val (_, eqn_ts') = InductivePackage.unify_consts thy rec_ts eqn_ts
   in
-    gen_primrec_i unchecked alt_name (names ~~ eqn_ts' ~~ atts) thy
+    gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy
   end;
 
-val add_primrec = gen_primrec false;
-val add_primrec_unchecked = gen_primrec true;
-val add_primrec_i = gen_primrec_i false;
-val add_primrec_unchecked_i = gen_primrec_i true;
+fun thy_note ((name, atts), thms) =
+  PureThy.add_thmss [((name, thms), atts)] #-> (fn [thms] => pair (name, thms));
+fun thy_def false ((name, atts), t) =
+      PureThy.add_defs_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm))
+  | thy_def true ((name, atts), t) =
+      PureThy.add_defs_unchecked_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm));
+
+in
+
+val add_primrec = gen_primrec thy_note (thy_def false);
+val add_primrec_unchecked = gen_primrec thy_note (thy_def true);
+val add_primrec_i = gen_primrec_i thy_note (thy_def false);
+val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true);
+fun gen_primrec note def alt_name specs =
+  gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs);
+
+end; (*local*)
 
 
 (* outer syntax *)