src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 60004 e27e7be1f2f6
parent 60003 ba8fa0c38d66
child 60007 41a117825097
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Apr 10 18:23:01 2015 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Apr 10 19:05:00 2015 +0200
@@ -62,16 +62,17 @@
     (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> theory -> theory
 
   val primrec: (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
-    local_theory -> (term list * thm list list) * local_theory
+    local_theory -> (term list * thm list * thm list list) * local_theory
   val primrec_cmd: rec_option list -> (binding * string option * mixfix) list ->
-    (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
+    (Attrib.binding * string) list -> local_theory ->
+    (term list * thm list * thm list list) * local_theory
   val primrec_global: (binding * typ option * mixfix) list ->
-    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
+    (Attrib.binding * term) list -> theory -> (term list * thm list * thm list list) * theory
   val primrec_overloaded: (string * (string * typ) * bool) list ->
     (binding * typ option * mixfix) list ->
-    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
-  val primrec_simple: ((binding * typ) * mixfix) list -> term list ->
-    local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
+    (Attrib.binding * term) list -> theory -> (term list * thm list * thm list list) * theory
+  val primrec_simple: ((binding * typ) * mixfix) list -> term list -> local_theory ->
+    (string list * (term list * thm list * (int list list * thm list list))) * local_theory
 end;
 
 structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
@@ -546,7 +547,7 @@
           find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr)))
             fun_data eqns_data;
 
-        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
+        val simps = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
           |> fst
           |> map_filter (try (fn (x, [y]) =>
             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
@@ -556,7 +557,7 @@
               |> K |> Goal.prove_sorry lthy' [] [] user_eqn
               |> Thm.close_derivation);
       in
-        ((js, simp_thms), lthy')
+        ((js, simps), lthy')
       end;
 
     val notes =
@@ -604,18 +605,17 @@
     lthy'
     |> fold_map Local_Theory.define defs
     |-> (fn defs => fn lthy =>
-      let val (thms, lthy) = prove lthy defs;
-      in ((names, (map fst defs, thms)), lthy) end)
+      let val ((jss, simpss), lthy) = prove lthy defs;
+      in ((names, (map fst defs, map (snd o snd) defs, (jss, simpss))), lthy) end)
   end;
 
 fun primrec_simple fixes ts lthy =
   primrec_simple0 Plugin_Name.default_filter false false fixes ts lthy
   handle OLD_PRIMREC () =>
     Old_Primrec.primrec_simple fixes ts lthy
-    |>> apsnd (apsnd (pair [] o single)) o apfst single;
+    |>> apsnd (fn (ts, thms) => (ts, [], ([], [thms]))) o apfst single;
 
-fun gen_primrec old_primrec prep_spec opts
-    (raw_fixes : (binding * 'a option * mixfix) list) raw_specs lthy =
+fun gen_primrec old_primrec prep_spec opts raw_fixes raw_specs lthy =
   let
     val dups = duplicates (op =) (map (Binding.name_of o #1) raw_fixes);
     val _ = null dups orelse error ("Duplicate function name " ^ quote (hd dups));
@@ -644,12 +644,14 @@
   in
     lthy
     |> primrec_simple0 plugins nonexhaustive transfer fixes (map snd specs)
-    |-> (fn (names, (ts, (jss, simpss))) =>
+    |-> (fn (names, (ts, defs, (jss, simpss))) =>
       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
       #> Local_Theory.notes (mk_notes jss names simpss)
-      #>> pair ts o map snd)
+      #>> (fn notes => (ts, defs, map snd notes)))
   end
-  handle OLD_PRIMREC () => old_primrec raw_fixes raw_specs lthy |>> apsnd single;
+  handle OLD_PRIMREC () =>
+    old_primrec raw_fixes raw_specs lthy
+    |>> (fn (ts, thms) => (ts, [], [thms]));
 
 val primrec = gen_primrec Old_Primrec.primrec Specification.check_spec [];
 val primrec_cmd = gen_primrec Old_Primrec.primrec_cmd Specification.read_spec;