improved bounds: nameless Term.bound, recover names for output;
authorwenzelm
Mon, 01 Aug 2005 19:20:39 +0200
changeset 16985 7df8abe926c3
parent 16984 abc48b981e60
child 16986 68bc6dbea7d6
improved bounds: nameless Term.bound, recover names for output;
src/Pure/meta_simplifier.ML
--- a/src/Pure/meta_simplifier.ML	Mon Aug 01 19:20:38 2005 +0200
+++ b/src/Pure/meta_simplifier.ML	Mon Aug 01 19:20:39 2005 +0200
@@ -26,7 +26,7 @@
   val rep_ss: simpset ->
    {rules: rrule Net.net,
     prems: thm list,
-    bounds: int} *
+    bounds: int * (string * (string * typ)) list} *
    {congs: (string * cong) list * string list,
     procs: proc Net.net,
     mk_rews:
@@ -80,6 +80,7 @@
     -> (theory -> simpset -> term -> thm option) -> simproc
   val simproc: theory -> string -> string list
     -> (theory -> simpset -> term -> thm option) -> simproc
+  val inherit_bounds: simpset -> simpset -> simpset
   val rewrite_cterm: bool * bool * bool ->
     (simpset -> thm -> thm option) -> simpset -> cterm -> thm
   val rewrite_aux: (simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
@@ -100,47 +101,6 @@
 struct
 
 
-(** diagnostics **)
-
-exception SIMPLIFIER of string * thm;
-
-val debug_simp = ref false;
-val trace_simp = ref false;
-val simp_depth = ref 0;
-val simp_depth_limit = ref 100;
-val trace_simp_depth_limit = ref 100;
-
-local
-
-fun println a =
-  if !simp_depth > !trace_simp_depth_limit then ()
-  else tracing (enclose "[" "]" (string_of_int(!simp_depth)) ^ a);
-
-fun prnt warn a = if warn then warning a else println a;
-fun prtm warn a thy t = prnt warn (a ^ "\n" ^ Sign.string_of_term thy t);
-fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
-
-in
-
-fun debug warn a = if ! debug_simp then prnt warn a else ();
-fun trace warn a = if ! trace_simp then prnt warn a else ();
-
-fun debug_term warn a thy t = if ! debug_simp then prtm warn a thy t else ();
-fun trace_term warn a thy t = if ! trace_simp then prtm warn a thy t else ();
-fun trace_cterm warn a ct = if ! trace_simp then prctm warn a ct else ();
-fun trace_thm a th = if ! trace_simp then prctm false a (Thm.cprop_of th) else ();
-
-fun trace_named_thm a (thm, name) =
-  if ! trace_simp then
-    prctm false (if name = "" then a else a ^ " " ^ quote name ^ ":") (Thm.cprop_of thm)
-  else ();
-
-fun warn_thm a = prctm true a o Thm.cprop_of;
-
-end;
-
-
-
 (** datatype simpset **)
 
 (* rewrite rules *)
@@ -220,7 +180,7 @@
   Simpset of
    {rules: rrule Net.net,
     prems: thm list,
-    bounds: int} *
+    bounds: int * (string * (string * typ)) list} *
    {congs: (string * cong) list * string list,
     procs: proc Net.net,
     mk_rews: mk_rews,
@@ -263,6 +223,59 @@
 fun map_simpset2 f (Simpset (r1, r2)) = Simpset (r1, map_ss2 f r2);
 
 
+(* diagnostics *)
+
+exception SIMPLIFIER of string * thm;
+
+val debug_simp = ref false;
+val trace_simp = ref false;
+val simp_depth = ref 0;
+val simp_depth_limit = ref 100;
+val trace_simp_depth_limit = ref 100;
+
+local
+
+fun println a =
+  if ! simp_depth > ! trace_simp_depth_limit then ()
+  else tracing (enclose "[" "]" (string_of_int (! simp_depth)) ^ a);
+
+fun prnt warn a = if warn then warning a else println a;
+
+fun show_bounds (Simpset ({bounds = (_, bs), ...}, _)) t =
+  let
+    val used = Term.add_term_names (t, []);
+    val xs = rev (Term.variantlist (rev (map #1 bs), used));
+    fun subst ((_, (b, T)), x) = (Free (b, T), Syntax.mark_boundT (x, T));
+  in Term.subst_atomic (ListPair.map subst (bs, xs)) t end;
+
+fun prtm warn a ss thy t = prnt warn (a ^ "\n" ^
+  Sign.string_of_term thy (if ! debug_simp then t else show_bounds ss t));
+
+in
+
+fun debug warn a = if ! debug_simp then prnt warn a else ();
+fun trace warn a = if ! trace_simp then prnt warn a else ();
+
+fun debug_term warn a ss thy t = if ! debug_simp then prtm warn a ss thy t else ();
+fun trace_term warn a ss thy t = if ! trace_simp then prtm warn a ss thy t else ();
+
+fun trace_cterm warn a ss ct =
+  if ! trace_simp then prtm warn a ss (Thm.theory_of_cterm ct) (Thm.term_of ct) else ();
+
+fun trace_thm a ss th =
+  if ! trace_simp then prtm false a ss (Thm.theory_of_thm th) (Thm.full_prop_of th) else ();
+
+fun trace_named_thm a ss (th, name) =
+  if ! trace_simp then
+    prtm false (if name = "" then a else a ^ " " ^ quote name ^ ":") ss
+      (Thm.theory_of_thm th) (Thm.full_prop_of th)
+  else ();
+
+fun warn_thm a ss th = prtm true a ss (Thm.theory_of_thm th) (Thm.full_prop_of th);
+
+end;
+
+
 (* print simpsets *)
 
 fun print_ss ss =
@@ -298,7 +311,7 @@
 local
 
 fun init_ss mk_rews termless subgoal_tac solvers =
-  make_simpset ((Net.empty, [], 0),
+  make_simpset ((Net.empty, [], (0, [])),
     (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers));
 
 val basic_mk_rews: mk_rews =
@@ -330,7 +343,7 @@
 
     val rules' = Net.merge eq_rrule (rules1, rules2);
     val prems' = gen_merge_lists Drule.eq_thm_prop prems1 prems2;
-    val bounds' = Int.max (bounds1, bounds2);
+    val bounds' = if #1 bounds1 < #1 bounds2 then bounds2 else bounds1;
     val congs' = gen_merge_lists (eq_cong o pairself #2) congs1 congs2;
     val weak' = merge_lists weak1 weak2;
     val procs' = Net.merge eq_proc (procs1, procs2);
@@ -364,8 +377,11 @@
 
 (* bounds and prems *)
 
-val incr_bounds = map_simpset1 (fn (rules, prems, bounds) =>
-  (rules, prems, bounds + 1));
+fun inherit_bounds (Simpset ({bounds, ...}, _)) =
+  map_simpset1 (fn (rules, prems, _) => (rules, prems, bounds));
+
+fun add_bound bound = map_simpset1 (fn (rules, prems, (count, bounds)) =>
+  (rules, prems, (count + 1, bound :: bounds)));
 
 fun add_prems ths = map_simpset1 (fn (rules, prems, bounds) =>
   (rules, ths @ prems, bounds));
@@ -381,14 +397,14 @@
   in {thm = thm, name = name, lhs = lhs, elhs = elhs, fo = fo, perm = perm} end;
 
 fun insert_rrule quiet (ss, rrule as {thm, name, lhs, elhs, perm}) =
- (trace_named_thm "Adding rewrite rule" (thm, name);
+ (trace_named_thm "Adding rewrite rule" ss (thm, name);
   ss |> map_simpset1 (fn (rules, prems, bounds) =>
     let
       val rrule2 as {elhs, ...} = mk_rrule2 rrule;
       val rules' = Net.insert_term eq_rrule (term_of elhs, rrule2) rules;
     in (rules', prems, bounds) end)
   handle Net.INSERT =>
-    (if quiet then () else warn_thm "Ignoring duplicate rewrite rule:" thm; ss));
+    (if quiet then () else warn_thm "Ignoring duplicate rewrite rule:" ss thm; ss));
 
 fun vperm (Var _, Var _) = true
   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
@@ -515,7 +531,7 @@
 fun del_rrule (ss, rrule as {thm, elhs, ...}) =
   ss |> map_simpset1 (fn (rules, prems, bounds) =>
     (Net.delete_term eq_rrule (term_of elhs, rrule) rules, prems, bounds))
-  handle Net.DELETE => (warn_thm "Rewrite rule not in simpset:" thm; ss);
+  handle Net.DELETE => (warn_thm "Rewrite rule not in simpset:" ss thm; ss);
 
 fun ss delsimps thms =
   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule ss) (ss, thms);
@@ -597,15 +613,15 @@
 
 local
 
-fun add_proc (ss, proc as Proc {name, lhs, ...}) =
- (trace_cterm false ("Adding simplification procedure " ^ quote name ^ " for") lhs;
+fun add_proc (proc as Proc {name, lhs, ...}) ss =
+ (trace_cterm false ("Adding simplification procedure " ^ quote name ^ " for") ss lhs;
   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
     (congs, Net.insert_term eq_proc (term_of lhs, proc) procs,
       mk_rews, termless, subgoal_tac, loop_tacs, solvers)) ss
   handle Net.INSERT =>
     (warning ("Ignoring duplicate simplification procedure " ^ quote name); ss));
 
-fun del_proc (ss, proc as Proc {name, lhs, ...}) =
+fun del_proc (proc as Proc {name, lhs, ...}) ss =
   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
     (congs, Net.delete_term eq_proc (term_of lhs, proc) procs,
       mk_rews, termless, subgoal_tac, loop_tacs, solvers)) ss
@@ -614,8 +630,8 @@
 
 in
 
-val (op addsimprocs) = Library.foldl (fn (ss, Simproc procs) => Library.foldl add_proc (ss, procs));
-val (op delsimprocs) = Library.foldl (fn (ss, Simproc procs) => Library.foldl del_proc (ss, procs));
+fun ss addsimprocs ps = fold (fn Simproc procs => fold add_proc procs) ps ss;
+fun ss delsimprocs ps = fold (fn Simproc procs => fold del_proc procs) ps ss;
 
 end;
 
@@ -713,25 +729,25 @@
 val lhs_of = #1 o dest_eq;
 val rhs_of = #2 o dest_eq;
 
-fun check_conv msg thm thm' =
+fun check_conv msg ss thm thm' =
   let
     val thm'' = transitive thm (transitive
       (symmetric (Drule.beta_eta_conversion (lhs_of thm'))) thm')
-  in if msg then trace_thm "SUCCEEDED" thm' else (); SOME thm'' end
+  in if msg then trace_thm "SUCCEEDED" ss thm' else (); SOME thm'' end
   handle THM _ =>
     let val {thy, prop = _ $ _ $ prop0, ...} = Thm.rep_thm thm in
-      trace_thm "Proved wrong thm (Check subgoaler?)" thm';
-      trace_term false "Should have proved:" thy prop0;
+      trace_thm "Proved wrong thm (Check subgoaler?)" ss thm';
+      trace_term false "Should have proved:" ss thy prop0;
       NONE
     end;
 
 
 (* mk_procrule *)
 
-fun mk_procrule thm =
+fun mk_procrule ss thm =
   let val (_, prems, lhs, elhs, rhs, _) = decomp_simp thm in
     if rewrite_rule_extra_vars prems lhs rhs
-    then (warn_thm "Extra vars on rhs:" thm; [])
+    then (warn_thm "Extra vars on rhs:" ss thm; [])
     else [mk_rrule2 {thm = thm, name = "", lhs = lhs, elhs = elhs, perm = false}]
   end;
 
@@ -794,25 +810,25 @@
         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
       in
         if perm andalso not (termless (rhs', lhs'))
-        then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
-              trace_thm "Term does not become smaller:" thm'; NONE)
-        else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
+        then (trace_named_thm "Cannot apply permutative rewrite rule" ss (thm, name);
+              trace_thm "Term does not become smaller:" ss thm'; NONE)
+        else (trace_named_thm "Applying instance of rewrite rule" ss (thm, name);
            if unconditional
            then
-             (trace_thm "Rewriting:" thm';
+             (trace_thm "Rewriting:" ss thm';
               let val lr = Logic.dest_equals prop;
-                  val SOME thm'' = check_conv false eta_thm thm'
+                  val SOME thm'' = check_conv false ss eta_thm thm'
               in SOME (thm'', uncond_skel (congs, lr)) end)
            else
-             (trace_thm "Trying to rewrite:" thm';
+             (trace_thm "Trying to rewrite:" ss thm';
               if !simp_depth > !simp_depth_limit
               then let val s = "simp_depth_limit exceeded - giving up"
                    in trace false s; warning s; NONE end
               else
               case prover ss thm' of
-                NONE => (trace_thm "FAILED" thm'; NONE)
+                NONE => (trace_thm "FAILED" ss thm'; NONE)
               | SOME thm2 =>
-                  (case check_conv true eta_thm thm2 of
+                  (case check_conv true ss eta_thm thm2 of
                      NONE => NONE |
                      SOME thm2' =>
                        let val concl = Logic.strip_imp_concl prop
@@ -838,15 +854,15 @@
     fun proc_rews [] = NONE
       | proc_rews (Proc {name, proc, lhs, ...} :: ps) =
           if Pattern.matches tsigt (Thm.term_of lhs, Thm.term_of t) then
-            (debug_term false ("Trying procedure " ^ quote name ^ " on:") thyt eta_t;
+            (debug_term false ("Trying procedure " ^ quote name ^ " on:") ss thyt eta_t;
              case transform_failure (curry SIMPROC_FAIL name)
                  (fn () => proc thyt ss eta_t) () of
                NONE => (debug false "FAILED"; proc_rews ps)
              | SOME raw_thm =>
-                 (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
-                  (case rews (mk_procrule raw_thm) of
+                 (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") ss raw_thm;
+                  (case rews (mk_procrule ss raw_thm) of
                     NONE => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
-                      " -- does not match") t; proc_rews ps)
+                      " -- does not match") ss t; proc_rews ps)
                   | some => some)))
           else proc_rews ps;
   in case eta_t of
@@ -860,19 +876,18 @@
 
 (* conversion to apply a congruence rule to a term *)
 
-fun congc (prover,thyt,maxt) {thm=cong,lhs=lhs} t =
-  let val thy = Thm.theory_of_thm cong
-      val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
+fun congc prover ss maxt {thm=cong,lhs=lhs} t =
+  let val rthm = Thm.incr_indexes (maxt+1) cong;
       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
       val insts = Thm.cterm_match (rlhs, t)
       (* Pattern.match can raise Pattern.MATCH;
          is handled when congc is called *)
       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
-      val unit = trace_thm "Applying congruence rule:" thm';
-      fun err (msg, thm) = (trace_thm msg thm; NONE)
+      val unit = trace_thm "Applying congruence rule:" ss thm';
+      fun err (msg, thm) = (trace_thm msg ss thm; NONE)
   in case prover thm' of
        NONE => err ("Congruence proof failed.  Could not prove", thm')
-     | SOME thm2 => (case check_conv true (Drule.beta_eta_conversion t) thm2 of
+     | SOME thm2 => (case check_conv true ss (Drule.beta_eta_conversion t) thm2 of
           NONE => err ("Congruence proof failed.  Should not have proved", thm2)
         | SOME thm2' =>
             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
@@ -916,8 +931,12 @@
        (case term_of t0 of
            Abs (a, T, t) =>
              let
-                 val (v, t') = Thm.dest_abs (SOME (Term.bound bounds a)) t0;
-                 val ss' = incr_bounds ss;
+                 val b = Term.bound (#1 bounds);
+                 val (v, t') = Thm.dest_abs (SOME b) t0;
+                 val b' = #1 (Term.dest_Free (Thm.term_of v));
+                 val _ = conditional (b <> b') (fn () =>
+                   warning ("Simplifier: renamed bound variable " ^ quote b ^ " to " ^ quote b'));
+                 val ss' = add_bound (a, (b', T)) ss;
                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0;
              in case botc skel' ss' t' of
                   SOME thm => SOME (abstract_rule a v thm)
@@ -958,7 +977,7 @@
   (*post processing: some partial applications h t1 ... tj, j <= length ts,
     may be a redex. Example: map (%x. x) = (%xs. xs) wrt map_cong*)
                           (let
-                             val thm = congc (prover ss, thy, maxidx) cong t0;
+                             val thm = congc (prover ss) ss maxidx cong t0;
                              val t = getOpt (Option.map rhs_of thm, t0);
                              val (cl, cr) = Thm.dest_comb t
                              val dVar = Var(("", 0), dummyT)
@@ -980,7 +999,7 @@
     and rules_of_prem ss prem =
       if maxidx_of_term (term_of prem) <> ~1
       then (trace_cterm true
-        "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], NONE))
+        "Cannot add premise as rewrite rule because it contains (type) unknowns:" ss prem; ([], NONE))
       else
         let val asm = assume prem
         in (extract_safe_rrules (ss, asm), SOME asm) end
@@ -1092,18 +1111,18 @@
 *)
 
 fun rewrite_cterm mode prover ss ct =
-  (simp_depth := !simp_depth + 1;
+  (inc simp_depth;
    if !simp_depth mod 10 = 0
    then warning ("Simplification depth " ^ string_of_int (!simp_depth))
    else ();
-   trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
+   trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ss ct;
    let val {thy, t, maxidx, ...} = Thm.rep_cterm ct
        val res = bottomc (mode, prover, thy, maxidx) ss ct
          handle THM (s, _, thms) =>
          error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
            Pretty.string_of (Display.pretty_thms thms))
-   in simp_depth := !simp_depth - 1; res end
-  ) handle exn => (simp_depth := !simp_depth - 1; raise exn);
+   in dec simp_depth; res end
+  ) handle exn => (dec simp_depth; raise exn);
 
 (*Rewrite a cterm*)
 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)