Made simplification procedures simpset-aware.
authorskalberg
Wed, 30 Jun 2004 00:42:59 +0200
changeset 15011 35be762f58f9
parent 15010 72fbe711e414
child 15012 28fa57b57209
Made simplification procedures simpset-aware.
NEWS
TFL/rules.ML
src/Pure/meta_simplifier.ML
src/Pure/tactic.ML
--- a/NEWS	Tue Jun 29 11:18:34 2004 +0200
+++ b/NEWS	Wed Jun 30 00:42:59 2004 +0200
@@ -6,6 +6,11 @@
 
 *** General ***
 
+* Pure: Simplification procedures can now take the current simpset as
+  an additional argument; This is useful for calling the simplifier
+  recursively.  See the functions MetaSimplifier.full_{mk_simproc,
+  simproc,simproc_i}.
+
 * Pure: considerably improved version of 'constdefs' command.  Now
   performs automatic type-inference of declared constants; additional
   support for local structure declarations (cf. locales and HOL
--- a/TFL/rules.ML	Tue Jun 29 11:18:34 2004 +0200
+++ b/TFL/rules.ML	Wed Jun 30 00:42:59 2004 +0200
@@ -433,7 +433,7 @@
 local fun rew_conv mss = MetaSimplifier.rewrite_cterm (true,false,false) (K(K None)) mss
 in
 fun simpl_conv ss thl ctm =
- rew_conv (MetaSimplifier.mss_of (#simps (MetaSimplifier.dest_mss (#mss (rep_ss ss))) @ thl)) ctm
+ rew_conv (MetaSimplifier.ss_of (#simps (MetaSimplifier.dest_mss (#mss (rep_ss ss))) @ thl)) ctm
  RS meta_eq_to_obj_eq
 end;
 
@@ -688,7 +688,7 @@
                      val eq = Logic.strip_imp_concl imp
                      val lhs = tych(get_lhs eq)
                      val mss' = MetaSimplifier.add_prems(mss, map ASSUME ants)
-                     val lhs_eq_lhs1 = MetaSimplifier.rewrite_cterm (false,true,false) (prover used) mss' lhs
+                     val lhs_eq_lhs1 = MetaSimplifier.rewrite_cterm (false,true,false) (prover used) (MetaSimplifier.from_mss mss') lhs
                        handle U.ERR _ => Thm.reflexive lhs
                      val dummy = print_thms "proven:" [lhs_eq_lhs1]
                      val lhs_eq_lhs2 = implies_intr_list ants lhs_eq_lhs1
@@ -710,7 +710,7 @@
                   val QeqQ1 = pbeta_reduce (tych Q)
                   val Q1 = #2(D.dest_eq(cconcl QeqQ1))
                   val mss' = MetaSimplifier.add_prems(mss, map ASSUME ants1)
-                  val Q1eeqQ2 = MetaSimplifier.rewrite_cterm (false,true,false) (prover used') mss' Q1
+                  val Q1eeqQ2 = MetaSimplifier.rewrite_cterm (false,true,false) (prover used') (MetaSimplifier.from_mss mss') Q1
                                 handle U.ERR _ => Thm.reflexive Q1
                   val Q2 = #2 (Logic.dest_equals (Thm.prop_of Q1eeqQ2))
                   val Q3 = tych(list_comb(list_mk_aabs(vstrl,Q2),vstrl))
@@ -736,7 +736,7 @@
                      val ants1 = map tych ants
                      val mss' = MetaSimplifier.add_prems(mss, map ASSUME ants1)
                      val Q_eeq_Q1 = MetaSimplifier.rewrite_cterm
-                        (false,true,false) (prover used') mss' (tych Q)
+                        (false,true,false) (prover used') (MetaSimplifier.from_mss mss') (tych Q)
                       handle U.ERR _ => Thm.reflexive (tych Q)
                      val lhs_eeq_lhs2 = implies_intr_list ants1 Q_eeq_Q1
                      val lhs_eq_lhs2 = lhs_eeq_lhs2 RS meta_eq_to_obj_eq
@@ -806,7 +806,7 @@
     val ctm = cprop_of th
     val names = add_term_names (term_of ctm, [])
     val th1 = MetaSimplifier.rewrite_cterm(false,true,false)
-      (prover names) (MetaSimplifier.add_congs(MetaSimplifier.mss_of [cut_lemma'], congs)) ctm
+      (prover names) (MetaSimplifier.addeqcongs(MetaSimplifier.ss_of [cut_lemma'], congs)) ctm
     val th2 = equal_elim th1 th
  in
  (th2, filter (not o restricted) (!tc_list))
--- a/src/Pure/meta_simplifier.ML	Tue Jun 29 11:18:34 2004 +0200
+++ b/src/Pure/meta_simplifier.ML	Wed Jun 30 00:42:59 2004 +0200
@@ -20,12 +20,14 @@
 signature AUX_SIMPLIFIER =
 sig
   type meta_simpset
+  type simpset
   type simproc
+  val full_mk_simproc: string -> cterm list
+    -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
   val mk_simproc: string -> cterm list
     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
   type solver
   val mk_solver: string -> (thm list -> int -> tactic) -> solver
-  type simpset
   val empty_ss: simpset
   val rep_ss: simpset ->
    {mss: meta_simpset,
@@ -33,7 +35,9 @@
     subgoal_tac: simpset -> int -> tactic,
     loop_tacs: (string * (int -> tactic)) list,
     unsafe_solvers: solver list,
-    solvers: solver list};
+    solvers: solver list}
+  val from_mss: meta_simpset -> simpset
+  val ss_of            : thm list -> simpset
   val print_ss: simpset -> unit
   val setsubgoaler: simpset *  (simpset -> int -> tactic) -> simpset
   val setloop:      simpset *             (int -> tactic) -> simpset
@@ -63,6 +67,10 @@
     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
   val simproc_i: Sign.sg -> string -> term list
     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
+  val full_simproc: Sign.sg -> string -> string list
+    -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
+  val full_simproc_i: Sign.sg -> string -> term list
+    -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
   val clear_ss  : simpset -> simpset
   val simp_thm  : bool * bool * bool -> simpset -> thm -> thm
   val simp_cterm: bool * bool * bool -> simpset -> cterm -> thm
@@ -85,10 +93,10 @@
   val add_congs         : meta_simpset * thm list -> meta_simpset
   val del_congs         : meta_simpset * thm list -> meta_simpset
   val add_simprocs      : meta_simpset *
-    (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
+    (string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp) list
       -> meta_simpset
   val del_simprocs      : meta_simpset *
-    (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
+    (string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp) list
       -> meta_simpset
   val add_prems         : meta_simpset * thm list -> meta_simpset
   val prems_of_mss      : meta_simpset -> thm list
@@ -100,19 +108,19 @@
   val get_mk_eq_True    : meta_simpset -> thm -> thm option
   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
   val rewrite_cterm: bool * bool * bool ->
-    (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
+    (meta_simpset -> thm -> thm option) -> simpset -> cterm -> thm
   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
   val rewrite_thm       : bool * bool * bool
                           -> (meta_simpset -> thm -> thm option)
-                          -> meta_simpset -> thm -> thm
+                          -> simpset -> thm -> thm
   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
   val rewrite_goal_rule : bool* bool * bool
                           -> (meta_simpset -> thm -> thm option)
-                          -> meta_simpset -> int -> thm -> thm
+                          -> simpset -> int -> thm -> thm
   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
   val asm_rewrite_goal_tac: bool*bool*bool ->
-    (meta_simpset -> tactic) -> meta_simpset -> int -> tactic
+    (meta_simpset -> tactic) -> simpset -> int -> tactic
 
 end;
 
@@ -181,8 +189,6 @@
        in which case there is nothing better to do.
 *)
 type cong = {thm: thm, lhs: cterm};
-type meta_simproc =
- {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
 
 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
@@ -193,11 +199,6 @@
 fun eq_prem (thm1, thm2) =
   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
 
-fun eq_simproc ({id = s1, ...}:meta_simproc, {id = s2, ...}:meta_simproc) = (s1 = s2);
-
-fun mk_simproc (name, proc, lhs, id) =
-  {name = name, proc = proc, lhs = lhs, id = id};
-
 
 (* datatype mss *)
 
@@ -219,6 +220,8 @@
     depth: depth of conditional rewriting;
 *)
 
+datatype solver = Solver of string * (thm list -> int -> tactic) * stamp;
+
 datatype meta_simpset =
   Mss of {
     rules: rrule Net.net,
@@ -230,7 +233,17 @@
               mk_sym: thm -> thm option,
               mk_eq_True: thm -> thm option},
     termless: term * term -> bool,
-    depth: int};
+    depth: int}
+and simpset =
+  Simpset of {
+    mss: meta_simpset,
+    mk_cong: thm -> thm,
+    subgoal_tac: simpset -> int -> tactic,
+    loop_tacs: (string * (int -> tactic)) list,
+    unsafe_solvers: solver list,
+    solvers: solver list}
+withtype meta_simproc =
+ {name: string, proc: simpset -> Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
 
 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
@@ -257,6 +270,14 @@
           )
   end;
 
+datatype simproc =
+  Simproc of string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp
+
+fun eq_simproc ({id = s1, ...}:meta_simproc, {id = s2, ...}:meta_simproc) = (s1 = s2);
+
+fun mk_simproc (name, proc, lhs, id) =
+  {name = name, proc = proc, lhs = lhs, id = id};
+
 
 (** simpset operations **)
 
@@ -591,11 +612,15 @@
 
 (* datatype simproc *)
 
-datatype simproc =
-  Simproc of string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp;
+fun full_mk_simproc name lhss proc =
+  Simproc (name, map (Thm.cterm_fun Logic.varify) lhss, proc, stamp ());
+
+fun full_simproc sg name ss =
+  full_mk_simproc name (map (fn s => Thm.read_cterm sg (s, TypeInfer.logicT)) ss);
+fun full_simproc_i sg name = full_mk_simproc name o map (Thm.cterm_of sg);
 
 fun mk_simproc name lhss proc =
-  Simproc (name, map (Thm.cterm_fun Logic.varify) lhss, proc, stamp ());
+  Simproc (name, map (Thm.cterm_fun Logic.varify) lhss, K proc, stamp ());
 
 fun simproc sg name ss =
   mk_simproc name (map (fn s => Thm.read_cterm sg (s, TypeInfer.logicT)) ss);
@@ -607,8 +632,6 @@
 
 (** solvers **)
 
-datatype solver = Solver of string * (thm list -> int -> tactic) * stamp;
-
 fun mk_solver name solver = Solver (name, solver, stamp());
 fun eq_solver (Solver (_, _, s1), Solver(_, _, s2)) = s1 = s2;
 
@@ -624,22 +647,13 @@
 
 (* type simpset *)
 
-datatype simpset =
-  Simpset of {
-    mss: meta_simpset,
-    mk_cong: thm -> thm,
-    subgoal_tac: simpset -> int -> tactic,
-    loop_tacs: (string * (int -> tactic)) list,
-    unsafe_solvers: solver list,
-    solvers: solver list};
-
 fun make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers solvers =
   Simpset {mss = mss, mk_cong = mk_cong, subgoal_tac = subgoal_tac,
     loop_tacs = loop_tacs, unsafe_solvers = unsafe_solvers, solvers = solvers};
 
-val empty_ss =
-  let val mss = set_mk_sym (empty_mss, Some o symmetric_fun)
-  in make_ss mss I (K (K no_tac)) [] [] [] end;
+fun from_mss mss = make_ss mss I (K (K no_tac)) [] [] [];
+
+val empty_ss = from_mss (set_mk_sym (empty_mss, Some o symmetric_fun));
 
 fun rep_ss (Simpset args) = args;
 fun prems_of_ss (Simpset {mss, ...}) = prems_of_mss mss;
@@ -850,7 +864,7 @@
 *)
 
 fun rewritec (prover, signt, maxt)
-             (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
+             (ss as Simpset{mss=mss as Mss{rules, procs, termless, prems, congs, depth,...},...}) t =
   let
     val eta_thm = Thm.eta_conversion t;
     val eta_t' = rhs_of eta_thm;
@@ -917,7 +931,7 @@
           if Pattern.matches tsigt (term_of lhs, term_of t) then
             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
              case transform_failure (curry SIMPROC_FAIL name)
-                 (fn () => proc signt prems eta_t) () of
+                 (fn () => proc ss signt prems eta_t) () of
                None => (debug false "FAILED"; proc_rews ps)
              | Some raw_thm =>
                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
@@ -969,20 +983,24 @@
 fun transitive2 thm = transitive1 (Some thm);
 fun transitive3 thm = transitive1 thm o Some;
 
-fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
+fun replace_mss (Simpset{mss=_,mk_cong,subgoal_tac,loop_tacs,unsafe_solvers,solvers}) mss_new =
+    Simpset{mss=mss_new,mk_cong=mk_cong,subgoal_tac=subgoal_tac,loop_tacs=loop_tacs,
+	    unsafe_solvers=unsafe_solvers,solvers=solvers};
+
+fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) (ss as Simpset{mss,...}) =
   let
     fun botc skel mss t =
           if is_Var skel then None
           else
           (case subc skel mss t of
              some as Some thm1 =>
-               (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
+               (case rewritec (prover, sign, maxidx) (replace_mss ss mss) (rhs_of thm1) of
                   Some (thm2, skel2) =>
                     transitive2 (transitive thm1 thm2)
                       (botc skel2 mss (rhs_of thm2))
                 | None => some)
            | None =>
-               (case rewritec (prover, sign, maxidx) mss t of
+               (case rewritec (prover, sign, maxidx) (replace_mss ss mss) t of
                   Some (thm2, skel2) => transitive2 thm2
                     (botc skel2 mss (rhs_of thm2))
                 | None => None))
@@ -1093,7 +1111,7 @@
             val concl' =
               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
             val dprem = apsome (curry (disch false) prem)
-          in case rewritec (prover, sign, maxidx) mss' concl' of
+          in case rewritec (prover, sign, maxidx) (replace_mss ss mss') concl' of
               None => rebuild prems concl' rrss asms mss (dprem eq)
             | Some (eq', _) => transitive2 (foldl (disch false o swap)
                   (the (transitive3 (dprem eq) eq'), prems))
@@ -1157,7 +1175,7 @@
            end)
        end
 
- in try_botc end;
+ in try_botc mss end;
 
 
 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
@@ -1172,25 +1190,27 @@
     prover: how to solve premises in conditional rewrites and congruences
 *)
 
-fun rewrite_cterm mode prover mss ct =
+fun rewrite_cterm mode prover (ss as Simpset{mss,...}) ct =
   let val {sign, t, maxidx, ...} = rep_cterm ct
       val Mss{depth, ...} = mss
   in trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
      simp_depth := depth;
-     bottomc (mode, prover, sign, maxidx) mss ct
+     bottomc (mode, prover, sign, maxidx) ss ct
   end
   handle THM (s, _, thms) =>
     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
       Pretty.string_of (Display.pretty_thms thms));
 
+val ss_of = from_mss o mss_of
+
 (*Rewrite a cterm*)
 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
-  | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
+  | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (ss_of thms);
 
 (*Rewrite a theorem*)
 fun simplify_aux _ _ [] = (fn th => th)
   | simplify_aux prover full thms =
-      Drule.fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
+      Drule.fconv_rule (rewrite_cterm (full, false, false) prover (ss_of thms));
 
 fun rewrite_thm mode prover mss = Drule.fconv_rule (rewrite_cterm mode prover mss);
 
@@ -1198,12 +1218,12 @@
 fun rewrite_goals_rule_aux _ []   th = th
   | rewrite_goals_rule_aux prover thms th =
       Drule.fconv_rule (Drule.goals_conv (K true) (rewrite_cterm (true, true, false) prover
-        (mss_of thms))) th;
+        (ss_of thms))) th;
 
 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
-fun rewrite_goal_rule mode prover mss i thm =
+fun rewrite_goal_rule mode prover ss i thm =
   if 0 < i  andalso  i <= nprems_of thm
-  then Drule.fconv_rule (Drule.goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
+  then Drule.fconv_rule (Drule.goals_conv (fn j => j=i) (rewrite_cterm mode prover ss)) thm
   else raise THM("rewrite_goal_rule",i,[thm]);
 
 
@@ -1229,25 +1249,25 @@
 
 (*note: may instantiate unknowns that appear also in other subgoals*)
 fun generic_simp_tac safe mode =
-  fn (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers, ...}) =>
+  fn (ss as Simpset {mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers, ...}) =>
     let
       val solvs = app_sols (if safe then solvers else unsafe_solvers);
       fun simp_loop_tac i =
         asm_rewrite_goal_tac mode
           (solve_all_tac (mk_cong, subgoal_tac, loop_tacs, unsafe_solvers))
-          mss i
-        THEN (solvs (prems_of_mss mss) i ORELSE
+          ss i
+        THEN (solvs (prems_of_ss ss) i ORELSE
               TRY ((loop_tac loop_tacs THEN_ALL_NEW simp_loop_tac) i))
     in simp_loop_tac end;
 
 (** simplification rules and conversions **)
 
 fun simp rew mode
-     (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, ...}) thm =
+     (ss as Simpset {mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, ...}) thm =
   let
     val tacf = solve_all_tac (mk_cong, subgoal_tac, loop_tacs, unsafe_solvers);
     fun prover m th = apsome fst (Seq.pull (tacf m th));
-  in rew mode prover mss thm end;
+  in rew mode prover ss thm end;
 
 val simp_thm = simp rewrite_thm;
 val simp_cterm = simp rewrite_cterm;
--- a/src/Pure/tactic.ML	Tue Jun 29 11:18:34 2004 +0200
+++ b/src/Pure/tactic.ML	Wed Jun 30 00:42:59 2004 +0200
@@ -498,7 +498,7 @@
 val rewrite_goals_rule = MetaSimplifier.rewrite_goals_rule_aux simple_prover;
 
 fun rewrite_goal_tac rews =
-  MetaSimplifier.asm_rewrite_goal_tac (true, false, false) (K no_tac) (MetaSimplifier.mss_of rews);
+  MetaSimplifier.asm_rewrite_goal_tac (true, false, false) (K no_tac) (MetaSimplifier.ss_of rews);
 
 (*Rewrite throughout proof state. *)
 fun rewrite_tac defs = PRIMITIVE(rewrite_rule defs);