improved backwards compatiblity of primrec_new (Isabelle/ML interface, attributes, etc.)
authortraytel
Tue, 01 Oct 2013 17:04:27 +0200
changeset 54013 38c0bbb8348b
parent 54012 7a8263843acb
child 54014 21dac9a60f0c
improved backwards compatiblity of primrec_new (Isabelle/ML interface, attributes, etc.)
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_lfp.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Oct 01 15:02:12 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Oct 01 17:04:27 2013 +0200
@@ -7,8 +7,17 @@
 
 signature BNF_FP_REC_SUGAR =
 sig
+  val add_primrec: (binding * typ option * mixfix) list ->
+    (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
   val add_primrec_cmd: (binding * string option * mixfix) list ->
-    (Attrib.binding * string) list -> local_theory -> local_theory;
+    (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
+  val add_primrec_global: (binding * typ option * mixfix) list ->
+    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
+  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
+    (binding * typ option * mixfix) list ->
+    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
+  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
+    local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
   val add_primcorecursive_cmd: bool ->
     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
     Proof.context -> Proof.state
@@ -31,8 +40,9 @@
 val selN = "sel"
 
 val nitpick_attrs = @{attributes [nitpick_simp]};
-val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
 val simp_attrs = @{attributes [simp]};
+val code_nitpick_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
+val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs @ simp_attrs;
 
 exception Primrec_Error of string * term list;
 
@@ -300,11 +310,11 @@
     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
   end;
 
-fun add_primrec fixes specs lthy =
+fun prepare_primrec fixes specs lthy =
   let
     val (bs, mxs) = map_split (apfst fst) fixes;
     val fun_names = map Binding.name_of bs;
-    val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
+    val eqns_data = map (dissect_eqn lthy fun_names) specs;
     val funs_data = eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
@@ -330,52 +340,51 @@
 
     val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
 
-    fun prove def_thms' ({nested_map_idents, nested_map_comps, ctr_specs, ...} : rec_spec)
-        induct_thm (fun_data : eqn_data list) lthy =
+    fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
+        (fun_data : eqn_data list) =
       let
-        val fun_name = #fun_name (hd fun_data);
         val def_thms = map (snd o snd) def_thms';
-        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
+        val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
           |> fst
           |> map_filter (try (fn (x, [y]) =>
             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
           |> map (fn (user_eqn, num_extra_args, rec_thm) =>
             mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
-            |> K |> Goal.prove lthy [] [] user_eqn)
+            |> K |> Goal.prove lthy [] [] user_eqn);
+        val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
+      in
+        (poss, simp_thmss)
+      end;
 
-        val notes =
-          [(inductN, if n2m then [induct_thm] else [], []),
-           (simpsN, simp_thms, code_nitpick_simp_attrs @ simp_attrs)]
-          |> filter_out (null o #2)
-          |> map (fn (thmN, thms, attrs) =>
-            ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]));
-      in
-        lthy |> Local_Theory.notes notes
-      end;
+    val notes =
+      (if n2m then map2 (fn name => fn thm =>
+        (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
+      |> map (fn (prefix, thmN, thms, attrs) =>
+        ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
 
     val common_name = mk_common_name fun_names;
 
     val common_notes =
-      [(inductN, if n2m then [induct_thm] else [], [])]
-      |> filter_out (null o #2)
+      (if n2m then [(inductN, [induct_thm], [])] else [])
       |> map (fn (thmN, thms, attrs) =>
         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   in
-    lthy'
-    |> fold_map Local_Theory.define defs
-    |-> snd oo (fn def_thms' => fold_map3 (prove def_thms') (take actual_nn rec_specs)
-      (take actual_nn induct_thms) funs_data)
-    |> Local_Theory.notes common_notes |> snd
+    (((fun_names, defs),
+      fn lthy => fn defs =>
+        split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
+      lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
   end;
 
-fun add_primrec_cmd raw_fixes raw_specs lthy =
+(* primrec definition *)
+
+fun add_primrec_simple fixes ts lthy =
   let
-    val _ = let val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) in null d orelse
-      primrec_error ("duplicate function name(s): " ^ commas d) end;
-    val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
+    val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
+      handle ERROR str => primrec_error str;
   in
-    add_primrec fixes specs lthy
-      handle ERROR str => primrec_error str
+    lthy
+    |> fold_map Local_Theory.define defs
+    |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   end
   handle Primrec_Error (str, eqns) =>
     if null eqns
@@ -383,6 +392,56 @@
     else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
 
+local
+
+fun gen_primrec prep_spec raw_fixes raw_spec lthy =
+  let
+    val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
+    val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
+
+    val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
+
+    val mk_notes =
+      flat ooo map3 (fn poss => fn prefix => fn thms =>
+        let
+          val (bs, attrss) = map_split (fst o nth specs) poss;
+          val notes =
+            map3 (fn b => fn attrs => fn thm =>
+              ((Binding.qualify false prefix b, code_nitpick_simp_attrs @ attrs), [([thm], [])]))
+            bs attrss thms;
+        in
+          ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
+        end);
+  in
+    lthy
+    |> add_primrec_simple fixes (map snd specs)
+    |-> (fn (names, (ts, (posss, simpss))) =>
+      Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
+      #> Local_Theory.notes (mk_notes posss names simpss)
+      #>> pair ts o map snd)
+  end;
+
+in
+
+val add_primrec = gen_primrec Specification.check_spec;
+val add_primrec_cmd = gen_primrec Specification.read_spec;
+
+end;
+
+fun add_primrec_global fixes specs thy =
+  let
+    val lthy = Named_Target.theory_init thy;
+    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
+    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
+  in ((ts, simps'), Local_Theory.exit_global lthy') end;
+
+fun add_primrec_overloaded ops fixes specs thy =
+  let
+    val lthy = Overloading.overloading ops thy;
+    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
+    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
+  in ((ts, simps'), Local_Theory.exit_global lthy') end;
+
 
 
 (* Primcorec *)
@@ -875,7 +934,7 @@
 
         val notes =
           [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
-           (codeN, ctr_thmss(*FIXME*), code_nitpick_simp_attrs),
+           (codeN, ctr_thmss(*FIXME*), code_nitpick_attrs),
            (ctrN, ctr_thmss, []),
            (discN, disc_thmss, simp_attrs),
            (selN, sel_thmss, simp_attrs),
--- a/src/HOL/BNF/Tools/bnf_lfp.ML	Tue Oct 01 15:02:12 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_lfp.ML	Tue Oct 01 17:04:27 2013 +0200
@@ -1889,6 +1889,6 @@
 
 val _ = Outer_Syntax.local_theory @{command_spec "primrec_new"}
   "define primitive recursive functions"
-  (Parse.fixes -- Parse_Spec.where_alt_specs >> uncurry add_primrec_cmd);
+  (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd));
 
 end;