Add plugin to generate transfer theorem for primrec and primcorec
authordesharna
Fri, 19 Dec 2014 14:06:13 +0100
changeset 59275 77cd4992edcd
parent 59274 67afe7e6a516
child 59276 d207455817e8
Add plugin to generate transfer theorem for primrec and primcorec
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_rec_sugar_transfer.ML
src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML
src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Transfer.thy
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Dec 19 11:20:07 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Dec 19 14:06:13 2014 +0100
@@ -154,6 +154,8 @@
       BNF_Def.bnf list -> BNF_Comp.absT_info list -> local_theory ->
       BNF_FP_Util.fp_result * local_theory) ->
     (local_theory -> local_theory) parser
+
+  val mk_parametricity_goal: Proof.context -> term list -> term -> term -> term
 end;
 
 structure BNF_FP_Def_Sugar : BNF_FP_DEF_SUGAR =
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/BNF/bnf_fp_rec_sugar_transfer.ML	Fri Dec 19 14:06:13 2014 +0100
@@ -0,0 +1,191 @@
+(*  Title:      HOL/Tools/BNF/bnf_fp_rec_sugar_transfer.ML
+    Author:     Martin Desharnais, TU Muenchen
+    Copyright   2014
+
+Parametricity of primitively (co)recursive functions.
+*)
+
+(* DO NOT FORGET TO DOCUMENT THIS NEW PLUGIN!!! *)
+
+signature BNF_FP_REC_SUGAR_TRANSFER =
+sig
+
+val primrec_transfer_pluginN : string
+val primcorec_transfer_pluginN : string
+
+val primrec_transfer_interpretation:
+  BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> Proof.context -> Proof.context
+val primcorec_transfer_interpretation:
+  BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> Proof.context -> Proof.context
+
+end;
+
+structure BNF_FP_Rec_Sugar_Transfer : BNF_FP_REC_SUGAR_TRANSFER =
+struct
+
+open BNF_Def
+open BNF_FP_Def_Sugar
+open BNF_FP_Rec_Sugar_Util
+open BNF_FP_Util
+open Ctr_Sugar_Tactics
+open Ctr_Sugar_Util
+
+val primrec_transfer_pluginN = Plugin_Name.declare_setup @{binding primrec_transfer};
+val primcorec_transfer_pluginN = Plugin_Name.declare_setup @{binding primcorec_transfer};
+
+fun mk_primrec_transfer_tac ctxt def =
+  Ctr_Sugar_Tactics.unfold_thms_tac ctxt [def] THEN
+  HEADGOAL (Transfer.transfer_prover_tac ctxt);
+
+fun mk_primcorec_transfer_tac apply_transfer ctxt f_def corec_def type_definitions
+  dtor_corec_transfers rel_pre_defs disc_eq_cases cases case_distribs case_congs =
+  let
+    fun instantiate_with_lambda thm =
+      let
+        val prop = Thm.prop_of thm;
+        val @{const Trueprop} $
+          (Const (@{const_name HOL.eq}, _) $
+            (Var (_, fT) $ _) $ _) = prop;
+        val T = range_type fT;
+        val idx = Term.maxidx_of_term prop + 1;
+        val bool_expr = Var (("x", idx), HOLogic.boolT);
+        val then_expr = Var (("t", idx), T);
+        val else_expr = Var (("e", idx), T);
+        val lambda = Term.lambda bool_expr (mk_If bool_expr then_expr else_expr);
+      in
+        cterm_instantiate_pos [SOME (certify ctxt lambda)] thm
+      end;
+
+    val transfer_rules =
+      @{thm Abs_transfer[OF
+        BNF_Composition.type_definition_id_bnf_UNIV
+        BNF_Composition.type_definition_id_bnf_UNIV]} ::
+      map (fn thm => @{thm Abs_transfer} OF [thm, thm]) type_definitions @
+      map (Local_Defs.unfold ctxt rel_pre_defs) dtor_corec_transfers;
+    val add_transfer_rule = Thm.attribute_declaration Transfer.transfer_add
+    val ctxt' = Context.proof_map (fold add_transfer_rule transfer_rules) ctxt
+
+    val case_distribs = map instantiate_with_lambda case_distribs;
+    val simps = case_distribs @ disc_eq_cases @ cases @ @{thms if_True if_False};
+    val simp_ctxt = put_simpset (simpset_of (ss_only simps ctxt)) ctxt';
+  in
+    unfold_thms_tac ctxt ([f_def, corec_def] @ @{thms split_beta if_conn}) THEN
+    HEADGOAL (simp_tac (fold Simplifier.add_cong case_congs simp_ctxt)) THEN
+    (if apply_transfer then HEADGOAL (Transfer.transfer_prover_tac ctxt') else all_tac)
+  end;
+
+fun massage_simple_notes base =
+  filter_out (null o #2)
+  #> map (fn (thmN, thms, f_attrs) =>
+    ((Binding.qualify true base (Binding.name thmN), []),
+     map_index (fn (i, thm) => ([thm], f_attrs i)) thms));
+
+fun fp_sugar_of_bnf ctxt = fp_sugar_of ctxt o (fn Type (s, _) => s) o T_of_bnf;
+
+val cat_somes = map the o filter is_some
+fun maybe_apply z = the_default z oo Option.map
+
+fun bnf_depth_first_traverse ctxt f T z =
+  case T of
+    Type (s, innerTs) =>
+    (case bnf_of ctxt s of
+      NONE => z
+    | SOME bnf => let val z' = f bnf z in
+        fold (bnf_depth_first_traverse ctxt f) innerTs z'
+      end)
+  | _ => z
+
+fun if_all_bnfs ctxt Ts f g =
+  let
+    val bnfs = cat_somes (map (fn T =>
+      case T of Type (s, _) => BNF_Def.bnf_of ctxt s | _ => NONE) Ts);
+  in
+    if length bnfs = length Ts then f bnfs else g
+  end;
+
+fun mk_goal lthy f =
+  let
+    val skematicTs = Term.add_tvarsT (fastype_of f) [];
+
+    val ((As, Bs), names_lthy) = lthy
+      |> Ctr_Sugar_Util.mk_TFrees' (map snd skematicTs)
+      ||>> Ctr_Sugar_Util.mk_TFrees' (map snd skematicTs);
+
+    val (Rs, names_lthy) =
+      Ctr_Sugar_Util.mk_Frees "R" (map2 BNF_Util.mk_pred2T As Bs) names_lthy;
+
+    val fA = Term.subst_TVars (map fst skematicTs ~~ As) f;
+    val fB = Term.subst_TVars (map fst skematicTs ~~ Bs) f;
+  in
+    (BNF_FP_Def_Sugar.mk_parametricity_goal lthy Rs fA fB, names_lthy)
+  end;
+
+fun prove_parametricity_if_bnf prove {transfers, fun_names, funs, fun_defs, fpTs} lthy =
+  fold_index (fn (n, (((transfer, f_names), f), def)) => fn lthy =>
+      if not transfer then lthy
+      else
+        if_all_bnfs lthy fpTs
+          (fn bnfs => fn () => prove n bnfs f_names f def lthy)
+          (fn () => let val _ = error "Function is not parametric." in lthy end) ())
+    (transfers ~~ fun_names ~~ funs ~~ fun_defs) lthy;
+
+fun prim_co_rec_transfer_interpretation prove =
+  prove_parametricity_if_bnf (fn n => fn bnfs => fn f_name => fn f => fn def => fn lthy =>
+    case try (prove n bnfs f def) lthy of
+      NONE => error "Failed to prove parametricity."
+    | SOME thm =>
+      let
+        val notes =
+          [("transfer", [thm], K @{attributes [transfer_rule]})]
+          |> massage_simple_notes f_name;
+      in
+        snd (Local_Theory.notes notes lthy)
+      end);
+
+val primrec_transfer_interpretation = prim_co_rec_transfer_interpretation
+  (fn n => fn bnfs => fn f => fn def => fn lthy =>
+     let
+       val (goal, names_lthy) = mk_goal lthy f;
+     in
+       Goal.prove lthy [] [] goal (fn {context = ctxt, prems = _} =>
+         mk_primrec_transfer_tac ctxt def)
+       |> singleton (Proof_Context.export names_lthy lthy)
+       |> Thm.close_derivation
+     end);
+
+val primcorec_transfer_interpretation = prim_co_rec_transfer_interpretation
+  (fn n => fn bnfs => fn f => fn def => fn lthy =>
+     let
+       val fp_sugars = map (the o fp_sugar_of_bnf lthy) bnfs;
+       val (goal, names_lthy) = mk_goal lthy f;
+       val (disc_eq_cases, case_thms, case_distribs, case_congs) =
+         bnf_depth_first_traverse lthy (fn bnf => fn xs =>
+           let
+             fun add_thms (xs, ys, zs, ws) (fp_sugar : fp_sugar) =
+               let
+                 val ctr_sugar = #ctr_sugar (#fp_ctr_sugar fp_sugar);
+                 val xs' = #disc_eq_cases ctr_sugar;
+                 val ys' = #case_thms ctr_sugar;
+                 val zs' = #case_distribs ctr_sugar;
+                 val w = #case_cong ctr_sugar;
+                 val union' = union Thm.eq_thm;
+                 val insert' = insert Thm.eq_thm;
+               in
+                 (union' xs' xs, union' ys' ys, union' zs' zs, insert' w ws)
+               end;
+           in
+             maybe_apply xs (add_thms xs) (fp_sugar_of_bnf lthy bnf)
+           end) (fastype_of f) ([], [], [], []);
+     in
+       Goal.prove lthy [] [] goal (fn {context = ctxt, prems = _} =>
+         mk_primcorec_transfer_tac true ctxt def
+         (#co_rec_def (#fp_co_induct_sugar (nth fp_sugars n)))
+         (map (#type_definition o #absT_info) fp_sugars)
+         (flat (map (#xtor_co_rec_transfers o #fp_res) fp_sugars))
+         (map (rel_def_of_bnf o #pre_bnf) fp_sugars)
+         disc_eq_cases case_thms case_distribs case_congs)
+       |> singleton (Proof_Context.export names_lthy lthy)
+       |> Thm.close_derivation
+     end);
+
+end
--- a/src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML	Fri Dec 19 11:20:07 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML	Fri Dec 19 14:06:13 2014 +0100
@@ -8,10 +8,20 @@
 
 signature BNF_FP_REC_SUGAR_UTIL =
 sig
+
   datatype fp_kind = Least_FP | Greatest_FP
 
   val case_fp: fp_kind -> 'a -> 'a -> 'a
 
+  type fp_rec_sugar =
+    {transfers: bool list,
+     fun_names: string list,
+     funs: term list,
+     fun_defs: thm list,
+     fpTs: typ list}
+
+  val morph_fp_rec_sugar: morphism -> fp_rec_sugar -> fp_rec_sugar
+
   val flat_rec_arg_args: 'a list list -> 'a list
 
   val indexed: 'a list -> int -> int list * int
@@ -51,6 +61,20 @@
 fun case_fp Least_FP l _ = l
   | case_fp Greatest_FP _ g = g;
 
+type fp_rec_sugar =
+  {transfers: bool list,
+   fun_names: string list,
+   funs: term list,
+   fun_defs: thm list,
+   fpTs: typ list}
+
+fun morph_fp_rec_sugar phi {transfers, fun_names, funs, fun_defs, fpTs} =
+  {transfers = transfers,
+   fun_names = fun_names,
+   funs = map (Morphism.term phi) funs,
+   fun_defs = map (Morphism.thm phi) fun_defs,
+   fpTs = map (Morphism.typ phi) fpTs};
+
 fun flat_rec_arg_args xss =
   (* FIXME (once the old datatype package is phased out): The first line below gives the preferred
      order. The second line is for compatibility with the old datatype package. *)
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Dec 19 11:20:07 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Dec 19 14:06:13 2014 +0100
@@ -8,7 +8,7 @@
 
 signature BNF_GFP_REC_SUGAR =
 sig
-  datatype primcorec_option = Sequential_Option | Exhaustive_Option
+  datatype primcorec_option = Sequential_Option | Exhaustive_Option | Transfer_Option
 
   datatype corec_call =
     Dummy_No_Corec of int |
@@ -31,7 +31,8 @@
      corec_sels: thm list}
 
   type corec_spec =
-    {corec: term,
+    {T: typ,
+     corec: term,
      exhaust_discs: thm list,
      sel_defs: thm list,
      fp_nesting_maps: thm list,
@@ -43,6 +44,11 @@
     (term * term list list) list list -> local_theory ->
     corec_spec list * typ list * thm * thm * thm list * thm list * (Token.src list * Token.src list)
     * bool * local_theory
+
+  val primcorec_interpretation:
+    string -> (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) ->
+    theory -> theory
+
   val add_primcorecursive_cmd: primcorec_option list ->
     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
     Proof.context -> Proof.state
@@ -81,7 +87,7 @@
 fun primcorec_error_eqn str eqn = raise PRIMCOREC (str, [eqn]);
 fun primcorec_error_eqns str eqns = raise PRIMCOREC (str, eqns);
 
-datatype primcorec_option = Sequential_Option | Exhaustive_Option;
+datatype primcorec_option = Sequential_Option | Exhaustive_Option | Transfer_Option;
 
 datatype corec_call =
   Dummy_No_Corec of int |
@@ -109,7 +115,8 @@
    corec_sels: thm list};
 
 type corec_spec =
-  {corec: term,
+  {T: typ,
+   corec: term,
    exhaust_discs: thm list,
    sel_defs: thm list,
    fp_nesting_maps: thm list,
@@ -403,6 +410,16 @@
     (case fp_sugar_of ctxt s of SOME {fp_bnf_sugar = {map_thms, ...}, ...} => map_thms | NONE => [])
   | map_thms_of_typ _ _ = [];
 
+val transfer_primcorec = morph_fp_rec_sugar o Morphism.transfer_morphism;
+
+structure Primcorec_Plugin = Plugin(type T = fp_rec_sugar);
+
+fun primcorec_interpretation name f =
+  Primcorec_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
+    f (transfer_primcorec (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
+
+val interpret_primcorec = Primcorec_Plugin.data_default;
+
 fun corec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
@@ -501,7 +518,7 @@
     fun mk_spec ({T, fp_ctr_sugar = {ctr_sugar as {exhaust_discs, sel_defs, ...}, ...},
         fp_co_induct_sugar = {co_rec = corec, co_rec_thms = corec_thms, co_rec_discs = corec_discs,
         co_rec_selss = corec_selss, ...}, ...} : fp_sugar) p_is q_isss f_isss f_Tsss =
-      {corec = mk_co_rec thy Greatest_FP perm_Cs' (substAT T) corec, exhaust_discs = exhaust_discs,
+      {T = T, corec = mk_co_rec thy Greatest_FP perm_Cs' (substAT T) corec, exhaust_discs = exhaust_discs,
        sel_defs = sel_defs,
        fp_nesting_maps = maps (map_thms_of_typ lthy o T_of_bnf) fp_nesting_bnfs,
        fp_nesting_map_ident0s = map map_ident0_of_bnf fp_nesting_bnfs,
@@ -976,6 +993,7 @@
 
     val sequentials = replicate actual_nn (member (op =) opts Sequential_Option);
     val exhaustives = replicate actual_nn (member (op =) opts Exhaustive_Option);
+    val transfers = replicate actual_nn (member (op =) opts Transfer_Option);
 
     val fun_names = map Binding.name_of bs;
     val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
@@ -990,7 +1008,7 @@
       let
         val missing = fun_names
           |> filter (map (fn Disc x => #fun_name x | Sel x => #fun_name x) eqns_data
-            |> not oo member (op =))
+            |> not oo member (op =));
       in
         null missing
           orelse primcorec_error_eqns ("missing equations for function(s): " ^ commas missing) []
@@ -1107,6 +1125,7 @@
     fun prove thmss'' def_infos lthy =
       let
         val def_thms = map (snd o snd) def_infos;
+        val ts = map fst def_infos;
 
         val (nchotomy_thmss, exclude_thmss) =
           (map2 append (take actual_nn thmss'') nchotomy_taut_thmss, drop actual_nn thmss'');
@@ -1421,8 +1440,16 @@
         |> Spec_Rules.add Spec_Rules.Equational (map fst def_infos, flat code_thmss)
         |> Local_Theory.notes (anonymous_notes @ common_notes @ notes)
         |> snd
+        |> (fn lthy =>
+          let
+            val phi = Local_Theory.target_morphism lthy;
+            val Ts = take actual_nn (map #T corec_specs);
+          in
+            interpret_primcorec {transfers = transfers, fun_names = fun_names,
+              funs = map (Morphism.term phi) ts, fun_defs = Morphism.fact phi def_thms, fpTs = Ts}
+              lthy
+          end)
       end;
-
     fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss';
   in
     (goalss, after_qed, lthy')
@@ -1458,7 +1485,8 @@
 
 val primcorec_option_parser = Parse.group (fn () => "option")
   (Parse.reserved "sequential" >> K Sequential_Option
-  || Parse.reserved "exhaustive" >> K Exhaustive_Option)
+  || Parse.reserved "exhaustive" >> K Exhaustive_Option
+  || Parse.reserved "transfer" >> K Transfer_Option)
 
 (* FIXME: should use "Parse_Spec.spec" instead of "Parse.prop" and honor binding *)
 val where_alt_specs_of_parser = Parse.where_ |-- Parse.!!! (Parse.enum1 "|"
@@ -1476,4 +1504,8 @@
       Parse.!!! (Parse.list1 primcorec_option_parser) --| @{keyword ")"}) []) --
     (Parse.fixes -- where_alt_specs_of_parser) >> uncurry add_primcorec_cmd);
 
+val _ = Theory.setup (primcorec_interpretation
+  BNF_FP_Rec_Sugar_Transfer.primcorec_transfer_pluginN
+  BNF_FP_Rec_Sugar_Transfer.primcorec_transfer_interpretation)
+
 end;
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Dec 19 11:20:07 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Dec 19 14:06:13 2014 +0100
@@ -8,7 +8,7 @@
 
 signature BNF_LFP_REC_SUGAR =
 sig
-  datatype primrec_option = Nonexhaustive_Option
+  datatype primrec_option = Nonexhaustive_Option | Transfer_Option;
 
   datatype rec_call =
     No_Rec of int * typ |
@@ -55,7 +55,11 @@
     * local_theory
   val rec_specs_of: binding list -> typ list -> typ list -> term list ->
     (term * term list list) list list -> local_theory ->
-    (bool * rec_spec list * typ list * thm * thm list * Token.src list) * local_theory
+    (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory
+
+  val primrec_interpretation:
+    string -> (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) ->
+    theory -> theory
 
   val add_primrec: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
@@ -88,7 +92,7 @@
 exception OLD_PRIMREC of unit;
 exception PRIMREC of string * term list;
 
-datatype primrec_option = Nonexhaustive_Option;
+datatype primrec_option = Nonexhaustive_Option | Transfer_Option;
 
 datatype rec_call =
   No_Rec of int * typ |
@@ -174,6 +178,16 @@
     SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt
   | _ => error "Unsupported nested recursion");
 
+val transfer_primrec = morph_fp_rec_sugar o Morphism.transfer_morphism;
+
+structure Primrec_Plugin = Plugin(type T = fp_rec_sugar);
+
+fun primrec_interpretation name f =
+  Primrec_Plugin.interpretation name (fn fp_rec_sugar => fn lthy =>
+    f (transfer_primrec (Proof_Context.theory_of lthy) fp_rec_sugar) lthy);
+
+val interpret_primrec = Primrec_Plugin.data_default;
+
 fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
@@ -242,7 +256,7 @@
        ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
   in
     ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts,
-      induct_attrs), lthy)
+      induct_attrs, map #T basic_lfp_sugars), lthy)
   end;
 
 val undef_const = Const (@{const_name undefined}, dummyT);
@@ -401,23 +415,28 @@
       |> fold_rev lambda (args @ left_args @ right_args)
     end);
 
-fun build_defs ctxt nonexhaustive bs mxs (funs_data : eqn_data list list)
+fun build_defs ctxt nonexhaustives bs mxs (funs_data : eqn_data list list)
     (rec_specs : rec_spec list) has_call =
   let
     val n_funs = length funs_data;
 
     val ctr_spec_eqn_data_list' =
-      map #ctr_specs (take n_funs rec_specs) ~~ funs_data
-      |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
-        ##> (fn x => null x orelse
-          raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
-    val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) =>
+      maps (fn ((xs, ys), z) =>
+        let
+          val zs = replicate (length xs) z
+          val (b, c) = finds (fn ((x,_), y) => #ctr x = #ctr y) (xs ~~ zs) ys
+          val (_ : bool ) = (fn x => null x orelse
+            raise PRIMREC ("excess equations in definition", map #rhs_term x)) c
+        in b end) (map #ctr_specs (take n_funs rec_specs) ~~ funs_data ~~ nonexhaustives);
+
+    val (_ : unit list) = ctr_spec_eqn_data_list' |> map (fn (({ctr, ...}, nonexhaustive), x) =>
       if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x)
       else if length x = 1 orelse nonexhaustive then ()
       else warning ("no equation for constructor " ^ Syntax.string_of_term ctxt ctr));
 
     val ctr_spec_eqn_data_list =
-      ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
+      map (fn ((x, y), z) => (x, z)) ctr_spec_eqn_data_list' @
+        (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
 
     val recs = take n_funs rec_specs |> map #recx;
     val rec_args = ctr_spec_eqn_data_list
@@ -472,7 +491,7 @@
   unfold_thms_tac ctxt (nested_simps ctxt @ map_ident0s @ map_comps) THEN
   HEADGOAL (rtac refl);
 
-fun prepare_primrec nonexhaustive fixes specs lthy0 =
+fun prepare_primrec nonexhaustives transfers fixes specs lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
 
@@ -502,7 +521,7 @@
         [] => ()
       | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
 
-    val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs), lthy) =
+    val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs, Ts), lthy) =
       rec_specs_of bs arg_Ts res_Ts frees callssss lthy0;
 
     val actual_nn = length funs_data;
@@ -513,10 +532,10 @@
         raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
           " is not a constructor in left-hand side", [user_eqn])) eqns_data;
 
-    val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call;
+    val defs = build_defs lthy nonexhaustives bs mxs funs_data rec_specs has_call;
 
-    fun prove lthy' def_thms' ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, ...}
-        : rec_spec) (fun_data : eqn_data list) =
+    fun prove def_thms' ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, ...}
+        : rec_spec) (fun_data : eqn_data list) lthy' =
       let
         val js =
           find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr)))
@@ -534,7 +553,7 @@
               |> Thm.close_derivation)
             js;
       in
-        (js, simp_thms)
+        ((js, simp_thms), lthy')
       end;
 
     val notes =
@@ -555,19 +574,33 @@
   in
     (((fun_names, defs),
       fn lthy => fn defs =>
-        split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
+        let
+          val def_thms = map (snd o snd) defs;
+          val ts = map fst defs;
+          val phi = Local_Theory.target_morphism lthy;
+        in
+          map_prod split_list
+            (interpret_primrec {transfers = transfers, fun_names = fun_names,
+               funs = (map (Morphism.term phi) ts), fun_defs = (Morphism.fact phi def_thms),
+               fpTs = (take actual_nn Ts)})
+            (@{fold_map 2} (prove defs) (take actual_nn rec_specs) funs_data lthy)
+        end),
       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
   end;
 
 fun add_primrec_simple' opts fixes ts lthy =
   let
-    val nonexhaustive = member (op =) opts Nonexhaustive_Option;
-    val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy
+    val actual_nn = length fixes;
+    val nonexhaustives = replicate actual_nn (member (op =) opts Nonexhaustive_Option);
+    val transfers = replicate actual_nn (member (op =) opts Transfer_Option);
+    val (((names, defs), prove), lthy') = prepare_primrec nonexhaustives transfers fixes ts lthy
       handle ERROR str => raise PRIMREC (str, []);
   in
     lthy'
     |> fold_map Local_Theory.define defs
-    |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
+    |-> (fn defs => fn lthy =>
+      let val (thms, lthy) = prove lthy defs;
+      in ((names, (map fst defs, thms)), lthy) end)
   end
   handle PRIMREC (str, eqns) =>
          if null eqns then
@@ -626,7 +659,8 @@
   ##> Local_Theory.exit_global;
 
 val primrec_option_parser = Parse.group (fn () => "option")
-  (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option)
+  (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option
+  || Parse.reserved "transfer" >> K Transfer_Option)
 
 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   "define primitive recursive functions"
--- a/src/HOL/Transfer.thy	Fri Dec 19 11:20:07 2014 +0100
+++ b/src/HOL/Transfer.thy	Fri Dec 19 14:06:13 2014 +0100
@@ -6,7 +6,7 @@
 section {* Generic theorem transfer using relations *}
 
 theory Transfer
-imports Hilbert_Choice Metis Basic_BNF_LFPs
+imports Basic_BNF_LFPs Hilbert_Choice Metis
 begin
 
 subsection {* Relator for function space *}
@@ -361,7 +361,21 @@
 
 end
 
+lemma if_conn:
+  "(if P \<and> Q then t else e) = (if P then if Q then t else e else e)"
+  "(if P \<or> Q then t else e) = (if P then t else if Q then t else e)"
+  "(if P \<longrightarrow> Q then t else e) = (if P then if Q then t else e else t)"
+  "(if \<not> P then t else e) = (if P then e else t)"
+by auto
+
 ML_file "Tools/Transfer/transfer_bnf.ML"
+ML_file "Tools/BNF/bnf_fp_rec_sugar_transfer.ML"
+
+ML {*
+val _ = Theory.setup (BNF_LFP_Rec_Sugar.primrec_interpretation
+  BNF_FP_Rec_Sugar_Transfer.primrec_transfer_pluginN
+  BNF_FP_Rec_Sugar_Transfer.primrec_transfer_interpretation)
+*}
 
 declare pred_fun_def [simp]
 declare rel_fun_eq [relator_eq]