Generalized equivariance and nominal_inductive commands to
authorberghofe
Thu Sep 13 18:11:59 2007 +0200 (2007-09-13)
changeset 24570621b60b1df00
parent 24569 c80e1871098b
child 24571 a6d0428dea8e
Generalized equivariance and nominal_inductive commands to
inductive definitions involving arbitrary monotone operators.
src/HOL/Nominal/nominal_inductive.ML
     1.1 --- a/src/HOL/Nominal/nominal_inductive.ML	Thu Sep 13 18:08:08 2007 +0200
     1.2 +++ b/src/HOL/Nominal/nominal_inductive.ML	Thu Sep 13 18:11:59 2007 +0200
     1.3 @@ -15,14 +15,41 @@
     1.4  structure NominalInductive : NOMINAL_INDUCTIVE =
     1.5  struct
     1.6  
     1.7 +val inductive_forall_name = "HOL.induct_forall";
     1.8 +val inductive_forall_def = thm "induct_forall_def";
     1.9 +val inductive_atomize = thms "induct_atomize";
    1.10 +val inductive_rulify = thms "induct_rulify";
    1.11 +val inductive_rulify_fallback = thms "induct_rulify_fallback";
    1.12 +
    1.13 +val rulify =
    1.14 +  hol_simplify inductive_rulify
    1.15 +  #> hol_simplify inductive_rulify_fallback;
    1.16 +
    1.17 +fun rulify_term thy = MetaSimplifier.rewrite_term thy inductive_rulify [];
    1.18 +
    1.19 +val atomize_conv =
    1.20 +  MetaSimplifier.rewrite_cterm (true, false, false) (K (K NONE))
    1.21 +    (HOL_basic_ss addsimps inductive_atomize);
    1.22 +val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
    1.23 +val atomize_induct = Conv.fconv_rule (Conv.prems_conv ~1
    1.24 +  (Conv.forall_conv ~1 (Conv.prems_conv ~1 atomize_conv)));
    1.25 +
    1.26  val finite_Un = thm "finite_Un";
    1.27  val supp_prod = thm "supp_prod";
    1.28  val fresh_prod = thm "fresh_prod";
    1.29  
    1.30 +val perm_bool = mk_meta_eq (thm "perm_bool");
    1.31  val perm_boolI = thm "perm_boolI";
    1.32  val (_, [perm_boolI_pi, _]) = Drule.strip_comb (snd (Thm.dest_comb
    1.33    (Drule.strip_imp_concl (cprop_of perm_boolI))));
    1.34  
    1.35 +fun mk_perm_bool_simproc names = Simplifier.simproc_i
    1.36 +  (theory_of_thm perm_bool) "perm_bool" [@{term "perm pi x"}] (fn thy => fn ss =>
    1.37 +    fn Const ("Nominal.perm", _) $ _ $ t =>
    1.38 +         if the_default "" (try (head_of #> dest_Const #> fst) t) mem names
    1.39 +         then SOME perm_bool else NONE
    1.40 +     | _ => NONE);
    1.41 +
    1.42  val allE_Nil = read_instantiate_sg (the_context()) [("x", "[]")] allE;
    1.43  
    1.44  fun transp ([] :: _) = []
    1.45 @@ -49,11 +76,77 @@
    1.46    | add_binders thy i (Abs (_, _, t)) bs = add_binders thy (i + 1) t bs
    1.47    | add_binders thy i _ bs = bs;
    1.48  
    1.49 +fun split_conj f names (Const ("op &", _) $ p $ q) _ = (case head_of p of
    1.50 +      Const (name, _) =>
    1.51 +        if name mem names then SOME (f p q) else NONE
    1.52 +    | _ => NONE)
    1.53 +  | split_conj _ _ _ _ = NONE;
    1.54 +
    1.55 +fun strip_all [] t = t
    1.56 +  | strip_all (_ :: xs) (Const ("All", _) $ Abs (s, T, t)) = strip_all xs t;
    1.57 +
    1.58 +(*********************************************************************)
    1.59 +(* maps  R ... & (ALL pi_1 ... pi_n z. P z (pi_1 o ... o pi_n o t))  *)
    1.60 +(* or    ALL pi_1 ... pi_n. P (pi_1 o ... o pi_n o t)                *)
    1.61 +(* to    R ... & id (ALL z. (pi_1 o ... o pi_n o t))                 *)
    1.62 +(* or    id (ALL z. (pi_1 o ... o pi_n o t))                         *)
    1.63 +(*                                                                   *)
    1.64 +(* where "id" protects the subformula from simplification            *)
    1.65 +(*********************************************************************)
    1.66 +
    1.67 +fun inst_conj_all names ps pis (Const ("op &", _) $ p $ q) _ =
    1.68 +      (case head_of p of
    1.69 +         Const (name, _) =>
    1.70 +           if name mem names then SOME (HOLogic.mk_conj (p,
    1.71 +             Const ("Fun.id", HOLogic.boolT --> HOLogic.boolT) $
    1.72 +               (subst_bounds (pis, strip_all pis q))))
    1.73 +           else NONE
    1.74 +       | _ => NONE)
    1.75 +  | inst_conj_all names ps pis t u =
    1.76 +      if member (op aconv) ps (head_of u) then
    1.77 +        SOME (Const ("Fun.id", HOLogic.boolT --> HOLogic.boolT) $
    1.78 +          (subst_bounds (pis, strip_all pis t)))
    1.79 +      else NONE
    1.80 +  | inst_conj_all _ _ _ _ _ = NONE;
    1.81 +
    1.82 +fun inst_conj_all_tac k = EVERY
    1.83 +  [TRY (EVERY [etac conjE 1, rtac conjI 1, atac 1]),
    1.84 +   REPEAT_DETERM_N k (etac allE 1),
    1.85 +   simp_tac (HOL_basic_ss addsimps [id_apply]) 1];
    1.86 +
    1.87 +fun map_term f t u = (case f t u of
    1.88 +      NONE => map_term' f t u | x => x)
    1.89 +and map_term' f (t $ u) (t' $ u') = (case (map_term f t t', map_term f u u') of
    1.90 +      (NONE, NONE) => NONE
    1.91 +    | (SOME t'', NONE) => SOME (t'' $ u)
    1.92 +    | (NONE, SOME u'') => SOME (t $ u'')
    1.93 +    | (SOME t'', SOME u'') => SOME (t'' $ u''))
    1.94 +  | map_term' f (Abs (s, T, t)) (Abs (s', T', t')) = (case map_term f t t' of
    1.95 +      NONE => NONE
    1.96 +    | SOME t'' => SOME (Abs (s, T, t'')))
    1.97 +  | map_term' _ _ _ = NONE;
    1.98 +
    1.99 +(*********************************************************************)
   1.100 +(*         Prove  F[f t]  from  F[t],  where F is monotone           *)
   1.101 +(*********************************************************************)
   1.102 +
   1.103 +fun map_thm ctxt f tac monos opt th =
   1.104 +  let
   1.105 +    val prop = prop_of th;
   1.106 +    fun prove t =
   1.107 +      Goal.prove ctxt [] [] t (fn _ =>
   1.108 +        EVERY [cut_facts_tac [th] 1, etac rev_mp 1,
   1.109 +          REPEAT_DETERM (FIRSTGOAL (resolve_tac monos)),
   1.110 +          REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))])
   1.111 +  in Option.map prove (map_term f prop (the_default prop opt)) end;
   1.112 +
   1.113  fun prove_strong_ind s avoids thy =
   1.114    let
   1.115      val ctxt = ProofContext.init thy;
   1.116      val ({names, ...}, {raw_induct, ...}) =
   1.117        InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
   1.118 +    val raw_induct = atomize_induct raw_induct;
   1.119 +    val monos = InductivePackage.get_monos ctxt;
   1.120      val eqvt_thms = NominalThmDecls.get_eqvt_thms thy;
   1.121      val _ = (case names \\ foldl (apfst prop_of #> add_term_consts) [] eqvt_thms of
   1.122          [] => ()
   1.123 @@ -104,24 +197,20 @@
   1.124      val fs_ctxt_name = Name.variant (add_term_names (raw_induct', [])) "z";
   1.125      val fsT = TFree (fs_ctxt_tyname, ind_sort);
   1.126  
   1.127 +    val inductive_forall_def' = Drule.instantiate'
   1.128 +      [SOME (ctyp_of thy fsT)] [] inductive_forall_def;
   1.129 +
   1.130      fun lift_pred' t (Free (s, T)) ts =
   1.131        list_comb (Free (s, fsT --> T), t :: ts);
   1.132      val lift_pred = lift_pred' (Bound 0);
   1.133  
   1.134 -    fun lift_prem (Const ("Trueprop", _) $ t) =
   1.135 -          let val (u, ts) = strip_comb t
   1.136 -          in
   1.137 -            if u mem ps then
   1.138 -              all fsT $ Abs ("z", fsT, HOLogic.mk_Trueprop
   1.139 -                (lift_pred u (map (incr_boundvars 1) ts)))
   1.140 -            else HOLogic.mk_Trueprop (lift_prem t)
   1.141 -          end
   1.142 -      | lift_prem (t as (f $ u)) =
   1.143 +    fun lift_prem (t as (f $ u)) =
   1.144            let val (p, ts) = strip_comb t
   1.145            in
   1.146              if p mem ps then
   1.147 -              HOLogic.all_const fsT $ Abs ("z", fsT,
   1.148 -                lift_pred p (map (incr_boundvars 1) ts))
   1.149 +              Const (inductive_forall_name,
   1.150 +                (fsT --> HOLogic.boolT) --> HOLogic.boolT) $
   1.151 +                  Abs ("z", fsT, lift_pred p (map (incr_boundvars 1) ts))
   1.152              else lift_prem f $ lift_prem u
   1.153            end
   1.154        | lift_prem (Abs (s, T, t)) = Abs (s, T, lift_prem t)
   1.155 @@ -167,13 +256,19 @@
   1.156  
   1.157      val vc_compat = map (fn (params, bvars, prems, (p, ts)) =>
   1.158        map (fn q => list_all (params, incr_boundvars ~1 (Logic.list_implies
   1.159 -          (filter (fn prem => null (ps inter term_frees prem)) prems, q))))
   1.160 +          (List.mapPartial (fn prem =>
   1.161 +             if null (ps inter term_frees prem) then SOME prem
   1.162 +             else map_term (split_conj (K o I) names) prem prem) prems, q))))
   1.163          (mk_distinct bvars @
   1.164           maps (fn (t, T) => map (fn (u, U) => HOLogic.mk_Trueprop
   1.165             (Const ("Nominal.fresh", U --> T --> HOLogic.boolT) $ u $ t)) bvars)
   1.166               (ts ~~ binder_types (fastype_of p)))) prems;
   1.167  
   1.168 -    val eqvt_ss = HOL_basic_ss addsimps eqvt_thms;
   1.169 +    val perm_pi_simp = PureThy.get_thms thy (Name "perm_pi_simp");
   1.170 +    val pt2_atoms = map (fn aT => PureThy.get_thm thy
   1.171 +      (Name ("pt_" ^ Sign.base_name (fst (dest_Type aT)) ^ "2"))) atomTs;
   1.172 +    val eqvt_ss = HOL_basic_ss addsimps (eqvt_thms @ perm_pi_simp @ pt2_atoms)
   1.173 +      addsimprocs [mk_perm_bool_simproc ["Fun.id"]];
   1.174      val fresh_bij = PureThy.get_thms thy (Name "fresh_bij");
   1.175      val perm_bij = PureThy.get_thms thy (Name "perm_bij");
   1.176      val fs_atoms = map (fn aT => PureThy.get_thm thy
   1.177 @@ -182,8 +277,6 @@
   1.178      val fresh_atm = PureThy.get_thms thy (Name "fresh_atm");
   1.179      val calc_atm = PureThy.get_thms thy (Name "calc_atm");
   1.180      val perm_fresh_fresh = PureThy.get_thms thy (Name "perm_fresh_fresh");
   1.181 -    val pt2_atoms = map (fn aT => PureThy.get_thm thy
   1.182 -      (Name ("pt_" ^ Sign.base_name (fst (dest_Type aT)) ^ "2")) RS sym) atomTs;
   1.183  
   1.184      fun obtain_fresh_name ts T (freshs1, freshs2, ctxt) =
   1.185        let
   1.186 @@ -231,14 +324,35 @@
   1.187                     (map snd bvars') ([], [], ctxt');
   1.188                   val freshs2' = NominalPackage.mk_not_sym freshs2;
   1.189                   val pis' = map NominalPackage.perm_of_pair (pi_bvars ~~ freshs1);
   1.190 +                 fun concat_perm pi1 pi2 =
   1.191 +                   let val T = fastype_of pi1
   1.192 +                   in if T = fastype_of pi2 then
   1.193 +                       Const ("List.append", T --> T --> T) $ pi1 $ pi2
   1.194 +                     else pi2
   1.195 +                   end;
   1.196 +                 val pis'' = fold (concat_perm #> map) pis' pis;
   1.197                   val env = Pattern.first_order_match thy (ihypt, prop_of ihyp)
   1.198                     (Vartab.empty, Vartab.empty);
   1.199                   val ihyp' = Thm.instantiate ([], map (pairself (cterm_of thy))
   1.200                     (map (Envir.subst_vars env) vs ~~
   1.201                      map (fold_rev (NominalPackage.mk_perm [])
   1.202                        (rev pis' @ pis)) params' @ [z])) ihyp;
   1.203 -                 val (gprems1, gprems2) = pairself (map fst) (List.partition
   1.204 -                   (fn (th, t) => null (term_frees t inter ps)) (gprems ~~ oprems));
   1.205 +                 fun mk_pi th =
   1.206 +                   Simplifier.simplify (HOL_basic_ss addsimps [id_apply]
   1.207 +                       addsimprocs [NominalPackage.perm_simproc])
   1.208 +                     (Simplifier.simplify eqvt_ss
   1.209 +                       (fold_rev (fn pi => fn th' => th' RS Drule.cterm_instantiate
   1.210 +                         [(perm_boolI_pi, cterm_of thy pi)] perm_boolI)
   1.211 +                           (rev pis' @ pis) th));
   1.212 +                 val (gprems1, gprems2) = split_list
   1.213 +                   (map (fn (th, t) =>
   1.214 +                      if null (term_frees t inter ps) then (SOME th, mk_pi th)
   1.215 +                      else
   1.216 +                        (map_thm ctxt (split_conj (K o I) names)
   1.217 +                           (etac conjunct1 1) monos NONE th,
   1.218 +                         mk_pi (the (map_thm ctxt (inst_conj_all names ps (rev pis''))
   1.219 +                           (inst_conj_all_tac (length pis'')) monos (SOME t) th))))
   1.220 +                      (gprems ~~ oprems)) |>> List.mapPartial I;
   1.221                   val vc_compat_ths' = map (fn th =>
   1.222                     let
   1.223                       val th' = gprems1 MRS
   1.224 @@ -258,11 +372,6 @@
   1.225                     in Simplifier.simplify (eqvt_ss addsimps fresh_atm) th'' end)
   1.226                       vc_compat_ths;
   1.227                   val vc_compat_ths'' = NominalPackage.mk_not_sym vc_compat_ths';
   1.228 -                 val gprems1' = map (fn th => fold_rev (fn pi => fn th' =>
   1.229 -                   Simplifier.simplify eqvt_ss (th' RS Drule.cterm_instantiate
   1.230 -                     [(perm_boolI_pi, cterm_of thy pi)] perm_boolI))
   1.231 -                       (rev pis' @ pis) th) gprems1;
   1.232 -                 val gprems2' = map (Simplifier.simplify eqvt_ss) gprems2;
   1.233                   (** Since calc_atm simplifies (pi :: 'a prm) o (x :: 'b) to x **)
   1.234                   (** we have to pre-simplify the rewrite rules                 **)
   1.235                   val calc_atm_ss = HOL_ss addsimps calc_atm @
   1.236 @@ -275,8 +384,8 @@
   1.237                       REPEAT_DETERM_N (nprems_of ihyp - length gprems)
   1.238                         (simp_tac calc_atm_ss 1),
   1.239                       REPEAT_DETERM_N (length gprems)
   1.240 -                       (resolve_tac gprems1' 1 ORELSE
   1.241 -                        simp_tac (HOL_basic_ss addsimps pt2_atoms @ gprems2'
   1.242 +                       (simp_tac (HOL_ss
   1.243 +                          addsimps inductive_forall_def' :: gprems2
   1.244                            addsimprocs [NominalPackage.perm_simproc]) 1)]));
   1.245                   val final = Goal.prove ctxt'' [] [] (term_of concl)
   1.246                     (fn _ => cut_facts_tac [th] 1 THEN full_simp_tac (HOL_ss
   1.247 @@ -301,7 +410,9 @@
   1.248          val ctxt = ProofContext.init thy;
   1.249          val rec_name = space_implode "_" (map Sign.base_name names);
   1.250          val ind_case_names = RuleCases.case_names induct_cases;
   1.251 -        val strong_raw_induct = mk_proof thy thss;
   1.252 +        val strong_raw_induct =
   1.253 +          mk_proof thy (map (map atomize_intr) thss) |>
   1.254 +          rulify |> MetaSimplifier.norm_hhf;
   1.255          val strong_induct =
   1.256            if length names > 1 then
   1.257              (strong_raw_induct, [ind_case_names, RuleCases.consumes 0])
   1.258 @@ -318,7 +429,7 @@
   1.259            [ind_case_names, RuleCases.consumes 1])] |> snd |>
   1.260          Theory.parent_path
   1.261        end))
   1.262 -      (map (map (rpair [])) vc_compat)
   1.263 +      (map (map (rulify_term thy #> rpair [])) vc_compat)
   1.264    end;
   1.265  
   1.266  fun prove_eqvt s xatoms thy =
   1.267 @@ -326,6 +437,10 @@
   1.268      val ctxt = ProofContext.init thy;
   1.269      val ({names, ...}, {raw_induct, intrs, elims, ...}) =
   1.270        InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
   1.271 +    val raw_induct = atomize_induct raw_induct;
   1.272 +    val elims = map atomize_induct elims;
   1.273 +    val intrs = map atomize_intr intrs;
   1.274 +    val monos = InductivePackage.get_monos ctxt;
   1.275      val intrs' = InductivePackage.unpartition_rules intrs
   1.276        (map (fn (((s, ths), (_, k)), th) =>
   1.277             (s, ths ~~ InductivePackage.infer_intro_vars th k ths))
   1.278 @@ -344,7 +459,10 @@
   1.279             | xs => error ("No such atoms: " ^ commas xs);
   1.280           atoms)
   1.281        end;
   1.282 -    val eqvt_ss = HOL_basic_ss addsimps NominalThmDecls.get_eqvt_thms thy;
   1.283 +    val perm_pi_simp = PureThy.get_thms thy (Name "perm_pi_simp");
   1.284 +    val eqvt_ss = HOL_basic_ss addsimps
   1.285 +      (NominalThmDecls.get_eqvt_thms thy @ perm_pi_simp) addsimprocs
   1.286 +      [mk_perm_bool_simproc names];
   1.287      val t = Logic.unvarify (concl_of raw_induct);
   1.288      val pi = Name.variant (add_term_names (t, [])) "pi";
   1.289      val ps = map (fst o HOLogic.dest_imp)
   1.290 @@ -357,13 +475,14 @@
   1.291               (Logic.unvarify (prop_of intr)) ^ "\n" ^ s);
   1.292          val res = SUBPROOF (fn {prems, params, ...} =>
   1.293            let
   1.294 -            val prems' = map (fn th' => Simplifier.simplify eqvt_ss
   1.295 -              (if null (names inter term_consts (prop_of th')) then th' RS th
   1.296 -               else th')) prems;
   1.297 +            val prems' = map (fn th => the_default th (map_thm ctxt
   1.298 +              (split_conj (K I) names) (etac conjunct2 1) monos NONE th)) prems;
   1.299 +            val prems'' = map (fn th' =>
   1.300 +              Simplifier.simplify eqvt_ss (th' RS th)) prems';
   1.301              val intr' = Drule.cterm_instantiate (map (cterm_of thy) vs ~~
   1.302                 map (cterm_of thy o NominalPackage.mk_perm [] pi o term_of) params)
   1.303                 intr
   1.304 -          in (rtac intr' THEN_ALL_NEW (TRY o resolve_tac prems')) 1
   1.305 +          in (rtac intr' THEN_ALL_NEW (TRY o resolve_tac prems'')) 1
   1.306            end) ctxt 1 st
   1.307        in
   1.308          case (Seq.pull res handle THM (s, _, _) => eqvt_err s) of