src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
changeset 58446 e89f57d1e46c
parent 58435 a379d4531d1a
child 58448 a1d4e7473c98
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Thu Sep 25 16:35:51 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Thu Sep 25 16:35:53 2014 +0200
@@ -140,6 +140,7 @@
 val corec_codeN = "corec_code";
 val map_disc_iffN = "map_disc_iff";
 val map_selN = "map_sel";
+val rec_transferN = "rec_transfer";
 val set_casesN = "set_cases";
 val set_introsN = "set_intros";
 val set_inductN = "set_induct";
@@ -418,9 +419,9 @@
       val ks = 1 upto n;
       val ms = map length ctr_Tss;
 
-      val B_ify = Term.typ_subst_atomic (As ~~ Bs);
+      val B_ify_T = Term.typ_subst_atomic (As ~~ Bs);
 
-      val fpBT = B_ify fpT;
+      val fpBT = B_ify_T fpT;
       val live_AsBs = filter (op <>) (As ~~ Bs);
       val fTs = map (op -->) live_AsBs;
 
@@ -428,7 +429,7 @@
         |> fold (fold Variable.declare_typ) [As, Bs]
         |> mk_TFrees 2
         ||>> mk_Freess "x" ctr_Tss
-        ||>> mk_Freess "y" (map (map B_ify) ctr_Tss)
+        ||>> mk_Freess "y" (map (map B_ify_T) ctr_Tss)
         ||>> mk_Frees "f" fTs
         ||>> mk_Frees "R" (map (uncurry mk_pred2T) live_AsBs)
         ||>> yield_singleton (mk_Frees "a") fpT
@@ -461,7 +462,7 @@
         end;
 
       val cxIns = map2 (mk_cIn ctor) ks xss;
-      val cyIns = map2 (mk_cIn (Term.map_types B_ify ctor)) ks yss;
+      val cyIns = map2 (mk_cIn (Term.map_types B_ify_T ctor)) ks yss;
 
       fun mk_map_thm ctr_def' cxIn =
         fold_thms lthy [ctr_def']
@@ -1092,10 +1093,12 @@
 fun derive_rel_induct_thms_for_types lthy fpA_Ts As Bs ctrAss ctrAs_Tsss exhausts ctor_rel_induct
     ctor_defss ctor_injects pre_rel_defs abs_inverses live_nesting_rel_eqs =
   let
-    val B_ify = typ_subst_nonatomic (As ~~ Bs);
-    val fpB_Ts = map B_ify fpA_Ts;
-    val ctrBs_Tsss = map (map (map B_ify)) ctrAs_Tsss;
-    val ctrBss = map (map (subst_nonatomic_types (As ~~ Bs))) ctrAss;
+    val B_ify_T = Term.typ_subst_atomic (As ~~ Bs);
+    val B_ify = Term.subst_atomic_types (As ~~ Bs);
+
+    val fpB_Ts = map B_ify_T fpA_Ts;
+    val ctrBs_Tsss = map (map (map B_ify_T)) ctrAs_Tsss;
+    val ctrBss = map (map B_ify) ctrAss;
 
     val ((((Rs, IRs), ctrAsss), ctrBsss), names_lthy) = lthy
       |> mk_Frees "R" (map2 mk_pred2T As Bs)
@@ -1319,7 +1322,8 @@
     abs_inverses abs_injects ctor_injects dtor_ctors rel_pre_defs ctor_defss dtor_rel_coinduct
     live_nesting_rel_eqs =
   let
-    val fpB_Ts = map (typ_subst_nonatomic (As ~~ Bs)) fpA_Ts;
+    val B_ify_T = Term.typ_subst_atomic (As ~~ Bs);
+    val fpB_Ts = map B_ify_T fpA_Ts;
 
     val (Rs, IRs, fpAs, fpBs, names_lthy) =
       let
@@ -1711,11 +1715,12 @@
     val set_boss = map (map fst o type_args_named_constrained_of_spec) specs;
     val set_bss = map (map (the_default Binding.empty)) set_boss;
 
-    val (((Bs0, Cs), Xs), names_no_defs_lthy) =
+    val ((((Bs0, Cs), Es), Xs), names_no_defs_lthy) =
       no_defs_lthy
       |> fold (Variable.declare_typ o resort_tfree_or_tvar dummyS) unsorted_As
       |> mk_TFrees num_As
       ||>> mk_TFrees nn
+      ||>> mk_TFrees nn
       ||>> variant_tfrees fp_b_names;
 
     fun add_fake_type spec =
@@ -1793,7 +1798,8 @@
     val ((pre_bnfs, absT_infos), (fp_res as {bnfs = fp_bnfs as any_fp_bnf :: _, ctors = ctors0,
              dtors = dtors0, xtor_co_recs = xtor_co_recs0, xtor_co_induct, dtor_ctors,
              ctor_dtors, ctor_injects, dtor_injects, xtor_map_thms, xtor_set_thmss, xtor_rel_thms,
-             xtor_co_rec_thms, rel_xtor_co_induct_thm, dtor_set_induct_thms, ...},
+             xtor_co_rec_thms, rel_xtor_co_induct_thm, dtor_set_induct_thms,
+             ctor_rec_transfer_thms, ...},
            lthy)) =
       fp_bnf (construct_fp mixfixes map_bs rel_bs set_bss) fp_bs (map dest_TFree unsorted_As)
         (map dest_TFree killed_As) fp_eqs no_defs_lthy
@@ -1857,13 +1863,13 @@
           if alive then resort_tfree_or_tvar S B else A)
         (liveness_of_fp_bnf num_As any_fp_bnf) As Bs0;
 
-    val B_ify = Term.typ_subst_atomic (As ~~ Bs);
+    val B_ify_T = Term.typ_subst_atomic (As ~~ Bs);
 
     val ctors = map (mk_ctor As) ctors0;
     val dtors = map (mk_dtor As) dtors0;
 
     val fpTs = map (domain_type o fastype_of) dtors;
-    val fpBTs = map B_ify fpTs;
+    val fpBTs = map B_ify_T fpTs;
 
     val code_attrs = if plugins code_plugin then [Code.add_default_eqn_attrib] else [];
 
@@ -1999,6 +2005,28 @@
         rel_distincts setss =
       injects @ distincts @ case_thms @ co_recs @ mapsx @ rel_injects @ rel_distincts @ flat setss;
 
+    fun derive_rec_transfer_thms lthy recs rec_defs ns =
+      let
+        val liveAsBs = filter (op <>) (As ~~ Bs);
+        val B_ify = Term.subst_atomic_types (liveAsBs @ (Cs ~~ Es));
+
+        val ((Rs, Ss), names_lthy) = lthy
+          |> mk_Frees "R" (map (uncurry mk_pred2T) liveAsBs)
+          ||>> mk_Frees "S" (map2 mk_pred2T Cs Es);
+
+        val recBs = map B_ify recs;
+        val goals = map2 (mk_parametricity_goal lthy (Rs @ Ss)) recs recBs;
+      in
+        Goal.prove_sorry lthy [] [] (Logic.mk_conjunction_balanced goals)
+          (fn {context = ctxt, prems = _} =>
+             mk_rec_transfer_tac names_lthy nn ns (map (certify ctxt) Ss)
+               (map (certify ctxt) Rs) rec_defs ctor_rec_transfer_thms pre_rel_defs
+               live_nesting_rel_eqs)
+        |> Conjunction.elim_balanced nn
+        |> Proof_Context.export names_lthy lthy
+        |> map Thm.close_derivation
+      end;
+
     fun derive_note_induct_recs_thms_for_types
         ((((mapss, rel_injectss, rel_distinctss, setss), (ctrss, _, ctr_defss, ctr_sugars)),
           (recs, rec_defs)), lthy) =
@@ -2008,6 +2036,11 @@
             xtor_co_rec_thms live_nesting_bnfs fp_nesting_bnfs fpTs Cs Xs ctrXs_Tsss abs_inverses
             type_definitions abs_inverses ctrss ctr_defss recs rec_defs lthy;
 
+        val rec_transfer_thmss =
+          if live = 0 then replicate nn []
+          else
+            map single (derive_rec_transfer_thms lthy recs rec_defs ns);
+
         val induct_type_attr = Attrib.internal o K o Induct.induct_type;
         val induct_pred_attr = Attrib.internal o K o Induct.induct_pred;
 
@@ -2040,6 +2073,7 @@
         val notes =
           [(inductN, map single induct_thms, fn T_name => induct_attrs @ [induct_type_attr T_name]),
            (recN, rec_thmss, K rec_attrs),
+           (rec_transferN, rec_transfer_thmss, K []),
            (rel_inductN, rel_induct_thmss, K (rel_induct_attrs @ [induct_pred_attr ""])),
            (simpsN, simp_thmss, K [])]
           |> massage_multi_notes fp_b_names fpTs;