simp_depth: now proper value in simpset (prevents problems with lost exception trace, enables multi-threaded simplification);
authorwenzelm
Wed, 09 May 2007 19:20:00 +0200
changeset 22892 c77a1e1c7323
parent 22891 ef91c38e7c0b
child 22893 1b0f4e6f81aa
simp_depth: now proper value in simpset (prevents problems with lost exception trace, enables multi-threaded simplification); trace_simp_depth_limit_exceeded: attempt to hide destructive pointer programming within simpset;
src/Pure/meta_simplifier.ML
--- a/src/Pure/meta_simplifier.ML	Wed May 09 18:58:03 2007 +0200
+++ b/src/Pure/meta_simplifier.ML	Wed May 09 19:20:00 2007 +0200
@@ -28,6 +28,7 @@
    {rules: rrule Net.net,
     prems: thm list,
     bounds: int * ((string * typ) * string) list,
+    depth: int * bool ref option,
     context: Proof.context option} *
    {congs: (string * cong) list * string list,
     procs: proc Net.net,
@@ -121,7 +122,6 @@
 structure MetaSimplifier: META_SIMPLIFIER =
 struct
 
-
 (** datatype simpset **)
 
 (* rewrite rules *)
@@ -165,6 +165,7 @@
     prems: current premises;
     bounds: maximal index of bound variables already used
       (for generating new names when rewriting under lambda abstractions);
+    depth: simp_depth and exceeded flag;
     congs: association list of congruence rules and
            a list of `weak' congruence constants.
            A congruence is `weak' if it avoids normalization of some argument.
@@ -189,6 +190,7 @@
    {rules: rrule Net.net,
     prems: thm list,
     bounds: int * ((string * typ) * string) list,
+    depth: int * bool ref option,
     context: Proof.context option} *
    {congs: (string * cong) list * string list,
     procs: proc Net.net,
@@ -212,11 +214,11 @@
 
 fun rep_ss (Simpset args) = args;
 
-fun make_ss1 (rules, prems, bounds, context) =
-  {rules = rules, prems = prems, bounds = bounds, context = context};
+fun make_ss1 (rules, prems, bounds, depth, context) =
+  {rules = rules, prems = prems, bounds = bounds, depth = depth, context = context};
 
-fun map_ss1 f {rules, prems, bounds, context} =
-  make_ss1 (f (rules, prems, bounds, context));
+fun map_ss1 f {rules, prems, bounds, depth, context} =
+  make_ss1 (f (rules, prems, bounds, depth, context));
 
 fun make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =
   {congs = congs, procs = procs, mk_rews = mk_rews, termless = termless,
@@ -227,9 +229,9 @@
 
 fun make_simpset (args1, args2) = Simpset (make_ss1 args1, make_ss2 args2);
 
-fun map_simpset f (Simpset ({rules, prems, bounds, context},
+fun map_simpset f (Simpset ({rules, prems, bounds, depth, context},
     {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers})) =
-  make_simpset (f ((rules, prems, bounds, context),
+  make_simpset (f ((rules, prems, bounds, depth, context),
     (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers)));
 
 fun map_simpset1 f (Simpset (r1, r2)) = Simpset (map_ss1 f r1, r2);
@@ -249,27 +251,37 @@
 fun eq_solver (Solver {id = id1, ...}, Solver {id = id2, ...}) = (id1 = id2);
 
 
+(* simp depth *)
+
+val simp_depth_limit = ref 100;
+val trace_simp_depth_limit = ref 1;
+
+fun trace_depth (Simpset ({depth = (depth, exceeded), ...}, _)) msg =
+  if depth > !trace_simp_depth_limit then
+    (case exceeded of
+      NONE => ()
+    | SOME r => if !r then () else (tracing "trace_simp_depth_limit exceeded!"; r := true))
+  else
+    (tracing (enclose "[" "]" (string_of_int depth) ^ msg);
+      (case exceeded of SOME r => r := false | _ => ()));
+
+val inc_simp_depth = map_simpset1 (fn (rules, prems, bounds, (depth, exceeded), context) =>
+  (rules, prems, bounds,
+    (depth + 1, if depth = !trace_simp_depth_limit then SOME (ref false) else exceeded), context));
+
+fun simp_depth (Simpset ({depth = (depth, _), ...}, _)) = depth;
+
+
 (* diagnostics *)
 
 exception SIMPLIFIER of string * thm;
 
 val debug_simp = ref false;
 val trace_simp = ref false;
-val simp_depth = ref 0;
-val simp_depth_limit = ref 100;
-val trace_simp_depth_limit = ref 1;
-val trace_simp_depth_limit_exceeded = ref false;
+
 local
 
-fun println a =
-  if ! simp_depth > ! trace_simp_depth_limit
-  then if !trace_simp_depth_limit_exceeded then ()
-       else (tracing "trace_simp_depth_limit exceeded!";
-             trace_simp_depth_limit_exceeded := true)
-  else (tracing (enclose "[" "]" (string_of_int (! simp_depth)) ^ a);
-        trace_simp_depth_limit_exceeded := false);
-
-fun prnt warn a = if warn then warning a else println a;
+fun prnt ss warn a = if warn then warning a else trace_depth ss a;
 
 fun show_bounds (Simpset ({bounds = (_, bs), ...}, _)) t =
   let
@@ -280,30 +292,31 @@
 
 in
 
-fun print_term warn a ss thy t = prnt warn (a ^ "\n" ^
+fun print_term ss warn a thy t = prnt ss warn (a ^ "\n" ^
   Sign.string_of_term thy (if ! debug_simp then t else show_bounds ss t));
 
-fun debug warn a = if ! debug_simp then prnt warn (a ()) else ();
-fun trace warn a = if ! trace_simp then prnt warn (a ()) else ();
+fun debug warn a ss = if ! debug_simp then prnt ss warn (a ()) else ();
+fun trace warn a ss = if ! trace_simp then prnt ss warn (a ()) else ();
 
-fun debug_term warn a ss thy t = if ! debug_simp then print_term warn (a ()) ss thy t else ();
-fun trace_term warn a ss thy t = if ! trace_simp then print_term warn (a ()) ss thy t else ();
+fun debug_term warn a ss thy t = if ! debug_simp then print_term ss warn (a ()) thy t else ();
+fun trace_term warn a ss thy t = if ! trace_simp then print_term ss warn (a ()) thy t else ();
 
 fun trace_cterm warn a ss ct =
-  if ! trace_simp then print_term warn (a ()) ss (Thm.theory_of_cterm ct) (Thm.term_of ct)
+  if ! trace_simp then print_term ss warn (a ()) (Thm.theory_of_cterm ct) (Thm.term_of ct)
   else ();
 
 fun trace_thm a ss th =
-  if ! trace_simp then print_term false (a ()) ss (Thm.theory_of_thm th) (Thm.full_prop_of th)
+  if ! trace_simp then print_term ss false (a ()) (Thm.theory_of_thm th) (Thm.full_prop_of th)
   else ();
 
 fun trace_named_thm a ss (th, name) =
   if ! trace_simp then
-    print_term false (if name = "" then a () else a () ^ " " ^ quote name ^ ":") ss
+    print_term ss false (if name = "" then a () else a () ^ " " ^ quote name ^ ":")
       (Thm.theory_of_thm th) (Thm.full_prop_of th)
   else ();
 
-fun warn_thm a ss th = print_term true a ss (Thm.theory_of_thm th) (Thm.full_prop_of th);
+fun warn_thm a ss th =
+  print_term ss true a (Thm.theory_of_thm th) (Thm.full_prop_of th);
 
 fun cond_warn_thm a (ss as Simpset ({context, ...}, _)) th =
   if is_some context then () else warn_thm a ss th;
@@ -347,20 +360,20 @@
 
 fun eq_bound (x: string, (y, _)) = x = y;
 
-fun add_bound bound = map_simpset1 (fn (rules, prems, (count, bounds), context) =>
-  (rules, prems, (count + 1, bound :: bounds), context));
+fun add_bound bound = map_simpset1 (fn (rules, prems, (count, bounds), depth, context) =>
+  (rules, prems, (count + 1, bound :: bounds), depth, context));
 
-fun add_prems ths = map_simpset1 (fn (rules, prems, bounds, context) =>
-  (rules, ths @ prems, bounds, context));
+fun add_prems ths = map_simpset1 (fn (rules, prems, bounds, depth, context) =>
+  (rules, ths @ prems, bounds, depth, context));
 
-fun inherit_context (Simpset ({bounds, context, ...}, _)) =
-  map_simpset1 (fn (rules, prems, _, _) => (rules, prems, bounds, context));
+fun inherit_context (Simpset ({bounds, depth, context, ...}, _)) =
+  map_simpset1 (fn (rules, prems, _, _, _) => (rules, prems, bounds, depth, context));
 
 fun the_context (Simpset ({context = SOME ctxt, ...}, _)) = ctxt
   | the_context _ = raise Fail "Simplifier: no proof context in simpset";
 
 fun context ctxt =
-  map_simpset1 (fn (rules, prems, bounds, _) => (rules, prems, bounds, SOME ctxt));
+  map_simpset1 (fn (rules, prems, bounds, depth, _) => (rules, prems, bounds, depth, SOME ctxt));
 
 val theory_context = context o ProofContext.init;
 
@@ -400,17 +413,17 @@
   in {thm = thm, name = name, lhs = lhs, elhs = elhs, extra = extra, fo = fo, perm = perm} end;
 
 fun del_rrule (rrule as {thm, elhs, ...}) ss =
-  ss |> map_simpset1 (fn (rules, prems, bounds, context) =>
-    (Net.delete_term eq_rrule (term_of elhs, rrule) rules, prems, bounds, context))
+  ss |> map_simpset1 (fn (rules, prems, bounds, depth, context) =>
+    (Net.delete_term eq_rrule (term_of elhs, rrule) rules, prems, bounds, depth, context))
   handle Net.DELETE => (cond_warn_thm "Rewrite rule not in simpset:" ss thm; ss);
 
 fun insert_rrule (rrule as {thm, name, elhs, ...}) ss =
  (trace_named_thm (fn () => "Adding rewrite rule") ss (thm, name);
-  ss |> map_simpset1 (fn (rules, prems, bounds, context) =>
+  ss |> map_simpset1 (fn (rules, prems, bounds, depth, context) =>
     let
       val rrule2 as {elhs, ...} = mk_rrule2 rrule;
       val rules' = Net.insert_term eq_rrule (term_of elhs, rrule2) rules;
-    in (rules', prems, bounds, context) end)
+    in (rules', prems, bounds, depth, context) end)
   handle Net.INSERT => (cond_warn_thm "Ignoring duplicate rewrite rule:" ss thm; ss));
 
 fun vperm (Var _, Var _) = true
@@ -759,7 +772,7 @@
 (* empty *)
 
 fun init_ss mk_rews termless subgoal_tac solvers =
-  make_simpset ((Net.empty, [], (0, []), NONE),
+  make_simpset ((Net.empty, [], (0, []), (0, NONE), NONE),
     (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers));
 
 fun clear_ss (ss as Simpset (_, {mk_rews, termless, subgoal_tac, solvers, ...})) =
@@ -780,16 +793,17 @@
 
 fun merge_ss (ss1, ss2) =
   let
-    val Simpset ({rules = rules1, prems = prems1, bounds = bounds1, context = _},
+    val Simpset ({rules = rules1, prems = prems1, bounds = bounds1, depth = depth1, context = _},
      {congs = (congs1, weak1), procs = procs1, mk_rews, termless, subgoal_tac,
       loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
-    val Simpset ({rules = rules2, prems = prems2, bounds = bounds2, context = _},
+    val Simpset ({rules = rules2, prems = prems2, bounds = bounds2, depth = depth2, context = _},
      {congs = (congs2, weak2), procs = procs2, mk_rews = _, termless = _, subgoal_tac = _,
       loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
 
     val rules' = Net.merge eq_rrule (rules1, rules2);
     val prems' = gen_merge_lists Thm.eq_thm_prop prems1 prems2;
     val bounds' = if #1 bounds1 < #1 bounds2 then bounds2 else bounds1;
+    val depth' = if #1 depth1 < #1 depth2 then depth2 else depth1;
     val congs' = merge (eq_cong o pairself #2) (congs1, congs2);
     val weak' = merge (op =) (weak1, weak2);
     val procs' = Net.merge eq_proc (procs1, procs2);
@@ -797,7 +811,7 @@
     val unsafe_solvers' = merge eq_solver (unsafe_solvers1, unsafe_solvers2);
     val solvers' = merge eq_solver (solvers1, solvers2);
   in
-    make_simpset ((rules', prems', bounds', NONE), ((congs', weak'), procs',
+    make_simpset ((rules', prems', bounds', depth', NONE), ((congs', weak'), procs',
       mk_rews, termless, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
   end;
 
@@ -903,9 +917,9 @@
               in SOME (thm'', uncond_skel (congs, lr)) end)
            else
              (trace_thm (fn () => "Trying to rewrite:") ss thm';
-              if !simp_depth > !simp_depth_limit
+              if simp_depth ss > ! simp_depth_limit
               then let val s = "simp_depth_limit exceeded - giving up"
-                   in trace false (fn () => s); warning s; NONE end
+                   in trace false (fn () => s) ss; warning s; NONE end
               else
               case prover ss thm' of
                 NONE => (trace_thm (fn () => "FAILED") ss thm'; NONE)
@@ -939,7 +953,7 @@
             (debug_term false (fn () => "Trying procedure " ^ quote name ^ " on:") ss thyt eta_t;
              case transform_failure (curry SIMPROC_FAIL name)
                  (fn () => proc ss eta_t') () of
-               NONE => (debug false (fn () => "FAILED"); proc_rews ps)
+               NONE => (debug false (fn () => "FAILED") ss; proc_rews ps)
              | SOME raw_thm =>
                  (trace_thm (fn () => "Procedure " ^ quote name ^ " produced rewrite rule:")
                    ss raw_thm;
@@ -1208,7 +1222,7 @@
         | _ => I) (term_of ct) [];
     in
       if null bs then ()
-      else print_term true ("Simplifier: term contains loose bounds: " ^ commas_quote bs) ss
+      else print_term ss true ("Simplifier: term contains loose bounds: " ^ commas_quote bs)
         (Thm.theory_of_cterm ct) (Thm.term_of ct)
     end
   else ();
@@ -1217,17 +1231,15 @@
   let
     val ct = Thm.adjust_maxidx_cterm ~1 raw_ct;
     val {thy, t, maxidx, ...} = Thm.rep_cterm ct;
-    val ss = activate_context thy raw_ss;
-    val _ = inc simp_depth;
+    val ss = inc_simp_depth (activate_context thy raw_ss);
+    val depth = simp_depth ss;
     val _ =
-      if ! simp_depth mod 20 = 0 then
-        warning ("Simplification depth " ^ string_of_int (! simp_depth))
+      if depth mod 20 = 0 then
+        warning ("Simplification depth " ^ string_of_int depth)
       else ();
     val _ = trace_cterm false (fn () => "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:") ss ct;
     val _ = check_bounds ss ct;
-    val res = bottomc (mode, Option.map Drule.flexflex_unique oo prover, thy, maxidx) ss ct
-  in dec simp_depth; res end
-  handle exn => (dec simp_depth; raise exn);  (* FIXME avoid handling of generic exceptions *)
+  in bottomc (mode, Option.map Drule.flexflex_unique oo prover, thy, maxidx) ss ct end;
 
 val simple_prover =
   SINGLE o (fn ss => ALLGOALS (resolve_tac (prems_of_ss ss)));