reimplement proof automation for coinduct rules
authorhuffman
Sat, 16 Oct 2010 15:26:30 -0700
changeset 40025 876689e6bbdf
parent 40024 a0f760ef6995
child 40026 8f8f18a88685
reimplement proof automation for coinduct rules
src/HOLCF/Library/Stream.thy
src/HOLCF/Tools/Domain/domain_theorems.ML
--- a/src/HOLCF/Library/Stream.thy	Sat Oct 16 14:41:11 2010 -0700
+++ b/src/HOLCF/Library/Stream.thy	Sat Oct 16 15:26:30 2010 -0700
@@ -265,7 +265,7 @@
  apply (simp add: stream.bisim_def,clarsimp)
  apply (drule spec, drule spec, drule (1) mp)
  apply (case_tac "x", simp)
- apply (case_tac "x'", simp)
+ apply (case_tac "y", simp)
 by auto
 
 
--- a/src/HOLCF/Tools/Domain/domain_theorems.ML	Sat Oct 16 14:41:11 2010 -0700
+++ b/src/HOLCF/Tools/Domain/domain_theorems.ML	Sat Oct 16 15:26:30 2010 -0700
@@ -297,143 +297,125 @@
 (******************************************************************************)
 
 fun prove_coinduction
-    (comp_dbind : binding, eqs : Domain_Library.eq list)
-    (take_rews : thm list)
-    (take_lemmas : thm list)
+    (comp_dbind : binding, dbinds : binding list)
+    (constr_infos : Domain_Constructors.constr_info list)
+    (take_info : Domain_Take_Proofs.take_induct_info)
+    (take_rews : thm list list)
     (thy : theory) : theory =
 let
-open Domain_Library;
+  val comp_dname = Sign.full_name thy comp_dbind;
 
-val dnames = map (fst o fst) eqs;
-val comp_dname = Sign.full_name thy comp_dbind;
-fun dc_take dn = %%:(dn^"_take");
-val x_name = idx_name dnames "x"; 
-val n_eqs = length eqs;
+  val iso_infos = map #iso_info constr_infos;
+  val newTs = map #absT iso_infos;
+
+  val {take_consts, take_0_thms, take_lemma_thms, ...} = take_info;
 
-(* ----- define bisimulation predicate -------------------------------------- *)
+  val R_names = Datatype_Prop.indexify_names (map (K "R") newTs);
+  val R_types = map (fn T => T --> T --> boolT) newTs;
+  val Rs = map Free (R_names ~~ R_types);
+  val n = Free ("n", natT);
+  val reserved = "x" :: "y" :: R_names;
 
-local
-  open HOLCF_Library
-  val dtypes  = map (Type o fst) eqs;
-  val relprod = mk_tupleT (map (fn tp => tp --> tp --> boolT) dtypes);
+  (* declare bisimulation predicate *)
   val bisim_bind = Binding.suffix_name "_bisim" comp_dbind;
-  val bisim_type = relprod --> boolT;
-in
+  val bisim_type = R_types ---> boolT;
   val (bisim_const, thy) =
       Sign.declare_const ((bisim_bind, bisim_type), NoSyn) thy;
-end;
-
-local
 
-  fun legacy_infer_term thy t =
-      singleton (Syntax.check_terms (ProofContext.init_global thy)) (intern_term thy t);
-  fun legacy_infer_prop thy t = legacy_infer_term thy (Type.constraint propT t);
-  fun infer_props thy = map (apsnd (legacy_infer_prop thy));
-  fun add_defs_i x = Global_Theory.add_defs false (map Thm.no_attributes x);
-  fun add_defs_infer defs thy = add_defs_i (infer_props thy defs) thy;
+  (* define bisimulation predicate *)
+  local
+    fun one_con T (con, args) =
+      let
+        val Ts = map snd args;
+        val ns1 = Name.variant_list reserved (Datatype_Prop.make_tnames Ts);
+        val ns2 = map (fn n => n^"'") ns1;
+        val vs1 = map Free (ns1 ~~ Ts);
+        val vs2 = map Free (ns2 ~~ Ts);
+        val eq1 = mk_eq (Free ("x", T), list_ccomb (con, vs1));
+        val eq2 = mk_eq (Free ("y", T), list_ccomb (con, vs2));
+        fun rel ((v1, v2), T) =
+            case AList.lookup (op =) (newTs ~~ Rs) T of
+              NONE => mk_eq (v1, v2) | SOME r => r $ v1 $ v2;
+        val eqs = foldr1 mk_conj (map rel (vs1 ~~ vs2 ~~ Ts) @ [eq1, eq2]);
+      in
+        Library.foldr mk_ex (vs1 @ vs2, eqs)
+      end;
+    fun one_eq ((T, R), cons) =
+      let
+        val x = Free ("x", T);
+        val y = Free ("y", T);
+        val disj1 = mk_conj (mk_eq (x, mk_bottom T), mk_eq (y, mk_bottom T));
+        val disjs = disj1 :: map (one_con T) cons;
+      in
+        mk_all (x, mk_all (y, mk_imp (R $ x $ y, foldr1 mk_disj disjs)))
+      end;
+    val conjs = map one_eq (newTs ~~ Rs ~~ map #con_specs constr_infos);
+    val bisim_rhs = lambdas Rs (Library.foldr1 mk_conj conjs);
+    val bisim_eqn = Logic.mk_equals (bisim_const, bisim_rhs);
+  in
+    val (bisim_def_thm, thy) = thy |>
+        yield_singleton (Global_Theory.add_defs false)
+         ((Binding.qualified true "bisim_def" comp_dbind, bisim_eqn), []);
+  end (* local *)
 
-  fun one_con (con, args) =
+  (* prove coinduction lemma *)
+  val coind_lemma =
     let
-      val nonrec_args = filter_out is_rec args;
-      val    rec_args = filter is_rec args;
-      val    recs_cnt = length rec_args;
-      val allargs     = nonrec_args @ rec_args
-                        @ map (upd_vname (fn s=> s^"'")) rec_args;
-      val allvns      = map vname allargs;
-      fun vname_arg s arg = if is_rec arg then vname arg^s else vname arg;
-      val vns1        = map (vname_arg "" ) args;
-      val vns2        = map (vname_arg "'") args;
-      val allargs_cnt = length nonrec_args + 2*recs_cnt;
-      val rec_idxs    = (recs_cnt-1) downto 0;
-      val nonlazy_idxs = map snd (filter_out (fn (arg,_) => is_lazy arg)
-                                             (allargs~~((allargs_cnt-1) downto 0)));
-      fun rel_app i ra = proj (Bound(allargs_cnt+2)) eqs (rec_of ra) $ 
-                              Bound (2*recs_cnt-i) $ Bound (recs_cnt-i);
-      val capps =
-          List.foldr
-            mk_conj
-            (mk_conj(
-             Bound(allargs_cnt+1)===list_ccomb(%%:con,map (bound_arg allvns) vns1),
-             Bound(allargs_cnt+0)===list_ccomb(%%:con,map (bound_arg allvns) vns2)))
-            (mapn rel_app 1 rec_args);
+      val assm = mk_trp (list_comb (bisim_const, Rs));
+      fun one ((T, R), take_const) =
+        let
+          val x = Free ("x", T);
+          val y = Free ("y", T);
+          val lhs = mk_capply (take_const $ n, x);
+          val rhs = mk_capply (take_const $ n, y);
+        in
+          mk_all (x, mk_all (y, mk_imp (R $ x $ y, mk_eq (lhs, rhs))))
+        end;
+      val goal =
+          mk_trp (foldr1 mk_conj (map one (newTs ~~ Rs ~~ take_consts)));
+      val rules = @{thm Rep_CFun_strict1} :: take_0_thms;
+      fun tacf {prems, context} =
+        let
+          val prem' = rewrite_rule [bisim_def_thm] (hd prems);
+          val prems' = Project_Rule.projections context prem';
+          val dests = map (fn th => th RS spec RS spec RS mp) prems';
+          fun one_tac (dest, rews) =
+              dtac dest 1 THEN safe_tac HOL_cs THEN
+              ALLGOALS (asm_simp_tac (HOL_basic_ss addsimps rews));
+        in
+          rtac @{thm nat.induct} 1 THEN
+          simp_tac (HOL_ss addsimps rules) 1 THEN
+          safe_tac HOL_cs THEN
+          EVERY (map one_tac (dests ~~ take_rews))
+        end
     in
-      List.foldr
-        mk_ex
-        (Library.foldr mk_conj
-                       (map (defined o Bound) nonlazy_idxs,capps)) allvns
+      Goal.prove_global thy [] [assm] goal tacf
     end;
-  fun one_comp n (_,cons) =
-      mk_all (x_name(n+1),
-      mk_all (x_name(n+1)^"'",
-      mk_imp (proj (Bound 2) eqs n $ Bound 1 $ Bound 0,
-      foldr1 mk_disj (mk_conj(Bound 1 === UU,Bound 0 === UU)
-                      ::map one_con cons))));
-  val bisim_eqn =
-      %%:(comp_dname^"_bisim") ==
-         mk_lam("R", foldr1 mk_conj (mapn one_comp 0 eqs));
+
+  (* prove individual coinduction rules *)
+  fun prove_coind ((T, R), take_lemma) =
+    let
+      val x = Free ("x", T);
+      val y = Free ("y", T);
+      val assm1 = mk_trp (list_comb (bisim_const, Rs));
+      val assm2 = mk_trp (R $ x $ y);
+      val goal = mk_trp (mk_eq (x, y));
+      fun tacf {prems, context} =
+        let
+          val rule = hd prems RS coind_lemma;
+        in
+          rtac take_lemma 1 THEN
+          asm_simp_tac (HOL_basic_ss addsimps (rule :: prems)) 1
+        end;
+    in
+      Goal.prove_global thy [] [assm1, assm2] goal tacf
+    end;
+  val coinds = map prove_coind (newTs ~~ Rs ~~ take_lemma_thms);
+  val coind_binds = map (Binding.qualified true "coinduct") dbinds;
 
 in
-  val (ax_bisim_def, thy) =
-      yield_singleton add_defs_infer
-        (Binding.qualified true "bisim_def" comp_dbind, bisim_eqn) thy;
-end; (* local *)
-
-(* ----- theorem concerning coinduction ------------------------------------- *)
-
-local
-  val pg = pg' thy;
-  val xs = mapn (fn n => K (x_name n)) 1 dnames;
-  fun bnd_arg n i = Bound(2*(n_eqs - n)-i-1);
-  val take_ss = HOL_ss addsimps (@{thm Rep_CFun_strict1} :: take_rews);
-  val sproj = prj (fn s => K("fst("^s^")")) (fn s => K("snd("^s^")"));
-  val _ = trace " Proving coind_lemma...";
-  val coind_lemma =
-    let
-      fun mk_prj n _ = proj (%:"R") eqs n $ bnd_arg n 0 $ bnd_arg n 1;
-      fun mk_eqn n dn =
-        (dc_take dn $ %:"n" ` bnd_arg n 0) ===
-        (dc_take dn $ %:"n" ` bnd_arg n 1);
-      fun mk_all2 (x,t) = mk_all (x, mk_all (x^"'", t));
-      val goal =
-        mk_trp (mk_imp (%%:(comp_dname^"_bisim") $ %:"R",
-          Library.foldr mk_all2 (xs,
-            Library.foldr mk_imp (mapn mk_prj 0 dnames,
-              foldr1 mk_conj (mapn mk_eqn 0 dnames)))));
-      fun x_tacs ctxt n x = [
-        rotate_tac (n+1) 1,
-        etac all2E 1,
-        eres_inst_tac ctxt [(("P", 1), sproj "R" eqs n^" "^x^" "^x^"'")] (mp RS disjE) 1,
-        TRY (safe_tac HOL_cs),
-        REPEAT (CHANGED (asm_simp_tac take_ss 1))];
-      fun tacs ctxt = [
-        rtac impI 1,
-        InductTacs.induct_tac ctxt [[SOME "n"]] 1,
-        simp_tac take_ss 1,
-        safe_tac HOL_cs] @
-        flat (mapn (x_tacs ctxt) 0 xs);
-    in pg [ax_bisim_def] goal tacs end;
-in
-  val _ = trace " Proving coind...";
-  val coind = 
-    let
-      fun mk_prj n x = mk_trp (proj (%:"R") eqs n $ %:x $ %:(x^"'"));
-      fun mk_eqn x = %:x === %:(x^"'");
-      val goal =
-        mk_trp (%%:(comp_dname^"_bisim") $ %:"R") ===>
-          Logic.list_implies (mapn mk_prj 0 xs,
-            mk_trp (foldr1 mk_conj (map mk_eqn xs)));
-      val tacs =
-        TRY (safe_tac HOL_cs) ::
-        maps (fn take_lemma => [
-          rtac take_lemma 1,
-          cut_facts_tac [coind_lemma] 1,
-          fast_tac HOL_cs 1])
-        take_lemmas;
-    in pg [] goal (K tacs) end;
-end; (* local *)
-
-in thy |> snd o Global_Theory.add_thmss
-    [((Binding.qualified true "coinduct" comp_dbind, [coind]), [])]
+  thy |> snd o Global_Theory.add_thms
+    (map Thm.no_attributes (coind_binds ~~ coinds))
 end; (* let *)
 
 (******************************************************************************)
@@ -500,7 +482,7 @@
 
 val thy =
     if is_indirect then thy else
-    prove_coinduction (comp_dbind, eqs) take_rews take_lemma_thms thy;
+    prove_coinduction (comp_dbind, dbinds) constr_infos take_info take_rewss thy;
 
 in
   (take_rews, thy)