Tuned simplifier not to re-normalized already normalized terms.
authornipkow
Wed, 07 Oct 1998 18:17:37 +0200
changeset 5623 75b513db9a3a
parent 5622 5b56804edf85
child 5624 4813dd0fe6e5
Tuned simplifier not to re-normalized already normalized terms.
src/Pure/thm.ML
--- a/src/Pure/thm.ML	Wed Oct 07 17:51:11 1998 +0200
+++ b/src/Pure/thm.ML	Wed Oct 07 18:17:37 1998 +0200
@@ -1509,7 +1509,9 @@
 (*
   A "mss" contains data needed during conversion:
     rules: discrimination net of rewrite rules;
-    congs: association list of congruence rules;
+    congs: association list of congruence rules and
+           a flag iff all of the congruences are 'full'.
+          A congruence is 'full' if it enforces normalization of all arguments.
     procs: discrimination net of simplification procedures
       (functions that prove rewrite rules on the fly);
     bounds: names of bound variables already used
@@ -1524,7 +1526,7 @@
 datatype meta_simpset =
   Mss of {
     rules: rrule Net.net,
-    congs: (string * cong) list,
+    congs: (string * cong) list * bool,
     procs: simproc Net.net,
     bounds: string list,
     prems: thm list,
@@ -1542,7 +1544,8 @@
 
 val empty_mss =
   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
-  in mk_mss (Net.empty, [], Net.empty, [], [], mk_rews, Term.termless) end;
+  in mk_mss (Net.empty, ([],true), Net.empty, [], [], mk_rews, Term.termless)
+  end;
 
 
 
@@ -1552,7 +1555,7 @@
 
 fun dest_mss (Mss {rules, congs, procs, ...}) =
   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
-   congs = map (fn (_, {thm, ...}) => thm) congs,
+   congs = map (fn (_, {thm, ...}) => thm) (fst congs),
    procs =
      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
      |> partition_eq eq_snd
@@ -1562,13 +1565,14 @@
 (* merge_mss *)		(*NOTE: ignores mk_rews and termless of 2nd mss*)
 
 fun merge_mss
- (Mss {rules = rules1, congs = congs1, procs = procs1, bounds = bounds1,
-    prems = prems1, mk_rews, termless},
-  Mss {rules = rules2, congs = congs2, procs = procs2, bounds = bounds2,
-    prems = prems2, ...}) =
+ (Mss {rules = rules1, congs = (congs1,full1), procs = procs1,
+       bounds = bounds1, prems = prems1, mk_rews, termless},
+  Mss {rules = rules2, congs = (congs2,full2), procs = procs2,
+       bounds = bounds2, prems = prems2, ...}) =
       mk_mss
        (Net.merge (rules1, rules2, eq_rrule),
-        generic_merge (eq_cong o pairself snd) I I congs1 congs2,
+        (generic_merge (eq_cong o pairself snd) I I congs1 congs2,
+         full1 andalso full2),
         Net.merge (procs1, procs2, eq_simproc),
         merge_lists bounds1 bounds2,
         generic_merge eq_prem I I prems1 prems2,
@@ -1713,6 +1717,33 @@
 
 (* add_congs *)
 
+(*FIXME -> term.ML *)
+fun is_Bound (Bound _) = true
+fun is_Bound _         = false;
+
+fun is_full_cong_prems [] varpairs = null varpairs
+  | is_full_cong_prems (p::prems) varpairs =
+    (case Logic.strip_assums_concl p of
+       Const("==",_) $ lhs $ rhs =>
+         let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
+         in is_Var x  andalso  forall is_Bound xs  andalso
+            null(findrep(xs))  andalso xs=ys andalso
+            (x,y) mem varpairs andalso
+            is_full_cong_prems (p::prems) (varpairs\(x,y))
+         end
+     | _ => false);
+
+fun is_full_cong (Thm{prop,...}) =
+let val prems = Logic.strip_imp_prems prop
+    and concl = Logic.strip_imp_concl prop
+    val (lhs,rhs) = Logic.dest_equals concl
+    val (f,xs) = strip_comb lhs
+    and (g,ys) = strip_comb rhs
+in
+  f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
+  is_full_cong_prems prems (xs ~~ ys)
+end
+
 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless}, thm) =
   let
     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
@@ -1720,9 +1751,11 @@
 (*   val lhs = Pattern.eta_contract lhs; *)
     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
       raise SIMPLIFIER ("Congruence must start with a constant", thm);
+    val (alist,full) = congs
+    val full2 = full andalso is_full_cong thm
   in
-    mk_mss (rules, (a, {lhs = lhs, thm = thm}) :: congs, procs, bounds,
-      prems, mk_rews, termless)
+    mk_mss (rules, ((a, {lhs = lhs, thm = thm}) :: alist, full2),
+            procs, bounds, prems, mk_rews, termless)
   end;
 
 val (op add_congs) = foldl add_cong;
@@ -1737,9 +1770,11 @@
 (*   val lhs = Pattern.eta_contract lhs; *)
     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
       raise SIMPLIFIER ("Congruence must start with a constant", thm);
+    val (alist,full) = congs
+    val alist2 = filter (fn (x,_)=> x<>a) alist
+    val full2 = forall (fn(_,{thm,...}) => is_full_cong thm) alist2
   in
-    mk_mss (rules, filter (fn (x,_)=> x<>a) congs, procs, bounds,
-      prems, mk_rews, termless)
+    mk_mss (rules, (alist2,full2), procs, bounds, prems, mk_rews, termless)
   end;
 
 val (op del_congs) = foldl del_cong;
@@ -1830,7 +1865,8 @@
 type termrec = (Sign.sg_ref * term list) * term;
 type conv = meta_simpset -> termrec -> termrec;
 
-fun check_conv (thm as Thm{shyps,hyps,prop,sign_ref,der,...}, prop0, ders) =
+fun check_conv
+      (thm as Thm{shyps,hyps,prop,sign_ref,der,...}, prop0, ders) =
   let fun err() = (trace_thm false "Proved wrong thm (Check subgoaler?)" thm;
                    trace_term false "Should have proved:" (Sign.deref sign_ref) prop0;
                    None)
@@ -1873,6 +1909,15 @@
 
 (* conversion to apply the meta simpset to a term *)
 
+(* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
+   normalized terms by carrying around the rhs of the rewrite rule just
+   applied. This is called the `skeleton'. It is decomposed in parallel
+   with the term. Once a Var is encountered, the corresponding term is
+   already in normal form.
+   skel0 is a dummy skeleton that is to enforce complete normalization.
+*)
+val skel0 = Bound 0;
+
 (*
   we try in order:
     (1) beta reduction
@@ -1885,7 +1930,7 @@
 *)
 
 fun rewritec (prover,sign_reft,maxt)
-             (mss as Mss{rules, procs, termless, prems, ...}) 
+             (mss as Mss{rules, procs, termless, prems, congs, ...}) 
              (t:term,etc as (shypst,hypst,ders)) =
   let
     val signt = Sign.deref sign_reft;
@@ -1916,12 +1961,18 @@
         if perm andalso not(termless(rhs',lhs')) then None
         else (trace_thm false "Applying instance of rewrite rule:" thm;
               if unconditional
-              then (trace_thm false "Rewriting:" thm'; 
-                    Some(rhs', (shyps', hyps', der'::ders)))
+              then (trace_thm false "Rewriting:" thm';
+                    let val (_,rhs) = Logic.dest_equals prop
+                    in Some((rhs', (shyps', hyps', der'::ders)),
+                            if snd congs then rhs else skel0)
+                        (* use rhs as depth-limit only if all congs are full *)
+                    end)
               else (trace_thm false "Trying to rewrite:" thm';
                     case prover mss thm' of
                       None       => (trace_thm false "FAILED" thm'; None)
-                    | Some(thm2) => check_conv(thm2,prop',ders)))
+                    | Some(thm2) =>
+                        (case check_conv(thm2,prop',ders) of
+                           None => None | Some trec => Some(trec,skel0))))
       end
 
     fun rews [] = None
@@ -1952,7 +2003,7 @@
                   | some => some)))
           else proc_rews eta_t ps;
   in case t of
-       Abs (_, _, body) $ u => Some (subst_bound (u, body), etc)
+       Abs (_, _, body) $ u => Some ((subst_bound (u, body), etc),skel0)
      | _ => (case rews (sort_rrules (Net.match_term rules t)) of
                None => proc_rews (Pattern.eta_contract t)
                                  (Net.match_term procs t)
@@ -1999,29 +2050,33 @@
 
 fun bottomc ((simprem,useprem,mutsimp),prover,sign_ref,maxidx) =
   let
-    fun botc fail mss trec =
-          (case subc mss trec of
+    fun botc fail skel mss trec =
+          if is_Var skel then if fail then None else Some(trec)
+          else
+          (case subc skel mss trec of
              some as Some(trec1) =>
                (case rewritec (prover,sign_ref,maxidx) mss trec1 of
-                  Some(trec2) => botc false mss trec2
+                  Some(trec2,skel2) => botc false skel2 mss trec2
                 | None => some)
            | None =>
                (case rewritec (prover,sign_ref,maxidx) mss trec of
-                  Some(trec2) => botc false mss trec2
+                  Some(trec2,skel2) => botc false skel2 mss trec2
                 | None => if fail then None else Some(trec)))
 
-    and try_botc mss trec = (case botc true mss trec of
-                                Some(trec1) => trec1
-                              | None => trec)
+    and try_botc mss trec =
+          (case botc true skel0 mss trec of
+             Some(trec1) => trec1 | None => trec)
 
-    and subc (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless})
+    and subc skel
+             (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless})
              (trec as (t0:term,etc:sort list*term list * rule mtree list)) =
        (case t0 of
            Abs(a,T,t) =>
              let val b = variant bounds a
                  val v = Free("." ^ b,T)
                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless)
-             in case botc true mss' (subst_bound(v,t),etc) of
+                 val skel' = case skel of Abs(_,_,sk) => sk | _ => skel0
+             in case botc true skel' mss' (subst_bound(v,t),etc) of
                   Some(t',etc') => Some(Abs(a, T, abstract_over(v,t')), etc')
                 | None => None
              end
@@ -2029,25 +2084,30 @@
              Const("==>",_)$s  => Some(impc(s,u,mss,etc))
            | Abs(_,_,body) =>
                let val trec = (subst_bound(u,body), etc)
-               in case subc mss trec of
+               in case subc skel0 mss trec of
                     None => Some(trec)
                   | trec => trec
                end
            | _  =>
                let fun appc() =
-                     (case botc true mss (t,etc) of
+                     let val (tskel,uskel) =
+                                case skel of tskel$uskel => (tskel,uskel)
+                                           | _ => (skel0,skel0)
+                     in
+                     (case botc true tskel mss (t,etc) of
                         Some(t1,etc1) =>
-                          (case botc true mss (u,etc1) of
+                          (case botc true uskel mss (u,etc1) of
                              Some(u1,etc2) => Some(t1$u1, etc2)
                            | None => Some(t1$u, etc1))
                       | None =>
-                          (case botc true mss (u,etc) of
+                          (case botc true uskel mss (u,etc) of
                              Some(u1,etc1) => Some(t$u1, etc1)
                            | None => None))
+                     end
                    val (h,ts) = strip_comb t
                in case h of
                     Const(a,_) =>
-                      (case assoc_string(congs,a) of
+                      (case assoc_string(fst congs,a) of
                          None => appc()
                        | Some(cong) =>
                            (congc (prover mss,sign_ref,maxidx) cong trec
@@ -2092,9 +2152,9 @@
           let val trec = disch1 trec2
           in case rewritec (prover,sign_ref,maxidx) mss trec of
                None => (None,trec)
-             | Some(Const("==>",_)$prem$conc,etc) =>
+             | Some((Const("==>",_)$prem$conc,etc),_) =>
                  mut_impc(prems,prem,conc,mss,etc)
-             | Some(trec') => (None,trec')
+             | Some(trec',_) => (None,trec')
           end
 
         fun simpconc() =