derive induction via backward proof, to ensure that the premises are in the right order for constructors like "X x y x" where x and y are mutually recursive
authorblanchet
Fri, 14 Sep 2012 12:09:27 +0200
changeset 49361 cc1d39529dd1
parent 49353 023be49d7fb8
child 49362 1271aca16aed
derive induction via backward proof, to ensure that the premises are in the right order for constructors like "X x y x" where x and y are mutually recursive
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Fri Sep 14 10:01:42 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Fri Sep 14 12:09:27 2012 +0200
@@ -95,6 +95,8 @@
       else ();
 
     val N = length specs;
+    val fp_bs = map type_binding_of specs;
+    val fp_common_name = mk_common_name fp_bs;
 
     fun prepare_type_arg (ty, c) =
       let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
@@ -105,11 +107,12 @@
     val unsorted_Ass0 = map (map (resort_tfree HOLogic.typeS)) Ass0;
     val unsorted_As = Library.foldr1 merge_type_args unsorted_Ass0;
 
-    val ((Bs, Cs), no_defs_lthy) =
+    val (((Bs, Cs), vs'), no_defs_lthy) =
       no_defs_lthy0
       |> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
       |> mk_TFrees N
-      ||>> mk_TFrees N;
+      ||>> mk_TFrees N
+      ||>> Variable.variant_fixes (map Binding.name_of fp_bs);
 
     (* TODO: cleaner handling of fake contexts, without "background_theory" *)
     (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
@@ -123,9 +126,6 @@
       Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
         unsorted_As);
 
-    val fp_bs = map type_binding_of specs;
-    val fp_common_name = mk_common_name fp_bs;
-
     val fake_Ts = map mk_fake_T fp_bs;
 
     val mixfixes = map mixfix_of specs;
@@ -204,6 +204,7 @@
     val flds = map (mk_fld As) flds0;
 
     val fpTs = map (domain_type o fastype_of) unfs;
+    val vs = map2 (curry Free) vs' fpTs;
 
     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
     val ns = map length ctr_Tsss;
@@ -328,19 +329,18 @@
             (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss))))
         end;
 
-    fun define_ctrs_case_for_type ((((((((((((((((((fp_b, fpT), C), fld), unf), fp_iter), fp_rec),
-          fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_bindings), ctr_mixfixes), ctr_Tss),
-        disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
+    fun define_ctrs_case_for_type (((((((((((((((((((fp_b, fpT), C), v), fld), unf), fp_iter),
+          fp_rec), fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_bindings), ctr_mixfixes),
+        ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
       let
         val unfT = domain_type (fastype_of fld);
         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
         val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
-        val ((((u, v), fs), xss), _) =
+        val (((u, fs), xss), _) =
           no_defs_lthy
           |> yield_singleton (mk_Frees "u") unfT
-          ||>> yield_singleton (mk_Frees "v") fpT
           ||>> mk_Frees "f" case_Ts
           ||>> mk_Freess "x" ctr_Tss;
 
@@ -518,29 +518,32 @@
       let
         val (induct_thms, induct_thm) =
           let
-            val sym_ctr_defss = map2 (map2 fold_def_rule) mss ctr_defss;
+            val (ps, names_lthy) =
+              lthy
+              |> mk_Frees "P" (map mk_predT fpTs);
 
-            val ss = @{simpset} |> fold Simplifier.add_simp
-              @{thms collect_def[abs_def] sum_setl_def[abs_def] sum_setr_def[abs_def]
-                 fsts_def[abs_def] snds_def[abs_def] False_imp_eq all_point_1};
-
-            val induct_thm0 = fp_induct OF (map mk_sumEN_tupled_balanced mss);
+            fun mk_prem_prem (x as Free (_, T)) =
+              map HOLogic.mk_Trueprop
+                (case find_index (curry (op =) T) fpTs of
+                  ~1 => []
+                | i => [nth ps i $ x]);
 
-            val spurious_fs =
-              Term.add_vars (prop_of induct_thm0) []
-              |> filter (fn (_, Type (@{type_name fun}, [_, T'])) => T' <> HOLogic.boolT
-                | _ => false);
+            fun mk_prem p ctr ctr_Ts =
+              let val (xs, _) = names_lthy |> mk_Frees "x" ctr_Ts in
+                fold_rev Logic.all xs
+                  (Logic.list_implies (maps mk_prem_prem xs,
+                     HOLogic.mk_Trueprop (p $ Term.list_comb (ctr, xs))))
+              end;
 
-            val cxs =
-              map (fn s as (_, T) =>
-                (certify lthy (Var s), certify lthy (mk_id_fun (domain_type T)))) spurious_fs;
+            val goal =
+              fold_rev (fold_rev Logic.all) [ps, vs]
+                (Library.foldr Logic.list_implies (map3 (map2 o mk_prem) ps ctrss ctr_Tsss,
+                   HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj
+                     (map2 (curry (op $)) ps vs))));
 
             val induct_thm =
-              Drule.cterm_instantiate cxs induct_thm0
-              |> Tactic.rule_by_tactic lthy (ALLGOALS (REPEAT_DETERM o bound_hyp_subst_tac))
-              |> Local_Defs.unfold lthy
-                (@{thm triv_forall_equality} :: flat sym_ctr_defss @ flat pre_set_defss)
-              |> Simplifier.full_simplify ss;
+              Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
+                Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt));
           in
             `(conj_dests N) induct_thm
           end;
@@ -715,16 +718,17 @@
               SOME ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs),
                 [(thms, [])])) (fp_bs ~~ thmss));
       in
-        lthy |> Local_Theory.notes notes |> snd
+        lthy |> Local_Theory.notes (common_notes @ notes) |> snd
       end;
 
     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
       fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list11
 
     val lthy' = lthy
-      |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
-        fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
-        ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)
+      |> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ vs ~~ flds ~~ unfs ~~
+        fp_iters ~~ fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~
+        ctr_bindingss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~
+        raw_sel_defaultsss)
       |>> split_list |> wrap_types_and_define_iter_likes
       |> (if lfp then derive_induct_iter_rec_thms_for_types
           else derive_coinduct_coiter_corec_thms_for_types);