src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 59281 1b4dc8a9f7d9
parent 59279 f5816b4d6489
child 59283 5ca195783da8
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Mon Jan 05 09:54:41 2015 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Mon Jan 05 10:09:42 2015 +0100
@@ -8,7 +8,10 @@
 
 signature BNF_LFP_REC_SUGAR =
 sig
-  datatype primrec_option = Nonexhaustive_Option | Transfer_Option;
+  datatype rec_option =
+    Plugins_Option of Proof.context -> Plugin_Name.filter |
+    Nonexhaustive_Option |
+    Transfer_Option
 
   datatype rec_call =
     No_Rec of int * typ |
@@ -57,13 +60,12 @@
     (term * term list list) list list -> local_theory ->
     (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory
 
-  val primrec_interpretation:
-    string -> (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) ->
-    theory -> theory
+  val lfp_rec_sugar_interpretation: string ->
+    (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> theory -> theory
 
   val add_primrec: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
-  val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list ->
+  val add_primrec_cmd: rec_option list -> (binding * string option * mixfix) list ->
     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
   val add_primrec_global: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
@@ -92,7 +94,10 @@
 exception OLD_PRIMREC of unit;
 exception PRIMREC of string * term list;
 
-datatype primrec_option = Nonexhaustive_Option | Transfer_Option;
+datatype rec_option =
+  Plugins_Option of Proof.context -> Plugin_Name.filter |
+  Nonexhaustive_Option |
+  Transfer_Option;
 
 datatype rec_call =
   No_Rec of int * typ |
@@ -178,15 +183,13 @@
     SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt
   | _ => error "Unsupported nested recursion");
 
-val transfer_primrec = morph_fp_rec_sugar o Morphism.transfer_morphism;
-
-structure Primrec_Plugin = Plugin(type T = fp_rec_sugar);
+structure LFP_Rec_Sugar_Plugin = Plugin(type T = fp_rec_sugar);
 
-fun primrec_interpretation name f =
-  Primrec_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
-    f (transfer_primrec (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
+fun lfp_rec_sugar_interpretation name f =
+  LFP_Rec_Sugar_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
+    f (transfer_fp_rec_sugar (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
 
-val interpret_primrec = Primrec_Plugin.data_default;
+val interpret_lfp_rec_sugar = LFP_Rec_Sugar_Plugin.data;
 
 fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
   let
@@ -491,7 +494,7 @@
   unfold_thms_tac ctxt (nested_simps ctxt @ map_ident0s @ map_comps) THEN
   HEADGOAL (rtac refl);
 
-fun prepare_primrec nonexhaustives transfers fixes specs lthy0 =
+fun prepare_primrec plugins nonexhaustives transfers fixes specs lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
 
@@ -578,11 +581,11 @@
           val def_thms = map (snd o snd) defs;
           val ts = map fst defs;
           val phi = Local_Theory.target_morphism lthy;
+          val fp_rec_sugar =
+            {transfers = transfers, fun_names = fun_names, funs = map (Morphism.term phi) ts,
+             fun_defs = Morphism.fact phi def_thms, fpTs = take actual_nn Ts};
         in
-          map_prod split_list
-            (interpret_primrec {transfers = transfers, fun_names = fun_names,
-               funs = (map (Morphism.term phi) ts), fun_defs = (Morphism.fact phi def_thms),
-               fpTs = (take actual_nn Ts)})
+          map_prod split_list (interpret_lfp_rec_sugar plugins fp_rec_sugar)
             (@{fold_map 2} (prove defs) (take actual_nn rec_specs) funs_data lthy)
         end),
       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
@@ -591,9 +594,14 @@
 fun add_primrec_simple' opts fixes ts lthy =
   let
     val actual_nn = length fixes;
-    val nonexhaustives = replicate actual_nn (member (op =) opts Nonexhaustive_Option);
-    val transfers = replicate actual_nn (member (op =) opts Transfer_Option);
-    val (((names, defs), prove), lthy') = prepare_primrec nonexhaustives transfers fixes ts lthy
+
+    val plugins = get_first (fn Plugins_Option f => SOME (f lthy) | _ => NONE) (rev opts)
+      |> the_default Plugin_Name.default_filter;
+    val nonexhaustives = replicate actual_nn (exists (can (fn Nonexhaustive_Option => ())) opts);
+    val transfers = replicate actual_nn (exists (can (fn Transfer_Option => ())) opts);
+
+    val (((names, defs), prove), lthy') =
+      prepare_primrec plugins nonexhaustives transfers fixes ts lthy
       handle ERROR str => raise PRIMREC (str, []);
   in
     lthy'
@@ -658,14 +666,15 @@
   #> add_primrec fixes specs
   ##> Local_Theory.exit_global;
 
-val primrec_option_parser = Parse.group (K "option")
-  (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option
+val rec_option_parser = Parse.group (K "option")
+  (Plugin_Name.parse_filter >> Plugins_Option
+   || Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option
    || Parse.reserved "transfer" >> K Transfer_Option);
 
 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   "define primitive recursive functions"
-  ((Scan.optional (@{keyword "("} |--
-      Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) --
+  ((Scan.optional (@{keyword "("} |-- Parse.!!! (Parse.list1 rec_option_parser)
+      --| @{keyword ")"}) []) --
     (Parse.fixes -- Parse_Spec.where_alt_specs)
     >> (fn (opts, (fixes, specs)) => snd o add_primrec_cmd opts fixes specs));