src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 69992 bd3c10813cc4
parent 69593 3dda49e08b9d
child 70494 41108e3e9ca5
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Tue Mar 26 14:23:18 2019 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Tue Mar 26 22:13:36 2019 +0100
@@ -580,7 +580,7 @@
         ((common_qualify (Binding.qualify true common_name (Binding.name thmN)), attrs),
           [(thms, [])]));
   in
-    (((fun_names, qualifys, defs),
+    (((fun_names, qualifys, arg_Ts, defs),
       fn lthy => fn defs =>
         let
           val def_thms = map (snd o snd) defs;
@@ -605,24 +605,29 @@
     val nonexhaustives = replicate actual_nn nonexhaustive;
     val transfers = replicate actual_nn transfer;
 
-    val (((names, qualifys, defs), prove), lthy') =
+    val (((names, qualifys, arg_Ts, defs), prove), lthy') =
       prepare_primrec plugins nonexhaustives transfers fixes ts lthy;
   in
     lthy'
     |> fold_map Local_Theory.define defs
     |> tap (uncurry (print_def_consts int))
     |-> (fn defs => fn lthy =>
-      let val ((jss, simpss), lthy) = prove lthy defs in
-        (((names, qualifys), (map fst defs, map (snd o snd) defs, (jss, simpss))), lthy)
-      end)
+      let
+        val ((jss, simpss), lthy) = prove lthy defs;
+        val res =
+          {prefix = (names, qualifys),
+           types = map (#1 o dest_Type) arg_Ts,
+           result = (map fst defs, map (snd o snd) defs, (jss, simpss))};
+      in (res, lthy) end)
   end;
 
 fun primrec_simple int fixes ts lthy =
   primrec_simple0 int Plugin_Name.default_filter false false fixes ts lthy
+    |>> (fn {prefix, result, ...} => (prefix, result))
   handle OLD_PRIMREC () =>
     Old_Primrec.primrec_simple int fixes ts lthy
-    |>> apsnd (fn (ts, thms) => (ts, [], ([], [thms]))) o apfst single
-    |>> apfst (map_split (rpair I));
+    |>> (fn {prefix, result = (ts, thms), ...} =>
+          (map_split (rpair I) [prefix], (ts, [], ([], [thms]))))
 
 fun gen_primrec old_primrec prep_spec int opts raw_fixes raw_specs lthy =
   let
@@ -648,8 +653,8 @@
   in
     lthy
     |> primrec_simple0 int plugins nonexhaustive transfer fixes (map snd specs)
-    |-> (fn ((names, qualifys), (ts, defs, (jss, simpss))) =>
-      Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
+    |-> (fn {prefix = (names, qualifys), types, result = (ts, defs, (jss, simpss))} =>
+      Spec_Rules.add (Spec_Rules.equational_primrec types) (ts, flat simpss)
       #> Local_Theory.notes (mk_notes jss names qualifys simpss)
       #-> (fn notes =>
         plugins code_plugin ? Code.declare_default_eqns (map (rpair true) (maps snd notes))
@@ -657,7 +662,7 @@
   end
   handle OLD_PRIMREC () =>
     old_primrec int raw_fixes raw_specs lthy
-    |>> (fn (ts, thms) => (ts, [], [thms]));
+    |>> (fn {result = (ts, thms), ...} => (ts, [], [thms]));
 
 val primrec = gen_primrec Old_Primrec.primrec Specification.check_multi_specs;
 val primrec_cmd = gen_primrec Old_Primrec.primrec_cmd Specification.read_multi_specs;