src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
changeset 57563 e3e7c86168b4
parent 57562 c1238062184b
child 57565 ab7f39114507
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Wed Jul 16 10:11:25 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Wed Jul 16 10:13:00 2014 +0200
@@ -111,8 +111,9 @@
   fun merge data : T = Symtab.merge (K true) data;
 );
 
-fun choose_relator Rs AB = find_first
-  (fastype_of #> binder_types #> (fn [T1, T2] => AB = (T1, T2))) Rs;
+fun choose_relator Rs AB = find_first (fastype_of #> binder_types #> (fn [A, B] => AB = (A, B))) Rs;
+fun build_the_rel ctxt Rs Ts A B = build_rel ctxt Ts (the o choose_relator Rs) (A, B);
+fun build_rel_app ctxt Rs Ts a b = build_the_rel ctxt Rs Ts (fastype_of a) (fastype_of b) $ a $ b;
 
 fun fp_sugar_of ctxt =
   Symtab.lookup (Data.get (Context.Proof ctxt))
@@ -1390,7 +1391,8 @@
                 map2 (fn th => fn 0 => th RS @{thm eq_True[THEN iffD2]} | _ => th)
                   rel_inject_thms ms;
 
-              val (disc_map_iff_thms, sel_map_thms, sel_set_thms, (rel_cases_thm, rel_cases_attrs)) =
+              val (disc_map_iff_thms, sel_map_thms, sel_set_thms, rel_sel_thms,
+                (rel_cases_thm, rel_cases_attrs)) =
                 let
                   val (((Ds, As), Bs), names_lthy) = lthy
                     |> mk_TFrees (dead_of_bnf fp_bnf)
@@ -1399,24 +1401,55 @@
                   val TA as Type (_, ADs) = mk_T_of_bnf Ds As fp_bnf;
                   val TB as Type (_, BDs) = mk_T_of_bnf Ds Bs fp_bnf;
                   val fTs = map2 (curry op -->) As Bs;
-                  val ((fs, ta), names_lthy) = names_lthy
+                  val rel = mk_rel_of_bnf Ds As Bs fp_bnf;
+                  val ((((fs, Rs), ta), tb), names_lthy) = names_lthy
                     |> mk_Frees "f" fTs
-                    ||>> yield_singleton (mk_Frees "a") TA;
+                    ||>> mk_Frees "R" (map2 mk_pred2T As Bs)
+                    ||>> yield_singleton (mk_Frees "a") TA
+                    ||>> yield_singleton (mk_Frees "b") TB;
                   val map_term = mk_map_of_bnf Ds As Bs fp_bnf;
-                  val discsA = map (mk_disc_or_sel ADs) discs;
-                  val selssA = map (map (mk_disc_or_sel ADs)) selss;
-                  val disc_sel_pairs = flat (map2 (map o pair) discsA selssA);
+                  val discAs = map (mk_disc_or_sel ADs) discs;
+                  val selAss = map (map (mk_disc_or_sel ADs)) selss;
+                  val discAs_selAss = flat (map2 (map o pair) discAs selAss);
+
+                  val rel_sel_thms =
+                    let
+                      val discBs = map (mk_disc_or_sel BDs) discs;
+                      val selBss = map (map (mk_disc_or_sel BDs)) selss;
+                      val n = length discAs;
+
+                      fun mk_rhs n k discA selAs discB selBs =
+                        (if k = n then [] else [HOLogic.mk_eq (discA $ ta, discB $ tb)]) @
+                        (case (selAs, selBs) of
+                           ([], []) => []
+                         | (_ :: _, _ :: _) => [Library.foldr HOLogic.mk_imp
+                           (if n = 1 then [] else [discA $ ta, discB $ tb],
+                            Library.foldr1 HOLogic.mk_conj (map2 (build_rel_app names_lthy Rs [])
+                              (map (rapp ta) selAs) (map (rapp tb) selBs)))]);
+
+                      val goals = if n = 0 then []
+                        else [mk_Trueprop_eq
+                          (build_rel_app names_lthy Rs [] ta tb,
+                           Library.foldr1 HOLogic.mk_conj
+                             (flat (map5 (mk_rhs n) (1 upto n) discAs selAss discBs selBss)))];
+                    in
+                      if null goals then
+                        []
+                      else
+                        Goal.prove_sorry lthy [] [] (Logic.mk_conjunction_balanced goals)
+                          (fn {context = ctxt, prems = _} =>
+                            mk_rel_sel_tac ctxt (certify ctxt ta) (certify ctxt tb) exhaust
+                              (flat disc_thmss) (flat sel_thmss) rel_inject_thms distincts
+                              rel_distinct_thms)
+                          |> Conjunction.elim_balanced (length goals)
+                          |> Proof_Context.export names_lthy lthy
+                    end;
 
                   val (rel_cases_thm, rel_cases_attrs) =
                     let
-                      val rel = mk_rel_of_bnf Ds As Bs fp_bnf;
-                      val (((thesis, Rs), tb), names_lthy) =  names_lthy
-                        |> yield_singleton (mk_Frees "thesis") HOLogic.boolT
-                        |>> HOLogic.mk_Trueprop
-                        ||>> mk_Frees "R" (map2 mk_pred2T As Bs)
-                        ||>> yield_singleton (mk_Frees "b") TB;
+                      val (thesis, names_lthy) = apfst HOLogic.mk_Trueprop
+                        (yield_singleton (mk_Frees "thesis") HOLogic.boolT names_lthy);
 
-                      val _ = apfst HOLogic.mk_Trueprop;
                       val rel_Rs_a_b = list_comb (rel, Rs) $ ta $ tb;
                       val ctrAs = map (mk_ctr ADs) ctrs;
                       val ctrBs = map (mk_ctr BDs) ctrs;
@@ -1464,7 +1497,7 @@
                   val disc_map_iff_thms =
                     let
                       val discsB = map (mk_disc_or_sel BDs) discs;
-                      val discsA_t = map (fn disc1 => Term.betapply (disc1, ta)) discsA;
+                      val discsA_t = map (fn disc1 => Term.betapply (disc1, ta)) discAs;
 
                       fun mk_goal (discA_t, discB) =
                         if head_of discA_t aconv HOLogic.Not orelse is_refl_bool discA_t then
@@ -1505,7 +1538,7 @@
                           if is_refl_bool prem then concl
                           else Logic.mk_implies (HOLogic.mk_Trueprop prem, concl)
                         end;
-                      val goals = map mk_goal disc_sel_pairs;
+                      val goals = map mk_goal discAs_selAss;
                     in
                       if null goals then
                         []
@@ -1570,7 +1603,7 @@
                             ([], ctxt)
                         end;
                       val (goals, names_lthy) = apfst (flat o flat) (fold_map (fn (disc, sel) =>
-                        fold_map (mk_goal disc sel) setsA) disc_sel_pairs names_lthy);
+                        fold_map (mk_goal disc sel) setsA) discAs_selAss names_lthy);
                     in
                       if null goals then
                         []
@@ -1583,7 +1616,8 @@
                           |> Proof_Context.export names_lthy lthy
                     end;
                 in
-                  (disc_map_iff_thms, sel_map_thms, sel_set_thms, (rel_cases_thm, rel_cases_attrs))
+                  (disc_map_iff_thms, sel_map_thms, sel_set_thms, rel_sel_thms,
+                    (rel_cases_thm, rel_cases_attrs))
                 end;
 
               val anonymous_notes =
@@ -1598,6 +1632,7 @@
                  (rel_distinctN, rel_distinct_thms, simp_attrs),
                  (rel_injectN, rel_inject_thms, simp_attrs),
                  (rel_introsN, rel_intro_thms, []),
+                 (rel_selN, rel_sel_thms, []),
                  (setN, set_thms, code_nitpicksimp_attrs @ simp_attrs),
                  (sel_mapN, sel_map_thms, []),
                  (sel_setN, sel_set_thms, []),