rough and ready induction
authorblanchet
Wed, 12 Sep 2012 23:06:39 +0200
changeset 49342 8ea4bad49ed5
parent 49341 d406979024d1
child 49343 bcce6988f6fa
rough and ready induction
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_util.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 23:06:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 23:06:39 2012 +0200
@@ -41,17 +41,18 @@
 fun strip_map_type (Type (@{type_name fun}, [T as Type _, T'])) = strip_map_type T' |>> cons T
   | strip_map_type T = ([], T);
 
+fun resort_tfree S (TFree (s, _)) = TFree (s, S);
+
 fun typ_subst inst (T as Type (s, Ts)) =
     (case AList.lookup (op =) inst T of
       NONE => Type (s, map (typ_subst inst) Ts)
     | SOME T' => T')
   | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
 
-fun resort_tfree S (TFree (s, _)) = TFree (s, S);
-
 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
 
 fun mk_id T = Const (@{const_name id}, T --> T);
+fun mk_id_fun T = Abs (Name.uu, T, Bound 0);
 
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
@@ -65,6 +66,8 @@
     Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
   end;
 
+fun fold_def_rule n thm = funpow n (fn thm => thm RS fun_cong) (thm RS meta_eq_to_obj_eq) RS sym;
+
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
 fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
@@ -153,7 +156,7 @@
     val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
     val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
         [] => ()
-      | A' :: _ => error ("Extra type variables on rhs: " ^
+      | A' :: _ => error ("Extra type variable on right-hand side: " ^
           quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
 
     fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
@@ -494,6 +497,7 @@
       end;
 
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
+    val pre_set_defss = map set_defs_of_bnf pre_bnfs;
     val map_ids = map map_id_of_bnf nested_bnfs;
 
     fun mk_map Ts Us t =
@@ -514,7 +518,29 @@
       let
         val (induct_thms, induct_thm) =
           let
-            val induct_thm = fp_induct;
+            val sym_ctr_defss = map2 (map2 fold_def_rule) mss ctr_defss;
+
+            val ss = @{simpset} |> fold Simplifier.add_simp
+              @{thms collect_def[abs_def] sum_setl_def[abs_def] sum_setr_def[abs_def]
+                 fsts_def[abs_def] snds_def[abs_def] False_imp_eq all_point_1};
+
+            val induct_thm0 = fp_induct OF (map mk_sumEN_tupled_balanced mss);
+
+            val spurious_fs =
+              Term.add_vars (prop_of induct_thm0) []
+              |> filter (fn (_, Type (@{type_name fun}, [_, T'])) => T' <> HOLogic.boolT
+                | _ => false);
+
+            val cxs =
+              map (fn s as (_, T) =>
+                (certify lthy (Var s), certify lthy (mk_id_fun (domain_type T)))) spurious_fs;
+
+            val induct_thm =
+              Drule.cterm_instantiate cxs induct_thm0
+              |> Tactic.rule_by_tactic lthy (ALLGOALS (REPEAT_DETERM o bound_hyp_subst_tac))
+              |> Local_Defs.unfold lthy
+                (@{thm triv_forall_equality} :: flat sym_ctr_defss @ flat pre_set_defss)
+              |> Simplifier.full_simplify ss;
           in
             `(conj_dests N) induct_thm
           end;
@@ -540,7 +566,7 @@
             fun mk_U maybe_mk_prodT =
               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
 
-            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
+            fun intr_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
               if member (op =) fpTs T then
                 maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
               else if exists_subtype (member (op =) fpTs) T then
@@ -548,9 +574,8 @@
               else
                 [x];
 
-            val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
-            val hxsss =
-              map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
+            val gxsss = map (map (maps (intr_calls giters (K I) (K I) (K I)))) xsss;
+            val hxsss = map (map (maps (intr_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
 
             val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
             val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
@@ -567,13 +592,13 @@
           end;
 
         val common_notes =
-          [(inductN, [induct_thm], []), (*### attribs *)
-           (inductsN, induct_thms, [])] (*### attribs *)
+          (if N > 1 then [(inductN, [induct_thm], [])] (* FIXME: attribs *) else [])
           |> map (fn (thmN, thms, attrs) =>
               ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
 
         val notes =
-          [(itersN, iter_thmss, simp_attrs),
+          [(inductN, map single induct_thms, []), (* FIXME: attribs *)
+           (itersN, iter_thmss, simp_attrs),
            (recsN, rec_thmss, Code.add_default_eqn_attrib :: simp_attrs)]
           |> maps (fn (thmN, thmss, attrs) =>
             map2 (fn b => fn thms =>
@@ -617,7 +642,7 @@
             fun mk_U maybe_mk_sumT =
               typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
 
-            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
+            fun intr_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
               let val T = fastype_of cqf in
                 if exists_subtype (member (op =) Cs) T then
                   build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
@@ -625,8 +650,8 @@
                   cqf
               end;
 
-            val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
-            val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
+            val crgsss' = map (map (map (intr_calls gcoiters (K I) (K I)))) crgsss;
+            val cshsss' = map (map (map (intr_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
 
             val goal_coiterss =
               map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
@@ -672,8 +697,14 @@
         val sel_corec_thmsss =
           map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss sel_thmsss;
 
+        val common_notes =
+          (if N > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
+          |> map (fn (thmN, thms, attrs) =>
+              ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
+
         val notes =
-          [(coitersN, coiter_thmss, []),
+          [(coinductN, map single coinduct_thms, []), (* FIXME: attribs *)
+           (coitersN, coiter_thmss, []),
            (disc_coitersN, disc_coiter_thmss, []),
            (sel_coitersN, map flat sel_coiter_thmsss, []),
            (corecsN, corec_thmss, []),
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Wed Sep 12 23:06:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Wed Sep 12 23:06:39 2012 +0200
@@ -18,7 +18,6 @@
   val caseN: string
   val coN: string
   val coinductN: string
-  val coinductsN: string
   val coiterN: string
   val coitersN: string
   val corecN: string
@@ -42,7 +41,6 @@
   val hsetN: string
   val hset_recN: string
   val inductN: string
-  val inductsN: string
   val injectN: string
   val isNodeN: string
   val iterN: string
@@ -159,7 +157,9 @@
 val algN = "alg"
 val IITN = "IITN"
 val iterN = "iter"
+val itersN = iterN ^ "s"
 val coiterN = coN ^ iterN
+val coitersN = coiterN ^ "s"
 val uniqueN = "_unique"
 val fldN = "fld"
 val unfN = "unf"
@@ -190,7 +190,9 @@
 
 val str_initN = "str_init"
 val recN = "rec"
+val recsN = recN ^ "s"
 val corecN = coN ^ recN
+val corecsN = corecN ^ "s"
 val fld_recN = fldN ^ "_" ^ recN
 val fld_recsN = fld_recN ^ "s"
 val unf_corecN = unfN ^ "_" ^ corecN
@@ -226,16 +228,12 @@
 val set_set_inclN = "set_set_incl"
 
 val caseN = "case"
-val coinductsN = "coinducts"
-val coitersN = "coiters"
-val corecsN = "corecs"
-val disc_coitersN = "disc_coiters"
-val disc_corecsN = "disc_corecs"
-val inductsN = "inducts"
-val itersN = "iters"
-val recsN = "recs"
-val sel_coitersN = "sel_coiters"
-val sel_corecsN = "sel_corecs"
+val discN = "disc"
+val disc_coitersN = discN ^ "_" ^ coitersN
+val disc_corecsN = discN ^ "_" ^ corecsN
+val selN = "sel"
+val sel_coitersN = selN ^ "_" ^ coitersN
+val sel_corecsN = selN ^ "_" ^ corecsN
 
 val mk_common_name = space_implode "_" o map Binding.name_of;
 
--- a/src/HOL/Codatatype/Tools/bnf_util.ML	Wed Sep 12 23:06:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Wed Sep 12 23:06:39 2012 +0200
@@ -567,5 +567,4 @@
 fun mk_unabs_def 0 thm = thm
   | mk_unabs_def n thm = mk_unabs_def (n - 1) thm RS @{thm spec[OF iffD1[OF fun_eq_iff]]};
 
-
 end;