simplified preparation and outer parsing of specification;
authorwenzelm
Thu, 12 Mar 2009 21:55:02 +0100
changeset 30487 a14ff49d3083
parent 30486 9cdc7ce0e389
child 30490 d09b7f0c2c14
simplified preparation and outer parsing of specification; export extern cmd interfaces as well;
src/HOL/Nominal/nominal_primrec.ML
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Nominal/nominal_primrec.ML	Thu Mar 12 21:51:02 2009 +0100
+++ b/src/HOL/Nominal/nominal_primrec.ML	Thu Mar 12 21:55:02 2009 +0100
@@ -12,6 +12,10 @@
     (binding * typ option * mixfix) list ->
     (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> local_theory -> Proof.state
+  val add_primrec_cmd: string list option -> string option ->
+    (binding * string option * mixfix) list ->
+    (binding * string option * mixfix) list ->
+    (Attrib.binding * string) list -> local_theory -> Proof.state
 end;
 
 structure NominalPrimrec : NOMINAL_PRIMREC =
@@ -36,10 +40,10 @@
       (fn Free (v, _) => insert (op =) v | _ => I) body []))
   in (curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body) end;
 
-fun process_eqn lthy is_fixed spec rec_fns = 
+fun process_eqn lthy is_fixed spec rec_fns =
   let
     val eq = unquantify spec;
-    val (lhs, rhs) = 
+    val (lhs, rhs) =
       HOLogic.dest_eq (HOLogic.dest_Trueprop (Logic.strip_imp_concl eq))
       handle TERM _ => raise RecError "not a proper equation";
 
@@ -67,7 +71,7 @@
     fun check_vars _ [] = ()
       | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars))
   in
-    if length middle > 1 then 
+    if length middle > 1 then
       raise RecError "more than one non-variable in pattern"
     else
      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
@@ -159,7 +163,7 @@
               val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
                 (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
                   handle RecError s => primrec_eq_err lthy s eq
-            in (fnames'', fnss'', 
+            in (fnames'', fnss'',
                 (list_abs_free (cargs' @ subs, rhs'))::fns)
             end)
 
@@ -172,7 +176,7 @@
             val SOME (_, _, eqns' as (_, (ls, _, rs, _, _)) :: _) =
               AList.lookup (op =) eqns fname;
             val (fnames', fnss', fns) = fold_rev (trans eqns') constrs
-              ((i, fname)::fnames, fnss, []) 
+              ((i, fname)::fnames, fnss, [])
           in
             (fnames', (i, (fname, ls, rs, fns))::fnss')
           end
@@ -235,15 +239,9 @@
 
 local
 
-fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
-  let
-    val ((fixes, spec), _) = prep_spec
-      raw_fixes (map (single o apsnd single) raw_spec) ctxt
-  in (fixes, map (apsnd the_single) spec) end;
-
 fun gen_primrec set_group prep_spec prep_term invs fctxt raw_fixes raw_params raw_spec lthy =
   let
-    val (fixes', spec) = prepare_spec prep_spec lthy (raw_fixes @ raw_params) raw_spec;
+    val (fixes', spec) = fst (prep_spec (raw_fixes @ raw_params) raw_spec lthy);
     val fixes = List.take (fixes', length raw_fixes);
     val (names_atts, spec') = split_list spec;
     val eqns' = map unquantify spec'
@@ -261,12 +259,12 @@
        then () else primrec_err param_err);
     val tnames = distinct (op =) (map (#1 o snd) eqns);
     val dts = find_dts dt_info tnames tnames;
-    val main_fns = 
+    val main_fns =
       map (fn (tname, {index, ...}) =>
-        (index, 
+        (index,
           (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns))
       dts;
-    val {descr, rec_names, rec_rewrites, ...} = 
+    val {descr, rec_names, rec_rewrites, ...} =
       if null dts then
         primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
       else snd (hd dts);
@@ -388,15 +386,15 @@
 
 in
 
-val add_primrec = gen_primrec false Specification.check_specification (K I);
-val add_primrec_cmd = gen_primrec true Specification.read_specification Syntax.read_term;
+val add_primrec = gen_primrec false Specification.check_spec (K I);
+val add_primrec_cmd = gen_primrec true Specification.read_spec Syntax.read_term;
 
 end;
 
 
 (* outer syntax *)
 
-local structure P = OuterParse and K = OuterKeyword in
+local structure P = OuterParse in
 
 val freshness_context = P.reserved "freshness_context";
 val invariant = P.reserved "invariant";
@@ -408,28 +406,16 @@
     (Scan.repeat1 (unless_flag P.term) >> SOME) -- Scan.optional parser1 NONE ||
   (parser1 >> pair NONE);
 val options =
-  Scan.optional (P.$$$ "(" |-- P.!!!
-    (parser2 --| P.$$$ ")")) (NONE, NONE);
-
-fun pipe_error t = P.!!! (Scan.fail_with (K
-  (cat_lines ["Equations must be separated by " ^ quote "|", quote t])));
-
-val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead
-  ((P.term :-- pipe_error) || Scan.succeed ("",""));
-
-val statements = P.enum1 "|" statement;
-
-val primrec_decl = P.opt_target -- options --
-  P.fixes -- P.for_fixes --| P.$$$ "where" -- statements;
+  Scan.optional (P.$$$ "(" |-- P.!!! (parser2 --| P.$$$ ")")) (NONE, NONE);
 
 val _ =
-  OuterSyntax.command "nominal_primrec" "define primitive recursive functions on nominal datatypes" K.thy_goal
-    (primrec_decl >> (fn ((((opt_target, (invs, fctxt)), raw_fixes), raw_params), raw_spec) =>
-      Toplevel.print o Toplevel.local_theory_to_proof opt_target
-        (add_primrec_cmd invs fctxt raw_fixes raw_params raw_spec)));
+  OuterSyntax.local_theory_to_proof "nominal_primrec"
+    "define primitive recursive functions on nominal datatypes" OuterKeyword.thy_goal
+    (options -- P.fixes -- P.for_fixes -- SpecParse.where_alt_specs
+      >> (fn ((((invs, fctxt), fixes), params), specs) =>
+        add_primrec_cmd invs fctxt fixes params specs));
 
 end;
 
-
 end;
 
--- a/src/HOL/Tools/primrec_package.ML	Thu Mar 12 21:51:02 2009 +0100
+++ b/src/HOL/Tools/primrec_package.ML	Thu Mar 12 21:55:02 2009 +0100
@@ -9,6 +9,8 @@
 sig
   val add_primrec: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> local_theory -> thm list * local_theory
+  val add_primrec_cmd: (binding * string option * mixfix) list ->
+    (Attrib.binding * string) list -> local_theory -> thm list * local_theory
   val add_primrec_global: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> theory -> thm list * theory
   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
@@ -213,12 +215,6 @@
 
 local
 
-fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
-  let
-    val ((fixes, spec), _) = prep_spec
-      raw_fixes (map (single o apsnd single) raw_spec) ctxt
-  in (fixes, map (apsnd the_single) spec) end;
-
 fun prove_spec ctxt names rec_rewrites defs eqs =
   let
     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
@@ -228,7 +224,7 @@
 
 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
   let
-    val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
+    val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec [];
     val tnames = distinct (op =) (map (#1 o snd) eqns);
@@ -268,8 +264,8 @@
 
 in
 
-val add_primrec = gen_primrec false Specification.check_specification;
-val add_primrec_cmd = gen_primrec true Specification.read_specification;
+val add_primrec = gen_primrec false Specification.check_spec;
+val add_primrec_cmd = gen_primrec true Specification.read_spec;
 
 end;
 
@@ -300,24 +296,16 @@
 val old_primrec_decl =
   opt_unchecked_name -- Scan.repeat1 ((SpecParse.opt_thm_name ":" >> apfst Binding.name_of) -- P.prop);
 
-fun pipe_error t = P.!!! (Scan.fail_with (K
-  (cat_lines ["Equations must be separated by " ^ quote "|", quote t])));
-
-val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead
-  ((P.term :-- pipe_error) || Scan.succeed ("",""));
-
-val statements = P.enum1 "|" statement;
-
-val primrec_decl = P.opt_target -- P.fixes --| P.$$$ "where" -- statements;
+val primrec_decl = P.opt_target -- P.fixes -- SpecParse.where_alt_specs;
 
 val _ =
   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
-    ((primrec_decl >> (fn ((opt_target, raw_fixes), raw_spec) =>
-      Toplevel.local_theory opt_target (add_primrec_cmd raw_fixes raw_spec #> snd)))
+    ((primrec_decl >> (fn ((opt_target, fixes), specs) =>
+      Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd)))
     || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
       Toplevel.theory (snd o
-        (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) alt_name
-          (map P.triple_swap eqns)))));
+        (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec)
+          alt_name (map P.triple_swap eqns)))));
 
 end;