src/Pure/thm.ML
changeset 4679 24917efb31b5
parent 4667 6328d427a339
child 4684 eb712fef644b
--- a/src/Pure/thm.ML	Wed Mar 04 13:15:05 1998 +0100
+++ b/src/Pure/thm.ML	Wed Mar 04 13:16:05 1998 +0100
@@ -164,7 +164,8 @@
   val add_prems         : meta_simpset * thm list -> meta_simpset
   val prems_of_mss      : meta_simpset -> thm list
   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
-(*  val mk_rews_of_mss    : meta_simpset -> thm -> thm list *)
+  val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
+  val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
   val trace_simp        : bool ref
   val rewrite_cterm     : bool * bool -> meta_simpset ->
@@ -1459,6 +1460,9 @@
 fun prtm warn a sign t =
   (prnt warn a; prnt warn (Sign.string_of_term sign t));
 
+fun prthm warn a (thm as Thm{sign_ref, prop, ...}) =
+  (prtm warn a (Sign.deref sign_ref) prop);
+
 val trace_simp = ref false;
 
 fun trace warn a = if !trace_simp then prnt warn a else ();
@@ -1505,7 +1509,9 @@
     bounds: names of bound variables already used
       (for generating new names when rewriting under lambda abstractions);
     prems: current premises;
-    mk_rews: turns simplification thms into rewrite rules;
+    mk_rews: mk: turns simplification thms into rewrite rules;
+             mk_sym: turns == around; (needs Drule!)
+             mk_eq_True: turns P into P == True - logic specific;
     termless: relation for ordered rewriting;
 *)
 
@@ -1516,15 +1522,21 @@
     procs: simproc Net.net,
     bounds: string list,
     prems: thm list,
-    mk_rews: thm -> thm list,
+    mk_rews: {mk: thm -> thm list,
+              mk_sym: thm -> thm option,
+              mk_eq_True: thm -> thm option},
     termless: term * term -> bool};
 
 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless) =
   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
-    prems = prems, mk_rews = mk_rews, termless = termless};
+       prems=prems, mk_rews=mk_rews, termless=termless};
+
+fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless}, rules') =
+  mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless);
 
 val empty_mss =
-  mk_mss (Net.empty, [], Net.empty, [], [], K [], Term.termless);
+  let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
+  in mk_mss (Net.empty, [], Net.empty, [], [], mk_rews, Term.termless) end;
 
 
 
@@ -1556,65 +1568,123 @@
         generic_merge eq_prem I I prems1 prems2,
         mk_rews, termless);
 
+(* add_simps *)
 
-(* mk_rrule *)
+fun insert_rrule(mss as Mss {rules,...},
+                 rrule as {thm = thm, lhs = lhs, perm = perm}) =
+  (trace_thm false "Adding rewrite rule:" thm;
+   let val rules' = Net.insert_term ((lhs, rrule), rules, eq_rrule)
+   in upd_rules(mss,rules') end
+   handle Net.INSERT =>
+     (prthm true "Ignoring duplicate rewrite rule" thm; mss));
+
+fun vperm (Var _, Var _) = true
+  | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
+  | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
+  | vperm (t, u) = (t = u);
+
+fun var_perm (t, u) =
+  vperm (t, u) andalso eq_set_term (term_vars t, term_vars u);
+
+(* FIXME: it seems that the conditions on extra variables are too liberal if
+prems are nonempty: does solving the prems really guarantee instantiation of
+all its Vars? Better: a dynamic check each time a rule is applied.
+*)
+fun rewrite_rule_extra_vars prems elhs erhs =
+  not ((term_vars erhs) subset
+       (union_term (term_vars elhs, List.concat(map term_vars prems))))
+  orelse
+  not ((term_tvars erhs) subset
+       (term_tvars elhs  union  List.concat(map term_tvars prems)));
 
-fun mk_rrule (thm as Thm {sign_ref, prop, ...}) =
-  let
-    val sign = Sign.deref sign_ref;
-    val prems = Logic.strip_imp_prems prop;
-    val concl = Logic.strip_imp_concl prop;
-    val (lhs, rhs) = Logic.dest_equals concl handle TERM _ =>
-      raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm);
-  in case Logic.rewrite_rule_ok sign prems lhs rhs of
-     (None,perm) => Some {thm = thm, lhs = lhs, perm = perm}
-   | (Some msg,_) =>
-        (prtm true ("ignoring rewrite rule ("^msg^")") sign prop; None)
+(*simple test for looping rewrite*)
+fun looptest sign prems lhs rhs =
+   rewrite_rule_extra_vars prems lhs rhs
+  orelse
+   is_Var (head_of lhs)
+  orelse
+   (exists (apl (lhs, op Logic.occs)) (rhs :: prems))
+  orelse
+   (null prems andalso
+    Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
+(*the condition "null prems" in the last cases is necessary because
+  conditional rewrites with extra variables in the conditions may terminate
+  although the rhs is an instance of the lhs. Example:
+  ?m < ?n ==> f(?n) == f(?m)*)
+
+fun decomp_simp(thm as Thm {sign_ref, prop, ...}) =
+  let val sign = Sign.deref sign_ref;
+      val prems = Logic.strip_imp_prems prop;
+      val concl = Logic.strip_imp_concl prop;
+      val (lhs, rhs) = Logic.dest_equals concl handle TERM _ =>
+        raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
+      val elhs = Pattern.eta_contract lhs;
+      val erhs = Pattern.eta_contract rhs;
+      val perm = var_perm (elhs, erhs) andalso not (elhs aconv erhs)
+                 andalso not (is_Var elhs)
+  in (sign,prems,lhs,rhs,perm) end;
+
+fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
+  apsome (fn eq_True => let val (_,_,lhs,_,_) = decomp_simp eq_True
+                        in {thm=eq_True, lhs=lhs, perm=false} end)
+         (mk_eq_True thm);
+
+fun mk_rrule mss thm =
+  let val (_,prems,lhs,rhs,perm) = decomp_simp thm
+  in if perm then Some{thm=thm, lhs=lhs, perm=true} else
+     (* weak test for loops: *)
+     if rewrite_rule_extra_vars prems lhs rhs orelse
+        is_Var (head_of lhs) (* mk_cases may do this! *)
+     then mk_eq_True mss thm
+     else Some{thm=thm, lhs=lhs, perm=false}
   end;
 
+fun orient_rrule mss thm =
+  let val (sign,prems,lhs,rhs,perm) = decomp_simp thm
+  in if perm then Some{thm=thm,lhs=lhs,perm=true}
+     else if looptest sign prems lhs rhs
+          then if looptest sign prems rhs lhs
+               then mk_eq_True mss thm
+               else let val Mss{mk_rews={mk_sym,...},...} = mss
+                    in apsome (fn thm' => {thm=thm', lhs=rhs, perm=false})
+                              (mk_sym thm)
+                    end
+          else Some{thm=thm, lhs=lhs, perm=false}
+  end;
 
-(* add_simps *)
+fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
 
-fun add_simp1(mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless},
-              thm as Thm {sign_ref, prop, ...}) =
-  case mk_rrule thm of
-    None => mss
-  | Some (rrule as {lhs, ...}) =>
-      (trace_thm false "Adding rewrite rule:" thm;
-       mk_mss (Net.insert_term ((lhs, rrule), rules, eq_rrule),
-               congs, procs, bounds, prems, mk_rews, termless)
-       handle Net.INSERT =>
-       (prtm true "ignoring duplicate rewrite rule" (Sign.deref sign_ref) prop;
-        mss));
+fun orient_comb_simps comb mk_rrule (mss,thms) =
+  let val rews = extract_rews(mss,thms)
+      val rrules = mapfilter mk_rrule rews
+  in foldl comb (mss,rrules) end
 
-fun add_simp(mss as Mss{mk_rews,...},thm) = foldl add_simp1 (mss, mk_rews thm);
+(* Add rewrite rules explicitly; do not reorient! *)
+fun add_simps(mss,thms) =
+  orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
 
-val add_simps = foldl add_simp;
+fun mss_of thms =
+  foldl insert_rrule (empty_mss, mapfilter (mk_rrule empty_mss) thms);
 
-fun mss_of thms = foldl add_simp1 (empty_mss, thms);
+fun safe_add_simps(mss,thms) =
+  orient_comb_simps insert_rrule (orient_rrule mss) (mss,thms);
 
 
 (* del_simps *)
 
-fun del_simp1(mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless},
-              thm as Thm {sign_ref, prop, ...}) =
-  case mk_rrule thm of
-    None => mss
-  | Some (rrule as {lhs, ...}) =>
-      (mk_mss (Net.delete_term ((lhs, rrule), rules, eq_rrule),
-               congs, procs, bounds, prems, mk_rews, termless)
-       handle Net.DELETE =>
-       (prtm true "rewrite rule not in simpset" (Sign.deref sign_ref) prop;
-        mss));
+fun del_rrule(mss as Mss {rules,...},
+              rrule as {thm = thm, lhs = lhs, perm = perm}) =
+  (upd_rules(mss, Net.delete_term ((lhs, rrule), rules, eq_rrule))
+   handle Net.DELETE =>
+     (prthm true "rewrite rule not in simpset" thm; mss));
 
-fun del_simp(mss as Mss{mk_rews,...},thm) = foldl del_simp1 (mss, mk_rews thm);
-
-val del_simps = foldl del_simp;
+fun del_simps(mss,thms) =
+  orient_comb_simps del_rrule (mk_rrule mss) (mss,thms);
 
 
 (* add_congs *)
 
-fun add_cong (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, thm) =
+fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless}, thm) =
   let
     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
@@ -1631,7 +1701,7 @@
 
 (* del_congs *)
 
-fun del_cong (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, thm) =
+fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless}, thm) =
   let
     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
@@ -1648,7 +1718,7 @@
 
 (* add_simprocs *)
 
-fun add_proc (mss as Mss {rules, congs, procs, bounds, prems, mk_rews, termless},
+fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless},
     (name, lhs as Cterm {sign_ref, t, ...}, proc, id)) =
   (trace_term false ("Adding simplification procedure " ^ quote name ^ " for:")
       (Sign.deref sign_ref) t;
@@ -1665,7 +1735,7 @@
 
 (* del_simprocs *)
 
-fun del_proc (mss as Mss {rules, congs, procs, bounds, prems, mk_rews, termless},
+fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless},
     (name, lhs as Cterm {t, ...}, proc, id)) =
   mk_mss (rules, congs,
     Net.delete_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
@@ -1680,7 +1750,7 @@
 
 (* prems *)
 
-fun add_prems (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, thms) =
+fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless}, thms) =
   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless);
 
 fun prems_of_mss (Mss {prems, ...}) = prems;
@@ -1689,11 +1759,22 @@
 (* mk_rews *)
 
 fun set_mk_rews
-  (Mss {rules, congs, procs, bounds, prems, mk_rews = _, termless}, mk_rews) =
-    mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless);
+  (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, mk) =
+    mk_mss (rules, congs, procs, bounds, prems,
+            {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
+            termless);
 
-fun mk_rews_of_mss (Mss {mk_rews, ...}) = mk_rews;
+fun set_mk_sym
+  (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, mk_sym) =
+    mk_mss (rules, congs, procs, bounds, prems,
+            {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
+            termless);
 
+fun set_mk_eq_True
+  (Mss {rules, congs, procs, bounds, prems, mk_rews, termless}, mk_eq_True) =
+    mk_mss (rules, congs, procs, bounds, prems,
+            {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
+            termless);
 
 (* termless *)
 
@@ -1744,18 +1825,11 @@
 
 (* mk_procrule *)
 
-fun mk_procrule (thm as Thm {sign_ref, prop, ...}) =
-  let
-    val sign = Sign.deref sign_ref;
-    val prems = Logic.strip_imp_prems prop;
-    val concl = Logic.strip_imp_concl prop;
-    val (lhs, _) = Logic.dest_equals concl handle TERM _ =>
-      raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm);
-    val econcl = Pattern.eta_contract concl;
-    val (elhs, erhs) = Logic.dest_equals econcl;
-  in case Logic.rewrite_rule_extra_vars prems elhs erhs of
-       Some msg => (prtm true msg sign prop; [])
-     | None => [{thm = thm, lhs = lhs, perm = false}]
+fun mk_procrule thm =
+  let val (_,prems,lhs,rhs,_) = decomp_simp thm
+  in if rewrite_rule_extra_vars prems lhs rhs
+     then (prthm true "Extra vars on rhs" thm; [])
+     else [{thm = thm, lhs = lhs, perm = false}]
   end;
 
 
@@ -1773,7 +1847,7 @@
 *)
 
 fun rewritec (prover,sign_reft,maxt)
-             (mss as Mss{rules, procs, mk_rews, termless, prems, ...}) 
+             (mss as Mss{rules, procs, termless, prems, ...}) 
              (shypst,hypst,t,ders) =
   let
       val signt = Sign.deref sign_reft;
@@ -1969,7 +2043,7 @@
                                                   (Sign.deref sign_ref) s1; mss)
              else let val thm = assume (Cterm{sign_ref=sign_ref, t=s1, 
                                               T=propT, maxidx= ~1})
-                  in add_simp(add_prems(mss,[thm]), thm) end
+                  in safe_add_simps(add_prems(mss,[thm]), [thm]) end
            val (shyps2,hyps2,u1,ders2) = try_botc mss1 (shyps1,hyps1,u,ders1)
            val hyps3 = if gen_mem (op aconv) (s1, hyps1)
                        then hyps2 else hyps2\s1