reimplement proof automation for coinduct rules
authorhuffman
Sat Oct 16 15:26:30 2010 -0700 (2010-10-16)
changeset 40025876689e6bbdf
parent 40024 a0f760ef6995
child 40026 8f8f18a88685
reimplement proof automation for coinduct rules
src/HOLCF/Library/Stream.thy
src/HOLCF/Tools/Domain/domain_theorems.ML
     1.1 --- a/src/HOLCF/Library/Stream.thy	Sat Oct 16 14:41:11 2010 -0700
     1.2 +++ b/src/HOLCF/Library/Stream.thy	Sat Oct 16 15:26:30 2010 -0700
     1.3 @@ -265,7 +265,7 @@
     1.4   apply (simp add: stream.bisim_def,clarsimp)
     1.5   apply (drule spec, drule spec, drule (1) mp)
     1.6   apply (case_tac "x", simp)
     1.7 - apply (case_tac "x'", simp)
     1.8 + apply (case_tac "y", simp)
     1.9  by auto
    1.10  
    1.11  
     2.1 --- a/src/HOLCF/Tools/Domain/domain_theorems.ML	Sat Oct 16 14:41:11 2010 -0700
     2.2 +++ b/src/HOLCF/Tools/Domain/domain_theorems.ML	Sat Oct 16 15:26:30 2010 -0700
     2.3 @@ -297,143 +297,125 @@
     2.4  (******************************************************************************)
     2.5  
     2.6  fun prove_coinduction
     2.7 -    (comp_dbind : binding, eqs : Domain_Library.eq list)
     2.8 -    (take_rews : thm list)
     2.9 -    (take_lemmas : thm list)
    2.10 +    (comp_dbind : binding, dbinds : binding list)
    2.11 +    (constr_infos : Domain_Constructors.constr_info list)
    2.12 +    (take_info : Domain_Take_Proofs.take_induct_info)
    2.13 +    (take_rews : thm list list)
    2.14      (thy : theory) : theory =
    2.15  let
    2.16 -open Domain_Library;
    2.17 +  val comp_dname = Sign.full_name thy comp_dbind;
    2.18  
    2.19 -val dnames = map (fst o fst) eqs;
    2.20 -val comp_dname = Sign.full_name thy comp_dbind;
    2.21 -fun dc_take dn = %%:(dn^"_take");
    2.22 -val x_name = idx_name dnames "x"; 
    2.23 -val n_eqs = length eqs;
    2.24 +  val iso_infos = map #iso_info constr_infos;
    2.25 +  val newTs = map #absT iso_infos;
    2.26 +
    2.27 +  val {take_consts, take_0_thms, take_lemma_thms, ...} = take_info;
    2.28  
    2.29 -(* ----- define bisimulation predicate -------------------------------------- *)
    2.30 +  val R_names = Datatype_Prop.indexify_names (map (K "R") newTs);
    2.31 +  val R_types = map (fn T => T --> T --> boolT) newTs;
    2.32 +  val Rs = map Free (R_names ~~ R_types);
    2.33 +  val n = Free ("n", natT);
    2.34 +  val reserved = "x" :: "y" :: R_names;
    2.35  
    2.36 -local
    2.37 -  open HOLCF_Library
    2.38 -  val dtypes  = map (Type o fst) eqs;
    2.39 -  val relprod = mk_tupleT (map (fn tp => tp --> tp --> boolT) dtypes);
    2.40 +  (* declare bisimulation predicate *)
    2.41    val bisim_bind = Binding.suffix_name "_bisim" comp_dbind;
    2.42 -  val bisim_type = relprod --> boolT;
    2.43 -in
    2.44 +  val bisim_type = R_types ---> boolT;
    2.45    val (bisim_const, thy) =
    2.46        Sign.declare_const ((bisim_bind, bisim_type), NoSyn) thy;
    2.47 -end;
    2.48 -
    2.49 -local
    2.50  
    2.51 -  fun legacy_infer_term thy t =
    2.52 -      singleton (Syntax.check_terms (ProofContext.init_global thy)) (intern_term thy t);
    2.53 -  fun legacy_infer_prop thy t = legacy_infer_term thy (Type.constraint propT t);
    2.54 -  fun infer_props thy = map (apsnd (legacy_infer_prop thy));
    2.55 -  fun add_defs_i x = Global_Theory.add_defs false (map Thm.no_attributes x);
    2.56 -  fun add_defs_infer defs thy = add_defs_i (infer_props thy defs) thy;
    2.57 +  (* define bisimulation predicate *)
    2.58 +  local
    2.59 +    fun one_con T (con, args) =
    2.60 +      let
    2.61 +        val Ts = map snd args;
    2.62 +        val ns1 = Name.variant_list reserved (Datatype_Prop.make_tnames Ts);
    2.63 +        val ns2 = map (fn n => n^"'") ns1;
    2.64 +        val vs1 = map Free (ns1 ~~ Ts);
    2.65 +        val vs2 = map Free (ns2 ~~ Ts);
    2.66 +        val eq1 = mk_eq (Free ("x", T), list_ccomb (con, vs1));
    2.67 +        val eq2 = mk_eq (Free ("y", T), list_ccomb (con, vs2));
    2.68 +        fun rel ((v1, v2), T) =
    2.69 +            case AList.lookup (op =) (newTs ~~ Rs) T of
    2.70 +              NONE => mk_eq (v1, v2) | SOME r => r $ v1 $ v2;
    2.71 +        val eqs = foldr1 mk_conj (map rel (vs1 ~~ vs2 ~~ Ts) @ [eq1, eq2]);
    2.72 +      in
    2.73 +        Library.foldr mk_ex (vs1 @ vs2, eqs)
    2.74 +      end;
    2.75 +    fun one_eq ((T, R), cons) =
    2.76 +      let
    2.77 +        val x = Free ("x", T);
    2.78 +        val y = Free ("y", T);
    2.79 +        val disj1 = mk_conj (mk_eq (x, mk_bottom T), mk_eq (y, mk_bottom T));
    2.80 +        val disjs = disj1 :: map (one_con T) cons;
    2.81 +      in
    2.82 +        mk_all (x, mk_all (y, mk_imp (R $ x $ y, foldr1 mk_disj disjs)))
    2.83 +      end;
    2.84 +    val conjs = map one_eq (newTs ~~ Rs ~~ map #con_specs constr_infos);
    2.85 +    val bisim_rhs = lambdas Rs (Library.foldr1 mk_conj conjs);
    2.86 +    val bisim_eqn = Logic.mk_equals (bisim_const, bisim_rhs);
    2.87 +  in
    2.88 +    val (bisim_def_thm, thy) = thy |>
    2.89 +        yield_singleton (Global_Theory.add_defs false)
    2.90 +         ((Binding.qualified true "bisim_def" comp_dbind, bisim_eqn), []);
    2.91 +  end (* local *)
    2.92  
    2.93 -  fun one_con (con, args) =
    2.94 +  (* prove coinduction lemma *)
    2.95 +  val coind_lemma =
    2.96      let
    2.97 -      val nonrec_args = filter_out is_rec args;
    2.98 -      val    rec_args = filter is_rec args;
    2.99 -      val    recs_cnt = length rec_args;
   2.100 -      val allargs     = nonrec_args @ rec_args
   2.101 -                        @ map (upd_vname (fn s=> s^"'")) rec_args;
   2.102 -      val allvns      = map vname allargs;
   2.103 -      fun vname_arg s arg = if is_rec arg then vname arg^s else vname arg;
   2.104 -      val vns1        = map (vname_arg "" ) args;
   2.105 -      val vns2        = map (vname_arg "'") args;
   2.106 -      val allargs_cnt = length nonrec_args + 2*recs_cnt;
   2.107 -      val rec_idxs    = (recs_cnt-1) downto 0;
   2.108 -      val nonlazy_idxs = map snd (filter_out (fn (arg,_) => is_lazy arg)
   2.109 -                                             (allargs~~((allargs_cnt-1) downto 0)));
   2.110 -      fun rel_app i ra = proj (Bound(allargs_cnt+2)) eqs (rec_of ra) $ 
   2.111 -                              Bound (2*recs_cnt-i) $ Bound (recs_cnt-i);
   2.112 -      val capps =
   2.113 -          List.foldr
   2.114 -            mk_conj
   2.115 -            (mk_conj(
   2.116 -             Bound(allargs_cnt+1)===list_ccomb(%%:con,map (bound_arg allvns) vns1),
   2.117 -             Bound(allargs_cnt+0)===list_ccomb(%%:con,map (bound_arg allvns) vns2)))
   2.118 -            (mapn rel_app 1 rec_args);
   2.119 +      val assm = mk_trp (list_comb (bisim_const, Rs));
   2.120 +      fun one ((T, R), take_const) =
   2.121 +        let
   2.122 +          val x = Free ("x", T);
   2.123 +          val y = Free ("y", T);
   2.124 +          val lhs = mk_capply (take_const $ n, x);
   2.125 +          val rhs = mk_capply (take_const $ n, y);
   2.126 +        in
   2.127 +          mk_all (x, mk_all (y, mk_imp (R $ x $ y, mk_eq (lhs, rhs))))
   2.128 +        end;
   2.129 +      val goal =
   2.130 +          mk_trp (foldr1 mk_conj (map one (newTs ~~ Rs ~~ take_consts)));
   2.131 +      val rules = @{thm Rep_CFun_strict1} :: take_0_thms;
   2.132 +      fun tacf {prems, context} =
   2.133 +        let
   2.134 +          val prem' = rewrite_rule [bisim_def_thm] (hd prems);
   2.135 +          val prems' = Project_Rule.projections context prem';
   2.136 +          val dests = map (fn th => th RS spec RS spec RS mp) prems';
   2.137 +          fun one_tac (dest, rews) =
   2.138 +              dtac dest 1 THEN safe_tac HOL_cs THEN
   2.139 +              ALLGOALS (asm_simp_tac (HOL_basic_ss addsimps rews));
   2.140 +        in
   2.141 +          rtac @{thm nat.induct} 1 THEN
   2.142 +          simp_tac (HOL_ss addsimps rules) 1 THEN
   2.143 +          safe_tac HOL_cs THEN
   2.144 +          EVERY (map one_tac (dests ~~ take_rews))
   2.145 +        end
   2.146      in
   2.147 -      List.foldr
   2.148 -        mk_ex
   2.149 -        (Library.foldr mk_conj
   2.150 -                       (map (defined o Bound) nonlazy_idxs,capps)) allvns
   2.151 +      Goal.prove_global thy [] [assm] goal tacf
   2.152      end;
   2.153 -  fun one_comp n (_,cons) =
   2.154 -      mk_all (x_name(n+1),
   2.155 -      mk_all (x_name(n+1)^"'",
   2.156 -      mk_imp (proj (Bound 2) eqs n $ Bound 1 $ Bound 0,
   2.157 -      foldr1 mk_disj (mk_conj(Bound 1 === UU,Bound 0 === UU)
   2.158 -                      ::map one_con cons))));
   2.159 -  val bisim_eqn =
   2.160 -      %%:(comp_dname^"_bisim") ==
   2.161 -         mk_lam("R", foldr1 mk_conj (mapn one_comp 0 eqs));
   2.162 +
   2.163 +  (* prove individual coinduction rules *)
   2.164 +  fun prove_coind ((T, R), take_lemma) =
   2.165 +    let
   2.166 +      val x = Free ("x", T);
   2.167 +      val y = Free ("y", T);
   2.168 +      val assm1 = mk_trp (list_comb (bisim_const, Rs));
   2.169 +      val assm2 = mk_trp (R $ x $ y);
   2.170 +      val goal = mk_trp (mk_eq (x, y));
   2.171 +      fun tacf {prems, context} =
   2.172 +        let
   2.173 +          val rule = hd prems RS coind_lemma;
   2.174 +        in
   2.175 +          rtac take_lemma 1 THEN
   2.176 +          asm_simp_tac (HOL_basic_ss addsimps (rule :: prems)) 1
   2.177 +        end;
   2.178 +    in
   2.179 +      Goal.prove_global thy [] [assm1, assm2] goal tacf
   2.180 +    end;
   2.181 +  val coinds = map prove_coind (newTs ~~ Rs ~~ take_lemma_thms);
   2.182 +  val coind_binds = map (Binding.qualified true "coinduct") dbinds;
   2.183  
   2.184  in
   2.185 -  val (ax_bisim_def, thy) =
   2.186 -      yield_singleton add_defs_infer
   2.187 -        (Binding.qualified true "bisim_def" comp_dbind, bisim_eqn) thy;
   2.188 -end; (* local *)
   2.189 -
   2.190 -(* ----- theorem concerning coinduction ------------------------------------- *)
   2.191 -
   2.192 -local
   2.193 -  val pg = pg' thy;
   2.194 -  val xs = mapn (fn n => K (x_name n)) 1 dnames;
   2.195 -  fun bnd_arg n i = Bound(2*(n_eqs - n)-i-1);
   2.196 -  val take_ss = HOL_ss addsimps (@{thm Rep_CFun_strict1} :: take_rews);
   2.197 -  val sproj = prj (fn s => K("fst("^s^")")) (fn s => K("snd("^s^")"));
   2.198 -  val _ = trace " Proving coind_lemma...";
   2.199 -  val coind_lemma =
   2.200 -    let
   2.201 -      fun mk_prj n _ = proj (%:"R") eqs n $ bnd_arg n 0 $ bnd_arg n 1;
   2.202 -      fun mk_eqn n dn =
   2.203 -        (dc_take dn $ %:"n" ` bnd_arg n 0) ===
   2.204 -        (dc_take dn $ %:"n" ` bnd_arg n 1);
   2.205 -      fun mk_all2 (x,t) = mk_all (x, mk_all (x^"'", t));
   2.206 -      val goal =
   2.207 -        mk_trp (mk_imp (%%:(comp_dname^"_bisim") $ %:"R",
   2.208 -          Library.foldr mk_all2 (xs,
   2.209 -            Library.foldr mk_imp (mapn mk_prj 0 dnames,
   2.210 -              foldr1 mk_conj (mapn mk_eqn 0 dnames)))));
   2.211 -      fun x_tacs ctxt n x = [
   2.212 -        rotate_tac (n+1) 1,
   2.213 -        etac all2E 1,
   2.214 -        eres_inst_tac ctxt [(("P", 1), sproj "R" eqs n^" "^x^" "^x^"'")] (mp RS disjE) 1,
   2.215 -        TRY (safe_tac HOL_cs),
   2.216 -        REPEAT (CHANGED (asm_simp_tac take_ss 1))];
   2.217 -      fun tacs ctxt = [
   2.218 -        rtac impI 1,
   2.219 -        InductTacs.induct_tac ctxt [[SOME "n"]] 1,
   2.220 -        simp_tac take_ss 1,
   2.221 -        safe_tac HOL_cs] @
   2.222 -        flat (mapn (x_tacs ctxt) 0 xs);
   2.223 -    in pg [ax_bisim_def] goal tacs end;
   2.224 -in
   2.225 -  val _ = trace " Proving coind...";
   2.226 -  val coind = 
   2.227 -    let
   2.228 -      fun mk_prj n x = mk_trp (proj (%:"R") eqs n $ %:x $ %:(x^"'"));
   2.229 -      fun mk_eqn x = %:x === %:(x^"'");
   2.230 -      val goal =
   2.231 -        mk_trp (%%:(comp_dname^"_bisim") $ %:"R") ===>
   2.232 -          Logic.list_implies (mapn mk_prj 0 xs,
   2.233 -            mk_trp (foldr1 mk_conj (map mk_eqn xs)));
   2.234 -      val tacs =
   2.235 -        TRY (safe_tac HOL_cs) ::
   2.236 -        maps (fn take_lemma => [
   2.237 -          rtac take_lemma 1,
   2.238 -          cut_facts_tac [coind_lemma] 1,
   2.239 -          fast_tac HOL_cs 1])
   2.240 -        take_lemmas;
   2.241 -    in pg [] goal (K tacs) end;
   2.242 -end; (* local *)
   2.243 -
   2.244 -in thy |> snd o Global_Theory.add_thmss
   2.245 -    [((Binding.qualified true "coinduct" comp_dbind, [coind]), [])]
   2.246 +  thy |> snd o Global_Theory.add_thms
   2.247 +    (map Thm.no_attributes (coind_binds ~~ coinds))
   2.248  end; (* let *)
   2.249  
   2.250  (******************************************************************************)
   2.251 @@ -500,7 +482,7 @@
   2.252  
   2.253  val thy =
   2.254      if is_indirect then thy else
   2.255 -    prove_coinduction (comp_dbind, eqs) take_rews take_lemma_thms thy;
   2.256 +    prove_coinduction (comp_dbind, dbinds) constr_infos take_info take_rewss thy;
   2.257  
   2.258  in
   2.259    (take_rews, thy)