Completely reimplemented mutual simplification of premises.
authorberghofe
Mon, 30 Sep 2002 16:38:22 +0200
changeset 13607 6908230623a3
parent 13606 2f121149acfe
child 13608 9a6f43b8eae1
Completely reimplemented mutual simplification of premises.
src/Pure/meta_simplifier.ML
--- a/src/Pure/meta_simplifier.ML	Mon Sep 30 16:37:44 2002 +0200
+++ b/src/Pure/meta_simplifier.ML	Mon Sep 30 16:38:22 2002 +0200
@@ -95,8 +95,8 @@
   let val {sign, prop, ...} = rep_thm thm
   in trace_term false a sign prop end;
 
-fun trace_named_thm a thm =
-  trace_thm (a ^ " " ^ quote(Thm.name_of_thm thm) ^ ":") thm;
+fun trace_named_thm a (thm, name) =
+  trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
 
 end;
 
@@ -105,8 +105,9 @@
 
 (* basic components *)
 
-type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
+type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
 (* thm: the rewrite rule
+   name: name of theorem from which rewrite rule was extracted
    lhs: the left-hand side
    elhs: the etac-contracted lhs.
    fo:  use first-order matching
@@ -230,17 +231,17 @@
 
 (* add_simps *)
 
-fun mk_rrule2{thm,lhs,elhs,perm} =
+fun mk_rrule2{thm, name, lhs, elhs, perm} =
   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
-  in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
+  in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
 
-fun insert_rrule(mss as Mss {rules,...},
-                 rrule as {thm,lhs,elhs,perm}) =
-  (trace_named_thm "Adding rewrite rule" thm;
+fun insert_rrule quiet (mss as Mss {rules,...},
+                 rrule as {thm,name,lhs,elhs,perm}) =
+  (trace_named_thm "Adding rewrite rule" (thm, name);
    let val rrule2 as {elhs,...} = mk_rrule2 rrule
        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
    in upd_rules(mss,rules') end
-   handle Net.INSERT =>
+   handle Net.INSERT => if quiet then mss else
      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
 
 fun vperm (Var _, Var _) = true
@@ -296,49 +297,51 @@
     else (lhs, rhs)
   end;
 
-fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
+fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   case mk_eq_True thm of
     None => []
-  | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
-                    in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
+  | Some eq_True =>
+      let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
+      in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
 
 (* create the rewrite rule and possibly also the ==True variant,
    in case there are extra vars on the rhs *)
-fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
-  let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
+fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
+  let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
         (term_tvars rhs) subset (term_tvars lhs)
      then [rrule]
-     else mk_eq_True mss thm2 @ [rrule]
+     else mk_eq_True mss (thm2, name) @ [rrule]
   end;
 
-fun mk_rrule mss thm =
+fun mk_rrule mss (thm, name) =
   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
-  in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
+  in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
      (* weak test for loops: *)
      if rewrite_rule_extra_vars prems lhs rhs orelse
         is_Var (term_of elhs)
-     then mk_eq_True mss thm
-     else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
+     then mk_eq_True mss (thm, name)
+     else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   end;
 
-fun orient_rrule mss thm =
+fun orient_rrule mss (thm, name) =
   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
-  in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
+  in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
      else if reorient sign prems lhs rhs
           then if reorient sign prems rhs lhs
-               then mk_eq_True mss thm
+               then mk_eq_True mss (thm, name)
                else let val Mss{mk_rews={mk_sym,...},...} = mss
                     in case mk_sym thm of
                          None => []
                        | Some thm' =>
                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
-                           in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
+                           in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
                     end
-          else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
+          else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   end;
 
-fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
+fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
+  flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
 
 fun orient_comb_simps comb mk_rrule (mss,thms) =
   let val rews = extract_rews(mss,thms)
@@ -347,17 +350,14 @@
 
 (* Add rewrite rules explicitly; do not reorient! *)
 fun add_simps(mss,thms) =
-  orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
+  orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
 
-fun mss_of thms =
-  foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
+fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
+  (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
 
 fun extract_safe_rrules(mss,thm) =
   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
 
-fun add_safe_simp(mss,thm) =
-  foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
-
 (* del_simps *)
 
 fun del_rrule(mss as Mss {rules,...},
@@ -517,10 +517,6 @@
     Science of Computer Programming 3 (1983), pages 119-149.
 *)
 
-type prover = meta_simpset -> thm -> thm option;
-type termrec = (Sign.sg_ref * term list) * term;
-type conv = meta_simpset -> termrec -> termrec;
-
 val dest_eq = Drule.dest_equals o cprop_of;
 val lhs_of = fst o dest_eq;
 val rhs_of = snd o dest_eq;
@@ -549,7 +545,7 @@
   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   in if rewrite_rule_extra_vars prems lhs rhs
      then (prthm true "Extra vars on rhs:" thm; [])
-     else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
+     else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   end;
 
 
@@ -599,7 +595,7 @@
     val eta_t' = rhs_of eta_thm;
     val eta_t = term_of eta_t';
     val tsigt = Sign.tsig_of signt;
-    fun rew {thm, lhs, elhs, fo, perm} =
+    fun rew {thm, name, lhs, elhs, fo, perm} =
       let
         val {sign, prop, maxidx, ...} = rep_thm thm;
         val _ = if Sign.subsig (sign, signt) then ()
@@ -615,9 +611,9 @@
         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
       in
         if perm andalso not (termless (rhs', lhs'))
-        then (trace_named_thm "Cannot apply permutative rewrite rule" thm;
+        then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
               trace_thm "Term does not become smaller:" thm'; None)
-        else (trace_named_thm "Applying instance of rewrite rule" thm;
+        else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
            if unconditional
            then
              (trace_thm "Rewriting:" thm';
@@ -701,11 +697,15 @@
 val (cA, (cB, cC)) =
   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
 
-fun transitive' thm1 None = Some thm1
-  | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
+fun transitive1 None None = None
+  | transitive1 (Some thm1) None = Some thm1
+  | transitive1 None (Some thm2) = Some thm2
+  | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
 
-fun transitive'' None thm2 = Some thm2
-  | transitive'' (Some thm1) thm2 = Some (transitive thm1 thm2);
+fun transitive2 thm = transitive1 (Some thm);
+fun transitive3 thm = transitive1 thm o Some;
+
+fun imp_cong' e = combination (combination refl_implies e);
 
 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   let
@@ -716,12 +716,12 @@
              some as Some thm1 =>
                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
                   Some (thm2, skel2) =>
-                    transitive' (transitive thm1 thm2)
+                    transitive2 (transitive thm1 thm2)
                       (botc skel2 mss (rhs_of thm2))
                 | None => some)
            | None =>
                (case rewritec (prover, sign, maxidx) mss t of
-                  Some (thm2, skel2) => transitive' thm2
+                  Some (thm2, skel2) => transitive2 thm2
                     (botc skel2 mss (rhs_of thm2))
                 | None => None))
 
@@ -744,7 +744,7 @@
          | t $ _ => (case t of
              Const ("==>", _) $ _  =>
                let val (s, u) = Drule.dest_implies t0
-               in impc (s, u, mss) end
+               in impc t0 mss end
            | Abs _ =>
                let val thm = beta_conversion false t0
                in case subc skel0 mss (rhs_of thm) of
@@ -786,7 +786,7 @@
                                list_comb (h, replicate (length ts) dVar)
                            in case botc skel mss cl of
                                 None => thm
-                              | Some thm' => transitive'' thm
+                              | Some thm' => transitive3 thm
                                   (combination thm' (reflexive cr))
                            end handle TERM _ => error "congc result"
                                     | Pattern.MATCH => appc ()))
@@ -794,126 +794,106 @@
                end)
          | _ => None)
 
-    and impc args =
-      if mutsimp
-      then let val (prem, conc, mss) = args
-           in apsome snd (mut_impc ([], prem, conc, mss)) end
-      else nonmut_impc args
-
-    and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
-        None => mut_impc1 (prems, prem, conc, mss)
-      | Some thm1 =>
-          let val prem1 = rhs_of thm1
-          in (case mut_impc1 (prems, prem1, conc, mss) of
-              None => Some (None,
-                combination (combination refl_implies thm1) (reflexive conc))
-            | Some (x, thm2) => Some (x, transitive (combination (combination
-                refl_implies thm1) (reflexive conc)) thm2))
-          end)
+    and impc ct mss =
+      if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
 
-    and mut_impc1 (prems, prem1, conc, mss) =
-      let
-        fun uncond ({thm, lhs, elhs, perm}) =
-          if Thm.no_prems thm then Some lhs else None
+    and rules_of_prem mss prem =
+      if maxidx_of_term (term_of prem) <> ~1
+      then (trace_cterm true
+        "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
+      else
+        let val asm = assume prem
+        in (extract_safe_rrules (mss, asm), Some asm) end
 
-        val (lhss1, mss1) =
-          if maxidx_of_term (term_of prem1) <> ~1
-          then (trace_cterm true
-            "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
-                ([],mss))
-          else let val thm = assume prem1
-                   val rrules1 = extract_safe_rrules (mss, thm)
-                   val lhss1 = mapfilter uncond rrules1
-                   val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
-               in (lhss1, mss1) end
-
-        fun disch1 thm =
-          let val (cB', cC') = dest_eq thm
-          in
-            implies_elim (Thm.instantiate
-              ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
-              (implies_intr prem1 thm)
-          end
+    and add_rrules (rrss, asms) mss =
+      add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
 
-        fun rebuild None = (case rewritec (prover, sign, maxidx) mss
-            (mk_implies (prem1, conc)) of
-              None => None
-            | Some (thm, _) =>
-                let val (prem, conc) = Drule.dest_implies (rhs_of thm)
-                in (case mut_impc (prems, prem, conc, mss) of
-                    None => Some (None, thm)
-                  | Some (x, thm') => Some (x, transitive thm thm'))
-                end handle TERM _ => Some (None, thm))
-          | rebuild (Some thm2) =
-            let val thm = disch1 thm2
-            in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
-                 None => Some (None, thm)
-               | Some (thm', _) =>
-                   let val (prem, conc) = Drule.dest_implies (rhs_of thm')
-                   in (case mut_impc (prems, prem, conc, mss) of
-                       None => Some (None, transitive thm thm')
-                     | Some (x, thm'') =>
-                         Some (x, transitive (transitive thm thm') thm''))
-                   end handle TERM _ => Some (None, transitive thm thm'))
-            end
-
-        fun simpconc () =
-          let val (s, t) = Drule.dest_implies conc
-          in case mut_impc (prems @ [prem1], s, t, mss1) of
-               None => rebuild None
-             | Some (Some i, thm2) =>
-                  let
-                    val (prem, cC') = Drule.dest_implies (rhs_of thm2);
-                    val thm2' = transitive (disch1 thm2) (Thm.instantiate
-                      ([], [(cA, prem1), (cB, prem), (cC, cC')])
-                      Drule.swap_prems_eq)
-                  in if i=0 then apsome (apsnd (transitive thm2'))
-                       (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
-                     else Some (Some (i-1), thm2')
-                  end
-             | Some (None, thm) => rebuild (Some thm)
-          end handle TERM _ => rebuild (botc skel0 mss1 conc)
-
-      in
+    and disch r (prem, eq) =
+      let
+        val (lhs, rhs) = dest_eq eq;
+        val eq' = implies_elim (Thm.instantiate
+          ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
+          (implies_intr prem eq)
+      in if not r then eq' else
         let
-          val tsig = Sign.tsig_of sign
-          fun reducible t =
-            exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
-        in case dropwhile (not o reducible) prems of
-            [] => simpconc ()
-          | red::rest => (trace_cterm false "Can now reduce premise:" red;
-              Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
+          val (prem', concl) = dest_implies lhs;
+          val (prem'', _) = dest_implies rhs
+        in transitive (transitive
+          (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
+             Drule.swap_prems_eq) eq')
+          (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
+             Drule.swap_prems_eq)
         end
       end
 
+    and rebuild [] _ _ _ _ eq = eq
+      | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
+          let
+            val mss' = add_rrules (rev rrss, rev asms) mss;
+            val concl' =
+              Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
+            val dprem = apsome (curry (disch false) prem)
+          in case rewritec (prover, sign, maxidx) mss' concl' of
+              None => rebuild prems concl' rrss asms mss (dprem eq)
+            | Some (eq', _) => transitive2 (foldl (disch false o swap)
+                  (the (transitive3 (dprem eq) eq'), prems))
+                (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
+          end
+          
+    and mut_impc0 prems concl rrss asms mss =
+      let
+        val prems' = strip_imp_prems concl;
+        val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
+      in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
+        (asms @ asms') [] [] [] [] mss ~1 ~1
+      end
+ 
+    and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
+        transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
+            (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
+          (if changed > 0 then
+             mut_impc (rev prems') concl (rev rrss') (rev asms')
+               [] [] [] [] mss ~1 changed
+           else rebuild prems' concl rrss' asms' mss
+             (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
+
+      | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
+          prems' rrss' asms' eqns mss changed k =
+        case (if k = 0 then None else botc skel0 (add_rrules
+          (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
+            None => mut_impc prems concl rrss asms (prem :: prems')
+              (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
+              (if k = 0 then 0 else k - 1)
+          | Some eqn =>
+            let
+              val prem' = rhs_of eqn;
+              val tprems = map term_of prems;
+              val i = 1 + foldl Int.max (~1, map (fn p =>
+                find_index_eq p tprems) (#hyps (rep_thm eqn)));
+              val (rrs', asm') = rules_of_prem mss prem'
+            in mut_impc prems concl rrss asms (prem' :: prems')
+              (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
+                (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
+                  (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
+            end
+
      (* legacy code - only for backwards compatibility *)
-     and nonmut_impc (prem, conc, mss) =
-       let val thm1 = if simprem then botc skel0 mss prem else None;
+     and nonmut_impc ct mss =
+       let val (prem, conc) = dest_implies ct;
+           val thm1 = if simprem then botc skel0 mss prem else None;
            val prem1 = if_none (apsome rhs_of thm1) prem;
-           val maxidx1 = maxidx_of_term (term_of prem1)
-           val mss1 =
-             if not useprem then mss else
-             if maxidx1 <> ~1
-             then (trace_cterm true
-               "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
-                   mss)
-             else let val thm = assume prem1
-                  in add_safe_simp (add_prems (mss, [thm]), thm) end
+           val mss1 = if not useprem then mss else add_rrules
+             (apsnd single (apfst single (rules_of_prem mss prem1))) mss
        in (case botc skel0 mss1 conc of
            None => (case thm1 of
                None => None
-             | Some thm1' => Some (combination
-                 (combination refl_implies thm1') (reflexive conc)))
+             | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
          | Some thm2 =>
-           let
-             val conc2 = rhs_of thm2;
-             val thm2' = implies_elim (Thm.instantiate
-               ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
-               (implies_intr prem1 thm2)
+           let val thm2' = disch false (prem1, thm2)
            in (case thm1 of
                None => Some thm2'
-             | Some thm1' => Some (transitive (combination
-                 (combination refl_implies thm1') (reflexive conc)) thm2'))
+             | Some thm1' =>
+                 Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
            end)
        end
 
@@ -950,7 +930,7 @@
             val (thA,j) = case term_of A of
                   Const("=?=",_)$_$_ => (reflexive A, i)
                 | _ => (if pred i then cv A else reflexive A, i+1)
-        in  combination (combination refl_implies thA) (gconv j B) end
+        in  imp_cong' thA (gconv j B) end
         handle TERM _ => reflexive ct
   in gconv 1 end;