add_primrec(_i): attributes;
authorwenzelm
Thu, 11 Mar 1999 21:58:54 +0100
changeset 6359 6fdb0badc6f4
parent 6358 92dbe243555f
child 6360 83573ae0f22c
add_primrec(_i): attributes; outer syntax for 'primrec';
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Thu Mar 11 21:58:12 1999 +0100
+++ b/src/HOL/Tools/primrec_package.ML	Thu Mar 11 21:58:54 1999 +0100
@@ -3,15 +3,19 @@
     Author:     Stefan Berghofer and Norbert Voelker
     Copyright   1998  TU Muenchen
 
-Package for defining functions on datatypes by primitive recursion
+Package for defining functions on datatypes by primitive recursion.
+
+TODO:
+  - add_primrec(_i): improve prep of args;
+  - quiet_mode (!?);
 *)
 
 signature PRIMREC_PACKAGE =
 sig
-  val add_primrec_i : string -> (string * term) list ->
-    theory -> theory * thm list
-  val add_primrec : string -> (string * string) list ->
-    theory -> theory * thm list
+  val add_primrec: string -> ((string * string) * Args.src list) list
+    -> theory -> theory * thm list
+  val add_primrec_i: string -> ((string * term) * theory attribute list) list
+    -> theory -> theory * thm list
 end;
 
 structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -25,6 +29,7 @@
 fun primrec_eq_err sign s eq =
   primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq);
 
+
 (* preprocessing of equations *)
 
 fun process_eqn sign (eq, rec_fns) = 
@@ -161,6 +166,7 @@
         else raise RecError ("inconsistent functions for datatype " ^ tname))
   end;
 
+
 (* prepare functions needed for definitions *)
 
 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
@@ -176,6 +182,7 @@
        end
    | Some (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
 
+
 (* make definition *)
 
 fun make_def sign fs (fname, ls, rec_name, tname) =
@@ -190,6 +197,7 @@
     inferT_axm sign defpair
   end;
 
+
 (* find datatypes which contain all datatypes in tnames' *)
 
 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
@@ -201,8 +209,9 @@
               (tname, dt)::(find_dts dt_info tnames' tnames)
             else find_dts dt_info tnames' tnames);
 
-fun add_primrec_i alt_name eqns thy =
+fun add_primrec_i alt_name eqns_atts thy =
   let
+    val (eqns, atts) = split_list eqns_atts;
     val sg = sign_of thy;
     val dt_info = DatatypePackage.get_datatypes thy;
     val rec_eqns = foldr (process_eqn sg) (map snd eqns, []);
@@ -215,7 +224,7 @@
 	dts;
     val {descr, rec_names, rec_rewrites, ...} = 
 	if null dts then
-	    primrec_err ("datatypes " ^ commas tnames ^ 
+	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
 			 "\nare not mutually recursive")
 	else snd (hd dts);
     val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
@@ -228,24 +237,41 @@
       Theory.add_path (if alt_name = "" then (space_implode "_"
         (map (Sign.base_name o #1) defs)) else alt_name) |>
       (if eq_set (names1, names2) then Theory.add_defs_i defs'
-       else primrec_err ("functions " ^ commas names2 ^
+       else primrec_err ("functions " ^ commas_quote names2 ^
          "\nare not mutually recursive"));
     val rewrites = (map mk_meta_eq rec_rewrites) @ (map (get_axiom thy' o fst) defs');
     val _ = writeln ("Proving equations for primrec function(s)\n" ^
-      commas names1 ^ " ...");
+      commas_quote names1 ^ " ...");
     val char_thms = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (sign_of thy') t)
         (fn _ => [rtac refl 1])) eqns;
     val simps = char_thms;
-    val thy'' = thy' |>
-      PureThy.add_thmss [(("simps", simps), [Simplifier.simp_add_global])] |>
-      PureThy.add_thms (map (rpair [])
-        (filter_out (equal "" o fst) (map fst eqns ~~ simps))) |>
-      Theory.parent_path;
-  in
-    (thy'', char_thms)
-  end;
+    val thy'' =
+      thy'
+      |> PureThy.add_thmss [(("simps", simps), [Simplifier.simp_add_global])]
+      |> PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts)
+      |> Theory.parent_path;
+  in (thy'', char_thms) end;
+
+
+fun read_eqn thy ((name, s), srcs) =
+  ((name, readtm (sign_of thy) propT s), map (Attrib.global_attribute thy) srcs);
+
+fun add_primrec alt_name eqns thy = add_primrec_i alt_name (map (read_eqn thy) eqns) thy;
+
 
-fun add_primrec alt_name eqns thy =
-  add_primrec_i alt_name (map (apsnd (readtm (sign_of thy) propT)) eqns) thy;
+(* outer syntax *)
+
+open OuterParse;
+
+val primrec_decl =
+  Scan.optional ($$$ "(" |-- name --| $$$ ")") "" --
+    Scan.repeat1 (opt_thm_name ":" -- term);
+
+val primrecP =
+  OuterSyntax.parser false "primrec" "define primitive recursive functions on datatypes"
+    (primrec_decl >> (fn (alt_name, eqns) =>
+      Toplevel.theory (#1 o add_primrec alt_name (map (fn ((x, y), z) => ((x, z), y)) eqns))));
+
+val _ = OuterSyntax.add_parsers [primrecP];
 
 end;