generate 'case_transfer' for (co)datatypes
authordesharna
Fri, 29 Aug 2014 14:21:24 +0200
changeset 58093 6f37a300c82b
parent 58092 4ae52c60603a
child 58094 117c5d2c2642
generate 'case_transfer' for (co)datatypes
src/HOL/Tools/BNF/bnf_def.ML
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_def_sugar_tactics.ML
--- a/src/HOL/Tools/BNF/bnf_def.ML	Thu Aug 28 23:57:26 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_def.ML	Fri Aug 29 14:21:24 2014 +0200
@@ -563,7 +563,7 @@
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   end;
 
-fun build_map_or_rel mk const of_bnf dest ctxt simpleTs build_simple =
+fun build_map_or_rel is_rel mk const of_bnf dest ctxt simpleTs build_simple =
   let
     fun build (TU as (T, U)) =
       if exists (curry (op =) T) simpleTs then
@@ -575,18 +575,19 @@
           (Type (s, Ts), Type (s', Us)) =>
           if s = s' then
             let
-              val bnf = the (bnf_of ctxt s);
-              val live = live_of_bnf bnf;
-              val mapx = mk live Ts Us (of_bnf bnf);
-              val TUs' = map dest (fst (strip_typeN live (fastype_of mapx)));
-            in Term.list_comb (mapx, map build TUs') end
+              val (live, cst0) =
+                if is_rel andalso s = @{type_name fun} then (2, @{term rel_fun})
+                else let val bnf = the (bnf_of ctxt s) in (live_of_bnf bnf, of_bnf bnf) end;
+              val cst = mk live Ts Us cst0;
+              val TUs' = map dest (fst (strip_typeN live (fastype_of cst)));
+            in Term.list_comb (cst, map build TUs') end
           else
             build_simple TU
         | _ => build_simple TU);
   in build end;
 
-val build_map = build_map_or_rel mk_map HOLogic.id_const map_of_bnf dest_funT;
-val build_rel = build_map_or_rel mk_rel HOLogic.eq_const rel_of_bnf dest_pred2T;
+val build_map = build_map_or_rel false mk_map HOLogic.id_const map_of_bnf dest_funT;
+val build_rel = build_map_or_rel true mk_rel HOLogic.eq_const rel_of_bnf dest_pred2T;
 
 fun map_flattened_map_args ctxt s map_args fs =
   let
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Thu Aug 28 23:57:26 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Aug 29 14:21:24 2014 +0200
@@ -98,6 +98,7 @@
 
 val EqN = "Eq_";
 
+val case_transferN = "case_transfer";
 val ctr_transferN = "ctr_transfer";
 val corec_codeN = "corec_code";
 val map_disc_iffN = "map_disc_iff";
@@ -131,17 +132,10 @@
 
 val name_of_set = name_of_const "set";
 
-
-fun mk_parametricity_goals ctxt Rs fs gs =
-  let
-    val prems =
-      map (foldr1 (uncurry mk_rel_fun) o
-          uncurry (map2 (build_the_rel ctxt Rs [])) o
-          pairself (fastype_of #> strip_type #> (fn (Ts, T) => Ts @ [T])))
-        (fs ~~ gs);
-  in
-    map3 (fn prem => fn f => fn g => HOLogic.mk_Trueprop (prem $ f $ g)) prems fs gs
-  end
+fun mk_parametricity_goal ctxt Rs f g =
+  let val prem = build_the_rel ctxt Rs [] (fastype_of f) (fastype_of g) in
+    HOLogic.mk_Trueprop (prem $ f $ g)
+  end;
 
 fun fp_sugar_of ctxt =
   Symtab.lookup (Data.get (Context.Proof ctxt))
@@ -1105,11 +1099,12 @@
     val set_boss = map (map fst o type_args_named_constrained_of_spec) specs;
     val set_bss = map (map (the_default Binding.empty)) set_boss;
 
-    val (((Bs0, Cs), Xs), no_defs_lthy) =
+    val ((((Bs0, Cs as C1 :: _), Es as E1 :: _), Xs), no_defs_lthy) =
       no_defs_lthy0
       |> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
       |> mk_TFrees num_As
       ||>> mk_TFrees nn
+      ||>> mk_TFrees nn
       ||>> variant_tfrees fp_b_names;
 
     fun add_fake_type spec =
@@ -1374,8 +1369,8 @@
               sel_default_eqs) lthy
           end;
 
-        fun derive_maps_sets_rels (ctr_sugar as {case_cong, discs, selss, ctrs, exhaust, disc_thmss,
-            sel_thmss, injects, distincts, ...} : ctr_sugar, lthy) =
+        fun derive_maps_sets_rels (ctr_sugar as {casex, case_cong, case_thms, discs, selss, ctrs,
+            exhaust, disc_thmss, sel_thmss, injects, distincts, ...} : ctr_sugar, lthy) =
           if live = 0 then
             ((([], [], [], []), ctr_sugar), lthy)
           else
@@ -1475,7 +1470,7 @@
                   rel_inject_thms ms;
 
               val (map_disc_iff_thms, map_sel_thms, set_sel_thms, rel_sel_thms, set_intros_thms,
-                   ctr_transfer_thms, (set_cases_thms, set_cases_attrss),
+                   case_transfer_thms, ctr_transfer_thms, (set_cases_thms, set_cases_attrss),
                    (rel_cases_thm, rel_cases_attrs)) =
                 let
                   val live_AsBs = filter (op <>) (As ~~ Bs);
@@ -1498,7 +1493,7 @@
 
                   val ctr_transfer_thms =
                     let
-                      val goals = mk_parametricity_goals names_lthy Rs ctrAs ctrBs;
+                      val goals = map2 (mk_parametricity_goal names_lthy Rs) ctrAs ctrBs;
                     in
                       Goal.prove_sorry lthy [] [] (Logic.mk_conjunction_balanced goals)
                         (K (mk_ctr_transfer_tac rel_intro_thms))
@@ -1695,6 +1690,22 @@
                       (thm, [consumes_attr, case_names_attr, cases_pred_attr ""])
                     end;
 
+                  val case_transfer_thms =
+                    let
+                      val (R, names_lthy) =
+                        yield_singleton (mk_Frees "R") (mk_pred2T C1 E1) names_lthy;
+
+                      val caseA = mk_case As C1 casex;
+                      val caseB = mk_case Bs E1 casex;
+                      val goal = mk_parametricity_goal names_lthy (R :: Rs) caseA caseB;
+                    in
+                      Goal.prove_sorry lthy [] [] goal
+                        (fn {context = ctxt, prems = _} =>
+                          mk_case_transfer_tac ctxt rel_cases_thm case_thms)
+                      |> singleton (Proof_Context.export names_lthy lthy)
+                      |> Thm.close_derivation
+                    end;
+
                   val map_disc_iff_thms =
                     let
                       val discsB = map (mk_disc_or_sel Bs) discs;
@@ -1821,7 +1832,7 @@
                     end;
                 in
                   (map_disc_iff_thms, map_sel_thms, set_sel_thms, rel_sel_thms, set_intros_thms,
-                    ctr_transfer_thms, (set_cases_thms, set_cases_attrss),
+                    case_transfer_thms, ctr_transfer_thms, (set_cases_thms, set_cases_attrss),
                     (rel_cases_thm, rel_cases_attrs))
                 end;
 
@@ -1831,7 +1842,8 @@
                 |> map (fn (thms, attrs) => ((Binding.empty, attrs), [(thms, [])]));
 
               val notes =
-                [(ctr_transferN, ctr_transfer_thms, K []),
+                [(case_transferN, [case_transfer_thms], K []),
+                 (ctr_transferN, ctr_transfer_thms, K []),
                  (mapN, map_thms, K (code_nitpicksimp_attrs @ simp_attrs)),
                  (map_disc_iffN, map_disc_iff_thms, K simp_attrs),
                  (map_selN, map_sel_thms, K []),
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar_tactics.ML	Thu Aug 28 23:57:26 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar_tactics.ML	Fri Aug 29 14:21:24 2014 +0200
@@ -13,6 +13,7 @@
   val basic_sumprod_thms_set: thm list
   val sumprod_thms_rel: thm list
 
+  val mk_case_transfer_tac: Proof.context -> thm -> thm list -> tactic
   val mk_coinduct_tac: Proof.context -> thm list -> int -> int list -> thm -> thm list ->
     thm list -> thm list -> thm list -> thm list -> thm list list -> thm list list list ->
     thm list list list -> tactic
@@ -84,6 +85,18 @@
 
 val co_induct_inst_as_projs_tac = PRIMITIVE oo co_induct_inst_as_projs;
 
+fun mk_case_transfer_tac ctxt rel_cases cases =
+  let
+    val n = length (tl (prems_of rel_cases));
+  in
+    REPEAT_DETERM (HEADGOAL (rtac @{thm rel_funI})) THEN
+    HEADGOAL (etac rel_cases) THEN
+    ALLGOALS (hyp_subst_tac ctxt) THEN
+    unfold_thms_tac ctxt cases THEN
+    ALLGOALS (fn k => (select_prem_tac n (dtac asm_rl) k THEN' rotate_tac ~1) k) THEN
+    ALLGOALS (REPEAT_ALL_NEW (atac ORELSE' rtac refl ORELSE' dtac @{thm rel_funD}))
+  end;
+
 fun mk_exhaust_tac ctxt n ctr_defs ctor_iff_dtor sumEN' =
   unfold_thms_tac ctxt (ctor_iff_dtor :: ctr_defs) THEN HEADGOAL (rtac sumEN') THEN
   HEADGOAL (EVERY' (maps (fn k => [select_prem_tac n (rotate_tac 1) k,
@@ -256,7 +269,7 @@
    EVERY (map11 (fn ct => fn assm => fn exhaust => fn discs => fn sels => fn ctor_defs =>
      fn dtor_ctor => fn ctor_inject => fn abs_inject => fn rel_pre_def => fn abs_inverse =>
       (rtac exhaust THEN_ALL_NEW (rtac exhaust THEN_ALL_NEW
-         (dtac (rotate_prems (~1) (cterm_instantiate_pos [NONE, NONE, NONE, NONE, SOME ct]
+         (dtac (rotate_prems ~1 (cterm_instantiate_pos [NONE, NONE, NONE, NONE, SOME ct]
             @{thm arg_cong2} RS iffD1)) THEN'
           atac THEN' atac THEN' hyp_subst_tac ctxt THEN' dtac assm THEN'
           REPEAT_DETERM o etac conjE))) 1 THEN