Generalized equivariance and nominal_inductive commands to
authorberghofe
Thu, 13 Sep 2007 18:11:59 +0200
changeset 24570 621b60b1df00
parent 24569 c80e1871098b
child 24571 a6d0428dea8e
Generalized equivariance and nominal_inductive commands to inductive definitions involving arbitrary monotone operators.
src/HOL/Nominal/nominal_inductive.ML
--- a/src/HOL/Nominal/nominal_inductive.ML	Thu Sep 13 18:08:08 2007 +0200
+++ b/src/HOL/Nominal/nominal_inductive.ML	Thu Sep 13 18:11:59 2007 +0200
@@ -15,14 +15,41 @@
 structure NominalInductive : NOMINAL_INDUCTIVE =
 struct
 
+val inductive_forall_name = "HOL.induct_forall";
+val inductive_forall_def = thm "induct_forall_def";
+val inductive_atomize = thms "induct_atomize";
+val inductive_rulify = thms "induct_rulify";
+val inductive_rulify_fallback = thms "induct_rulify_fallback";
+
+val rulify =
+  hol_simplify inductive_rulify
+  #> hol_simplify inductive_rulify_fallback;
+
+fun rulify_term thy = MetaSimplifier.rewrite_term thy inductive_rulify [];
+
+val atomize_conv =
+  MetaSimplifier.rewrite_cterm (true, false, false) (K (K NONE))
+    (HOL_basic_ss addsimps inductive_atomize);
+val atomize_intr = Conv.fconv_rule (Conv.prems_conv ~1 atomize_conv);
+val atomize_induct = Conv.fconv_rule (Conv.prems_conv ~1
+  (Conv.forall_conv ~1 (Conv.prems_conv ~1 atomize_conv)));
+
 val finite_Un = thm "finite_Un";
 val supp_prod = thm "supp_prod";
 val fresh_prod = thm "fresh_prod";
 
+val perm_bool = mk_meta_eq (thm "perm_bool");
 val perm_boolI = thm "perm_boolI";
 val (_, [perm_boolI_pi, _]) = Drule.strip_comb (snd (Thm.dest_comb
   (Drule.strip_imp_concl (cprop_of perm_boolI))));
 
+fun mk_perm_bool_simproc names = Simplifier.simproc_i
+  (theory_of_thm perm_bool) "perm_bool" [@{term "perm pi x"}] (fn thy => fn ss =>
+    fn Const ("Nominal.perm", _) $ _ $ t =>
+         if the_default "" (try (head_of #> dest_Const #> fst) t) mem names
+         then SOME perm_bool else NONE
+     | _ => NONE);
+
 val allE_Nil = read_instantiate_sg (the_context()) [("x", "[]")] allE;
 
 fun transp ([] :: _) = []
@@ -49,11 +76,77 @@
   | add_binders thy i (Abs (_, _, t)) bs = add_binders thy (i + 1) t bs
   | add_binders thy i _ bs = bs;
 
+fun split_conj f names (Const ("op &", _) $ p $ q) _ = (case head_of p of
+      Const (name, _) =>
+        if name mem names then SOME (f p q) else NONE
+    | _ => NONE)
+  | split_conj _ _ _ _ = NONE;
+
+fun strip_all [] t = t
+  | strip_all (_ :: xs) (Const ("All", _) $ Abs (s, T, t)) = strip_all xs t;
+
+(*********************************************************************)
+(* maps  R ... & (ALL pi_1 ... pi_n z. P z (pi_1 o ... o pi_n o t))  *)
+(* or    ALL pi_1 ... pi_n. P (pi_1 o ... o pi_n o t)                *)
+(* to    R ... & id (ALL z. (pi_1 o ... o pi_n o t))                 *)
+(* or    id (ALL z. (pi_1 o ... o pi_n o t))                         *)
+(*                                                                   *)
+(* where "id" protects the subformula from simplification            *)
+(*********************************************************************)
+
+fun inst_conj_all names ps pis (Const ("op &", _) $ p $ q) _ =
+      (case head_of p of
+         Const (name, _) =>
+           if name mem names then SOME (HOLogic.mk_conj (p,
+             Const ("Fun.id", HOLogic.boolT --> HOLogic.boolT) $
+               (subst_bounds (pis, strip_all pis q))))
+           else NONE
+       | _ => NONE)
+  | inst_conj_all names ps pis t u =
+      if member (op aconv) ps (head_of u) then
+        SOME (Const ("Fun.id", HOLogic.boolT --> HOLogic.boolT) $
+          (subst_bounds (pis, strip_all pis t)))
+      else NONE
+  | inst_conj_all _ _ _ _ _ = NONE;
+
+fun inst_conj_all_tac k = EVERY
+  [TRY (EVERY [etac conjE 1, rtac conjI 1, atac 1]),
+   REPEAT_DETERM_N k (etac allE 1),
+   simp_tac (HOL_basic_ss addsimps [id_apply]) 1];
+
+fun map_term f t u = (case f t u of
+      NONE => map_term' f t u | x => x)
+and map_term' f (t $ u) (t' $ u') = (case (map_term f t t', map_term f u u') of
+      (NONE, NONE) => NONE
+    | (SOME t'', NONE) => SOME (t'' $ u)
+    | (NONE, SOME u'') => SOME (t $ u'')
+    | (SOME t'', SOME u'') => SOME (t'' $ u''))
+  | map_term' f (Abs (s, T, t)) (Abs (s', T', t')) = (case map_term f t t' of
+      NONE => NONE
+    | SOME t'' => SOME (Abs (s, T, t'')))
+  | map_term' _ _ _ = NONE;
+
+(*********************************************************************)
+(*         Prove  F[f t]  from  F[t],  where F is monotone           *)
+(*********************************************************************)
+
+fun map_thm ctxt f tac monos opt th =
+  let
+    val prop = prop_of th;
+    fun prove t =
+      Goal.prove ctxt [] [] t (fn _ =>
+        EVERY [cut_facts_tac [th] 1, etac rev_mp 1,
+          REPEAT_DETERM (FIRSTGOAL (resolve_tac monos)),
+          REPEAT_DETERM (rtac impI 1 THEN (atac 1 ORELSE tac))])
+  in Option.map prove (map_term f prop (the_default prop opt)) end;
+
 fun prove_strong_ind s avoids thy =
   let
     val ctxt = ProofContext.init thy;
     val ({names, ...}, {raw_induct, ...}) =
       InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
+    val raw_induct = atomize_induct raw_induct;
+    val monos = InductivePackage.get_monos ctxt;
     val eqvt_thms = NominalThmDecls.get_eqvt_thms thy;
     val _ = (case names \\ foldl (apfst prop_of #> add_term_consts) [] eqvt_thms of
         [] => ()
@@ -104,24 +197,20 @@
     val fs_ctxt_name = Name.variant (add_term_names (raw_induct', [])) "z";
     val fsT = TFree (fs_ctxt_tyname, ind_sort);
 
+    val inductive_forall_def' = Drule.instantiate'
+      [SOME (ctyp_of thy fsT)] [] inductive_forall_def;
+
     fun lift_pred' t (Free (s, T)) ts =
       list_comb (Free (s, fsT --> T), t :: ts);
     val lift_pred = lift_pred' (Bound 0);
 
-    fun lift_prem (Const ("Trueprop", _) $ t) =
-          let val (u, ts) = strip_comb t
-          in
-            if u mem ps then
-              all fsT $ Abs ("z", fsT, HOLogic.mk_Trueprop
-                (lift_pred u (map (incr_boundvars 1) ts)))
-            else HOLogic.mk_Trueprop (lift_prem t)
-          end
-      | lift_prem (t as (f $ u)) =
+    fun lift_prem (t as (f $ u)) =
           let val (p, ts) = strip_comb t
           in
             if p mem ps then
-              HOLogic.all_const fsT $ Abs ("z", fsT,
-                lift_pred p (map (incr_boundvars 1) ts))
+              Const (inductive_forall_name,
+                (fsT --> HOLogic.boolT) --> HOLogic.boolT) $
+                  Abs ("z", fsT, lift_pred p (map (incr_boundvars 1) ts))
             else lift_prem f $ lift_prem u
           end
       | lift_prem (Abs (s, T, t)) = Abs (s, T, lift_prem t)
@@ -167,13 +256,19 @@
 
     val vc_compat = map (fn (params, bvars, prems, (p, ts)) =>
       map (fn q => list_all (params, incr_boundvars ~1 (Logic.list_implies
-          (filter (fn prem => null (ps inter term_frees prem)) prems, q))))
+          (List.mapPartial (fn prem =>
+             if null (ps inter term_frees prem) then SOME prem
+             else map_term (split_conj (K o I) names) prem prem) prems, q))))
         (mk_distinct bvars @
          maps (fn (t, T) => map (fn (u, U) => HOLogic.mk_Trueprop
            (Const ("Nominal.fresh", U --> T --> HOLogic.boolT) $ u $ t)) bvars)
              (ts ~~ binder_types (fastype_of p)))) prems;
 
-    val eqvt_ss = HOL_basic_ss addsimps eqvt_thms;
+    val perm_pi_simp = PureThy.get_thms thy (Name "perm_pi_simp");
+    val pt2_atoms = map (fn aT => PureThy.get_thm thy
+      (Name ("pt_" ^ Sign.base_name (fst (dest_Type aT)) ^ "2"))) atomTs;
+    val eqvt_ss = HOL_basic_ss addsimps (eqvt_thms @ perm_pi_simp @ pt2_atoms)
+      addsimprocs [mk_perm_bool_simproc ["Fun.id"]];
     val fresh_bij = PureThy.get_thms thy (Name "fresh_bij");
     val perm_bij = PureThy.get_thms thy (Name "perm_bij");
     val fs_atoms = map (fn aT => PureThy.get_thm thy
@@ -182,8 +277,6 @@
     val fresh_atm = PureThy.get_thms thy (Name "fresh_atm");
     val calc_atm = PureThy.get_thms thy (Name "calc_atm");
     val perm_fresh_fresh = PureThy.get_thms thy (Name "perm_fresh_fresh");
-    val pt2_atoms = map (fn aT => PureThy.get_thm thy
-      (Name ("pt_" ^ Sign.base_name (fst (dest_Type aT)) ^ "2")) RS sym) atomTs;
 
     fun obtain_fresh_name ts T (freshs1, freshs2, ctxt) =
       let
@@ -231,14 +324,35 @@
                    (map snd bvars') ([], [], ctxt');
                  val freshs2' = NominalPackage.mk_not_sym freshs2;
                  val pis' = map NominalPackage.perm_of_pair (pi_bvars ~~ freshs1);
+                 fun concat_perm pi1 pi2 =
+                   let val T = fastype_of pi1
+                   in if T = fastype_of pi2 then
+                       Const ("List.append", T --> T --> T) $ pi1 $ pi2
+                     else pi2
+                   end;
+                 val pis'' = fold (concat_perm #> map) pis' pis;
                  val env = Pattern.first_order_match thy (ihypt, prop_of ihyp)
                    (Vartab.empty, Vartab.empty);
                  val ihyp' = Thm.instantiate ([], map (pairself (cterm_of thy))
                    (map (Envir.subst_vars env) vs ~~
                     map (fold_rev (NominalPackage.mk_perm [])
                       (rev pis' @ pis)) params' @ [z])) ihyp;
-                 val (gprems1, gprems2) = pairself (map fst) (List.partition
-                   (fn (th, t) => null (term_frees t inter ps)) (gprems ~~ oprems));
+                 fun mk_pi th =
+                   Simplifier.simplify (HOL_basic_ss addsimps [id_apply]
+                       addsimprocs [NominalPackage.perm_simproc])
+                     (Simplifier.simplify eqvt_ss
+                       (fold_rev (fn pi => fn th' => th' RS Drule.cterm_instantiate
+                         [(perm_boolI_pi, cterm_of thy pi)] perm_boolI)
+                           (rev pis' @ pis) th));
+                 val (gprems1, gprems2) = split_list
+                   (map (fn (th, t) =>
+                      if null (term_frees t inter ps) then (SOME th, mk_pi th)
+                      else
+                        (map_thm ctxt (split_conj (K o I) names)
+                           (etac conjunct1 1) monos NONE th,
+                         mk_pi (the (map_thm ctxt (inst_conj_all names ps (rev pis''))
+                           (inst_conj_all_tac (length pis'')) monos (SOME t) th))))
+                      (gprems ~~ oprems)) |>> List.mapPartial I;
                  val vc_compat_ths' = map (fn th =>
                    let
                      val th' = gprems1 MRS
@@ -258,11 +372,6 @@
                    in Simplifier.simplify (eqvt_ss addsimps fresh_atm) th'' end)
                      vc_compat_ths;
                  val vc_compat_ths'' = NominalPackage.mk_not_sym vc_compat_ths';
-                 val gprems1' = map (fn th => fold_rev (fn pi => fn th' =>
-                   Simplifier.simplify eqvt_ss (th' RS Drule.cterm_instantiate
-                     [(perm_boolI_pi, cterm_of thy pi)] perm_boolI))
-                       (rev pis' @ pis) th) gprems1;
-                 val gprems2' = map (Simplifier.simplify eqvt_ss) gprems2;
                  (** Since calc_atm simplifies (pi :: 'a prm) o (x :: 'b) to x **)
                  (** we have to pre-simplify the rewrite rules                 **)
                  val calc_atm_ss = HOL_ss addsimps calc_atm @
@@ -275,8 +384,8 @@
                      REPEAT_DETERM_N (nprems_of ihyp - length gprems)
                        (simp_tac calc_atm_ss 1),
                      REPEAT_DETERM_N (length gprems)
-                       (resolve_tac gprems1' 1 ORELSE
-                        simp_tac (HOL_basic_ss addsimps pt2_atoms @ gprems2'
+                       (simp_tac (HOL_ss
+                          addsimps inductive_forall_def' :: gprems2
                           addsimprocs [NominalPackage.perm_simproc]) 1)]));
                  val final = Goal.prove ctxt'' [] [] (term_of concl)
                    (fn _ => cut_facts_tac [th] 1 THEN full_simp_tac (HOL_ss
@@ -301,7 +410,9 @@
         val ctxt = ProofContext.init thy;
         val rec_name = space_implode "_" (map Sign.base_name names);
         val ind_case_names = RuleCases.case_names induct_cases;
-        val strong_raw_induct = mk_proof thy thss;
+        val strong_raw_induct =
+          mk_proof thy (map (map atomize_intr) thss) |>
+          rulify |> MetaSimplifier.norm_hhf;
         val strong_induct =
           if length names > 1 then
             (strong_raw_induct, [ind_case_names, RuleCases.consumes 0])
@@ -318,7 +429,7 @@
           [ind_case_names, RuleCases.consumes 1])] |> snd |>
         Theory.parent_path
       end))
-      (map (map (rpair [])) vc_compat)
+      (map (map (rulify_term thy #> rpair [])) vc_compat)
   end;
 
 fun prove_eqvt s xatoms thy =
@@ -326,6 +437,10 @@
     val ctxt = ProofContext.init thy;
     val ({names, ...}, {raw_induct, intrs, elims, ...}) =
       InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
+    val raw_induct = atomize_induct raw_induct;
+    val elims = map atomize_induct elims;
+    val intrs = map atomize_intr intrs;
+    val monos = InductivePackage.get_monos ctxt;
     val intrs' = InductivePackage.unpartition_rules intrs
       (map (fn (((s, ths), (_, k)), th) =>
            (s, ths ~~ InductivePackage.infer_intro_vars th k ths))
@@ -344,7 +459,10 @@
            | xs => error ("No such atoms: " ^ commas xs);
          atoms)
       end;
-    val eqvt_ss = HOL_basic_ss addsimps NominalThmDecls.get_eqvt_thms thy;
+    val perm_pi_simp = PureThy.get_thms thy (Name "perm_pi_simp");
+    val eqvt_ss = HOL_basic_ss addsimps
+      (NominalThmDecls.get_eqvt_thms thy @ perm_pi_simp) addsimprocs
+      [mk_perm_bool_simproc names];
     val t = Logic.unvarify (concl_of raw_induct);
     val pi = Name.variant (add_term_names (t, [])) "pi";
     val ps = map (fst o HOLogic.dest_imp)
@@ -357,13 +475,14 @@
              (Logic.unvarify (prop_of intr)) ^ "\n" ^ s);
         val res = SUBPROOF (fn {prems, params, ...} =>
           let
-            val prems' = map (fn th' => Simplifier.simplify eqvt_ss
-              (if null (names inter term_consts (prop_of th')) then th' RS th
-               else th')) prems;
+            val prems' = map (fn th => the_default th (map_thm ctxt
+              (split_conj (K I) names) (etac conjunct2 1) monos NONE th)) prems;
+            val prems'' = map (fn th' =>
+              Simplifier.simplify eqvt_ss (th' RS th)) prems';
             val intr' = Drule.cterm_instantiate (map (cterm_of thy) vs ~~
                map (cterm_of thy o NominalPackage.mk_perm [] pi o term_of) params)
                intr
-          in (rtac intr' THEN_ALL_NEW (TRY o resolve_tac prems')) 1
+          in (rtac intr' THEN_ALL_NEW (TRY o resolve_tac prems'')) 1
           end) ctxt 1 st
       in
         case (Seq.pull res handle THM (s, _, _) => eqvt_err s) of