formal proof context for axclass proofs;
authorwenzelm
Wed, 10 Apr 2013 13:10:38 +0200
changeset 51671 0d142a78fb7c
parent 51670 d721d21e6374
child 51672 d5c5e088ebdf
child 51685 385ef6706252
formal proof context for axclass proofs; avoid global_simpset_of -- refer to simpset of proof context;
src/HOL/HOLCF/Tools/cpodef.ML
src/HOL/Nominal/nominal_atoms.ML
src/HOL/Nominal/nominal_datatype.ML
src/Pure/axclass.ML
--- a/src/HOL/HOLCF/Tools/cpodef.ML	Wed Apr 10 12:31:35 2013 +0200
+++ b/src/HOL/HOLCF/Tools/cpodef.ML	Wed Apr 10 13:10:38 2013 +0200
@@ -74,7 +74,7 @@
     val (full_tname, Ts) = dest_Type newT
     val lhs_sorts = map (snd o dest_TFree) Ts
     val tac = Tactic.rtac (@{thm typedef_cpo} OF cpo_thms) 1
-    val thy = AxClass.prove_arity (full_tname, lhs_sorts, @{sort cpo}) tac thy
+    val thy = AxClass.prove_arity (full_tname, lhs_sorts, @{sort cpo}) (K tac) thy
     (* transfer thms so that they will know about the new cpo instance *)
     val cpo_thms' = map (Thm.transfer thy) cpo_thms
     fun make thm = Drule.zero_var_indexes (thm OF cpo_thms')
@@ -113,7 +113,7 @@
     val (full_tname, Ts) = dest_Type newT
     val lhs_sorts = map (snd o dest_TFree) Ts
     val tac = Tactic.rtac (@{thm typedef_pcpo} OF pcpo_thms) 1
-    val thy = AxClass.prove_arity (full_tname, lhs_sorts, @{sort pcpo}) tac thy
+    val thy = AxClass.prove_arity (full_tname, lhs_sorts, @{sort pcpo}) (K tac) thy
     val pcpo_thms' = map (Thm.transfer thy) pcpo_thms
     fun make thm = Drule.zero_var_indexes (thm OF pcpo_thms')
     val Rep_strict = make @{thm typedef_Rep_strict}
--- a/src/HOL/Nominal/nominal_atoms.ML	Wed Apr 10 12:31:35 2013 +0200
+++ b/src/HOL/Nominal/nominal_atoms.ML	Wed Apr 10 13:10:38 2013 +0200
@@ -518,7 +518,7 @@
          in
            thy'
            |> AxClass.prove_arity (qu_name,[],[cls_name])
-              (if ak_name = ak_name' then proof1 else proof2)
+              (fn _ => if ak_name = ak_name' then proof1 else proof2)
          end) ak_names thy) ak_names thy12d;
 
      (* show that                       *)
@@ -537,7 +537,7 @@
           val at_thm   = Global_Theory.get_thm thy ("at_"^ak_name^"_inst");
           val pt_inst  = Global_Theory.get_thm thy ("pt_"^ak_name^"_inst");
 
-          fun pt_proof thm = 
+          fun pt_proof thm ctxt =
               EVERY [Class.intro_classes_tac [],
                      rtac (thm RS pt1) 1, rtac (thm RS pt2) 1, rtac (thm RS pt3) 1, atac 1];
 
@@ -592,7 +592,7 @@
                       val simp_s = HOL_basic_ss addsimps [dj_inst RS dj_supp, finite_emptyI];
                   in EVERY [Class.intro_classes_tac [], asm_simp_tac simp_s 1] end)
         in
-         AxClass.prove_arity (qu_name,[],[qu_class]) proof thy'
+         AxClass.prove_arity (qu_name,[],[qu_class]) (fn _ => proof) thy'
         end) ak_names thy) ak_names thy18;
 
        (* shows that                  *)
@@ -607,7 +607,7 @@
         let
           val cls_name = Sign.full_bname thy ("fs_"^ak_name);
           val fs_inst  = Global_Theory.get_thm thy ("fs_"^ak_name^"_inst");
-          fun fs_proof thm = EVERY [Class.intro_classes_tac [], rtac (thm RS fs1) 1];
+          fun fs_proof thm ctxt = EVERY [Class.intro_classes_tac [], rtac (thm RS fs1) 1];
 
           val fs_thm_unit  = fs_unit_inst;
           val fs_thm_prod  = fs_inst RS (fs_inst RS fs_prod_inst);
@@ -669,7 +669,7 @@
                     EVERY [Class.intro_classes_tac [], simp_tac simp_s 1]
                   end))
               in
-                AxClass.prove_arity (name,[],[cls_name]) proof thy''
+                AxClass.prove_arity (name,[],[cls_name]) (fn _ => proof) thy''
               end) ak_names thy') ak_names thy) ak_names thy24;
 
        (* shows that                                                    *) 
@@ -689,7 +689,7 @@
             val pt_inst  = Global_Theory.get_thm thy' ("pt_"^ak_name^"_inst");
             val at_inst  = Global_Theory.get_thm thy' ("at_"^ak_name^"_inst");
 
-            fun cp_proof thm  = EVERY [Class.intro_classes_tac [],rtac (thm RS cp1) 1];
+            fun cp_proof thm ctxt = EVERY [Class.intro_classes_tac [],rtac (thm RS cp1) 1];
           
             val cp_thm_unit = cp_unit_inst;
             val cp_thm_prod = cp_inst RS (cp_inst RS cp_prod_inst);
@@ -722,7 +722,7 @@
                val simp_s = HOL_basic_ss addsimps [Simpdata.mk_eq defn];
                val proof = EVERY [Class.intro_classes_tac [], REPEAT (asm_simp_tac simp_s 1)];
              in 
-               AxClass.prove_arity (discrete_ty, [], [qu_class]) proof thy
+               AxClass.prove_arity (discrete_ty, [], [qu_class]) (fn _ => proof) thy
              end) ak_names;
 
           fun discrete_fs_inst discrete_ty defn = 
@@ -733,7 +733,7 @@
                val simp_s = HOL_ss addsimps [supp_def, Collect_const, finite_emptyI, Simpdata.mk_eq defn];
                val proof = EVERY [Class.intro_classes_tac [], asm_simp_tac simp_s 1];
              in 
-               AxClass.prove_arity (discrete_ty, [], [qu_class]) proof thy
+               AxClass.prove_arity (discrete_ty, [], [qu_class]) (fn _ => proof) thy
              end) ak_names;
 
           fun discrete_cp_inst discrete_ty defn = 
@@ -744,7 +744,7 @@
                val simp_s = HOL_ss addsimps [Simpdata.mk_eq defn];
                val proof = EVERY [Class.intro_classes_tac [], asm_simp_tac simp_s 1];
              in
-               AxClass.prove_arity (discrete_ty, [], [qu_class]) proof thy
+               AxClass.prove_arity (discrete_ty, [], [qu_class]) (fn _ => proof) thy
              end) ak_names)) ak_names;
 
         in
--- a/src/HOL/Nominal/nominal_datatype.ML	Wed Apr 10 12:31:35 2013 +0200
+++ b/src/HOL/Nominal/nominal_datatype.ML	Wed Apr 10 13:10:38 2013 +0200
@@ -278,9 +278,9 @@
                  Const ("Nominal.perm", T) $ pi $ Free (x, T2))
                end)
              (perm_names_types ~~ perm_indnames))))
-          (fn _ => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
+          (fn {context = ctxt, ...} => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
             ALLGOALS (asm_full_simp_tac
-              (global_simpset_of thy2 addsimps [perm_fun_def]))])),
+              (simpset_of ctxt addsimps [perm_fun_def]))])),
         length new_type_names));
 
     (**** prove [] \<bullet> t = t ****)
@@ -299,8 +299,8 @@
                    Free (x, T)))
                (perm_names ~~
                 map body_type perm_types ~~ perm_indnames)))))
-          (fn _ => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
-            ALLGOALS (asm_full_simp_tac (global_simpset_of thy2))])),
+          (fn {context = ctxt, ...} => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
+            ALLGOALS (asm_full_simp_tac (simpset_of ctxt))])),
         length new_type_names))
       end)
       atoms;
@@ -334,8 +334,8 @@
                     end)
                   (perm_names ~~
                    map body_type perm_types ~~ perm_indnames)))))
-           (fn _ => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
-              ALLGOALS (asm_full_simp_tac (global_simpset_of thy2 addsimps [pt2', pt2_ax]))]))),
+           (fn {context = ctxt, ...} => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
+              ALLGOALS (asm_full_simp_tac (simpset_of ctxt addsimps [pt2', pt2_ax]))]))),
          length new_type_names)
       end) atoms;
 
@@ -370,8 +370,8 @@
                     end)
                   (perm_names ~~
                    map body_type perm_types ~~ perm_indnames))))))
-           (fn _ => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
-              ALLGOALS (asm_full_simp_tac (global_simpset_of thy2 addsimps [pt3', pt3_rev', pt3_ax]))]))),
+           (fn {context = ctxt, ...} => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
+              ALLGOALS (asm_full_simp_tac (simpset_of ctxt addsimps [pt3', pt3_rev', pt3_ax]))]))),
          length new_type_names)
       end) atoms;
 
@@ -393,7 +393,7 @@
         val permT2 = mk_permT (Type (name2, []));
         val Ts = map body_type perm_types;
         val cp_inst = cp_inst_of thy name1 name2;
-        val simps = global_simpset_of thy addsimps (perm_fun_def ::
+        fun simps ctxt = simpset_of ctxt addsimps (perm_fun_def ::
           (if name1 <> name2 then
              let val dj = dj_thm_of thy name2 name1
              in [dj RS (cp_inst RS dj_cp), dj RS dj_perm_perm_forget] end
@@ -422,12 +422,12 @@
                      perm2 $ (perm3 $ pi1 $ pi2) $ (perm1 $ pi1 $ Free (x, T)))
                   end)
                 (perm_names ~~ Ts ~~ perm_indnames)))))
-          (fn _ => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
-             ALLGOALS (asm_full_simp_tac simps)]))
+          (fn {context = ctxt, ...} => EVERY [Datatype_Aux.ind_tac induct perm_indnames 1,
+             ALLGOALS (asm_full_simp_tac (simps ctxt))]))
       in
         fold (fn (s, tvs) => fn thy => AxClass.prove_arity
             (s, map (inter_sort thy sort o snd) tvs, [cp_class])
-            (Class.intro_classes_tac [] THEN ALLGOALS (resolve_tac thms)) thy)
+            (fn _ => Class.intro_classes_tac [] THEN ALLGOALS (resolve_tac thms)) thy)
           (full_new_type_names' ~~ tyvars) thy
         |> Theory.checkpoint
       end;
@@ -439,7 +439,7 @@
         in
           fold (fn (s, tvs) => fn thy => AxClass.prove_arity
               (s, map (inter_sort thy [pt_name] o snd) tvs, [pt_name])
-              (EVERY
+              (fn _ => EVERY
                 [Class.intro_classes_tac [],
                  resolve_tac perm_empty_thms 1,
                  resolve_tac perm_append_thms 1,
@@ -561,9 +561,9 @@
                  S $ (Const ("Nominal.perm", permT --> T --> T) $
                    Free ("pi", permT) $ Free (x, T)))
                end) (rep_set_names'' ~~ recTs' ~~ perm_indnames')))))
-        (fn _ => EVERY
+        (fn {context = ctxt, ...} => EVERY
            [Datatype_Aux.ind_tac rep_induct [] 1,
-            ALLGOALS (simp_tac (global_simpset_of thy4 addsimps
+            ALLGOALS (simp_tac (simpset_of ctxt addsimps
               (Thm.symmetric perm_fun_def :: abs_perm))),
             ALLGOALS (resolve_tac rep_intrs THEN_ALL_NEW assume_tac)])),
         length new_type_names));
@@ -621,10 +621,10 @@
           in AxClass.prove_arity
             (Sign.intern_type thy name,
               map (inter_sort thy sort o snd) tvs, [pt_class])
-            (EVERY [Class.intro_classes_tac [],
+            (fn ctxt => EVERY [Class.intro_classes_tac [],
               rewrite_goals_tac [perm_def],
-              asm_full_simp_tac (global_simpset_of thy addsimps [Rep_inverse]) 1,
-              asm_full_simp_tac (global_simpset_of thy addsimps
+              asm_full_simp_tac (simpset_of ctxt addsimps [Rep_inverse]) 1,
+              asm_full_simp_tac (simpset_of ctxt addsimps
                 [Rep RS perm_closed RS Abs_inverse]) 1,
               asm_full_simp_tac (HOL_basic_ss addsimps [Global_Theory.get_thm thy
                 ("pt_" ^ Long_Name.base_name atom ^ "3")]) 1]) thy
@@ -651,9 +651,9 @@
           AxClass.prove_arity
             (Sign.intern_type thy name,
               map (inter_sort thy sort o snd) tvs, [cp_class])
-            (EVERY [Class.intro_classes_tac [],
+            (fn ctxt => EVERY [Class.intro_classes_tac [],
               rewrite_goals_tac [perm_def],
-              asm_full_simp_tac (global_simpset_of thy addsimps
+              asm_full_simp_tac (simpset_of ctxt addsimps
                 ((Rep RS perm_closed1 RS Abs_inverse) ::
                  (if atom1 = atom2 then []
                   else [Rep RS perm_closed2 RS Abs_inverse]))) 1,
@@ -841,8 +841,8 @@
     fun prove_distinct_thms _ [] = []
       | prove_distinct_thms (p as (rep_thms, dist_lemma)) (t :: ts) =
           let
-            val dist_thm = Goal.prove_global_future thy8 [] [] t (fn _ =>
-              simp_tac (global_simpset_of thy8 addsimps (dist_lemma :: rep_thms)) 1)
+            val dist_thm = Goal.prove_global_future thy8 [] [] t (fn {context = ctxt, ...} =>
+              simp_tac (simpset_of ctxt addsimps (dist_lemma :: rep_thms)) 1)
           in
             dist_thm :: Drule.export_without_context (dist_thm RS not_sym) ::
               prove_distinct_thms p ts
@@ -889,8 +889,8 @@
               (pt_class_of thy8 atom :: map (cp_class_of thy8 atom) (remove (op =) atom dt_atoms))
               (HOLogic.mk_Trueprop (HOLogic.mk_eq
                 (perm (list_comb (c, l_args)), list_comb (c, r_args)))))
-            (fn _ => EVERY
-              [simp_tac (global_simpset_of thy8 addsimps (constr_rep_thm :: perm_defs)) 1,
+            (fn {context = ctxt, ...} => EVERY
+              [simp_tac (simpset_of ctxt addsimps (constr_rep_thm :: perm_defs)) 1,
                simp_tac (HOL_basic_ss addsimps (Rep_thms @ Abs_inverse_thms @
                  constr_defs @ perm_closed_thms)) 1,
                TRY (simp_tac (HOL_basic_ss addsimps
@@ -945,8 +945,8 @@
               (HOLogic.mk_Trueprop (HOLogic.mk_eq
                 (HOLogic.mk_eq (list_comb (c, args1), list_comb (c, args2)),
                  foldr1 HOLogic.mk_conj eqs))))
-            (fn _ => EVERY
-               [asm_full_simp_tac (global_simpset_of thy8 addsimps (constr_rep_thm ::
+            (fn {context = ctxt, ...} => EVERY
+               [asm_full_simp_tac (simpset_of ctxt addsimps (constr_rep_thm ::
                   rep_inject_thms')) 1,
                 TRY (asm_full_simp_tac (HOL_basic_ss addsimps (fresh_def :: supp_def ::
                   alpha @ abs_perm @ abs_fresh @ rep_inject_thms @
@@ -1075,8 +1075,8 @@
                  Const ("Finite_Set.finite", HOLogic.mk_setT atomT --> HOLogic.boolT) $
                    (Const ("Nominal.supp", T --> HOLogic.mk_setT atomT) $ Free (s, T)))
                    (indnames ~~ recTs)))))
-           (fn _ => Datatype_Aux.ind_tac dt_induct indnames 1 THEN
-            ALLGOALS (asm_full_simp_tac (global_simpset_of thy8 addsimps
+           (fn {context = ctxt, ...} => Datatype_Aux.ind_tac dt_induct indnames 1 THEN
+            ALLGOALS (asm_full_simp_tac (simpset_of ctxt addsimps
               (abs_supp @ supp_atm @
                Global_Theory.get_thms thy8 ("fs_" ^ Long_Name.base_name atom ^ "1") @
                flat supp_thms))))),
@@ -1107,7 +1107,7 @@
           val sort = Sign.minimize_sort thy (Sign.certify_sort thy (class :: pt_cp_sort));
         in fold (fn Type (s, Ts) => AxClass.prove_arity
           (s, map (inter_sort thy sort o snd o dest_TFree) Ts, [class])
-          (Class.intro_classes_tac [] THEN resolve_tac ths 1)) newTs thy
+          (fn _ => Class.intro_classes_tac [] THEN resolve_tac ths 1)) newTs thy
         end) (atoms ~~ finite_supp_thms) ||>
       Theory.checkpoint;
 
--- a/src/Pure/axclass.ML	Wed Apr 10 12:31:35 2013 +0200
+++ b/src/Pure/axclass.ML	Wed Apr 10 13:10:38 2013 +0200
@@ -26,8 +26,8 @@
   val define_overloaded: binding -> string * term -> theory -> thm * theory
   val add_classrel: thm -> theory -> theory
   val add_arity: thm -> theory -> theory
-  val prove_classrel: class * class -> tactic -> theory -> theory
-  val prove_arity: string * sort list * sort -> tactic -> theory -> theory
+  val prove_classrel: class * class -> (Proof.context -> tactic) -> theory -> theory
+  val prove_arity: string * sort list * sort -> (Proof.context -> tactic) -> theory -> theory
   val define_class: binding * class list -> string list ->
     (Thm.binding * term list) list -> theory -> class * theory
   val axiomatize_class: binding * class list -> theory -> theory
@@ -439,9 +439,11 @@
   let
     val ctxt = Proof_Context.init_global thy;
     val (c1, c2) = cert_classrel thy raw_rel;
-    val th = Goal.prove ctxt [] [] (Logic.mk_classrel (c1, c2)) (K tac) handle ERROR msg =>
-      cat_error msg ("The error(s) above occurred while trying to prove class relation " ^
-        quote (Syntax.string_of_classrel ctxt [c1, c2]));
+    val th =
+      Goal.prove ctxt [] [] (Logic.mk_classrel (c1, c2)) (fn {context, ...} => tac context)
+        handle ERROR msg =>
+          cat_error msg ("The error(s) above occurred while trying to prove class relation " ^
+            quote (Syntax.string_of_classrel ctxt [c1, c2]));
   in
     thy |> add_classrel th
   end;
@@ -452,10 +454,12 @@
     val arity = Proof_Context.cert_arity ctxt raw_arity;
     val names = map (prefix arity_prefix) (Logic.name_arities arity);
     val props = Logic.mk_arities arity;
-    val ths = Goal.prove_multi ctxt [] [] props
-      (fn _ => Goal.precise_conjunction_tac (length props) 1 THEN tac) handle ERROR msg =>
-        cat_error msg ("The error(s) above occurred while trying to prove type arity " ^
-          quote (Syntax.string_of_arity ctxt arity));
+    val ths =
+      Goal.prove_multi ctxt [] [] props
+      (fn {context, ...} => Goal.precise_conjunction_tac (length props) 1 THEN tac context)
+        handle ERROR msg =>
+          cat_error msg ("The error(s) above occurred while trying to prove type arity " ^
+            quote (Syntax.string_of_arity ctxt arity));
   in
     thy |> fold add_arity ths
   end;