generic trace operations for main steps of Simplifier;
authorwenzelm
Thu, 12 Dec 2013 21:14:33 +0100
changeset 54729 c5cd7a58cf2d
parent 54728 445e7947c6b5
child 54730 de2d99b459b3
generic trace operations for main steps of Simplifier;
src/Pure/raw_simplifier.ML
src/Pure/simplifier.ML
--- a/src/Pure/raw_simplifier.ML	Thu Dec 12 17:34:50 2013 +0100
+++ b/src/Pure/raw_simplifier.ML	Thu Dec 12 21:14:33 2013 +0100
@@ -72,6 +72,8 @@
 sig
   include BASIC_RAW_SIMPLIFIER
   exception SIMPLIFIER of string * thm
+  type trace_ops
+  val set_trace_ops: trace_ops -> Proof.context -> Proof.context
   val internal_ss: simpset ->
    {congs: (cong_name * thm) list * cong_name list,
     procs: proc Net.net,
@@ -84,7 +86,8 @@
     termless: term * term -> bool,
     subgoal_tac: Proof.context -> int -> tactic,
     loop_tacs: (string * (Proof.context -> int -> tactic)) list,
-    solvers: solver list * solver list}
+    solvers: solver list * solver list,
+    trace_ops: trace_ops}
   val map_ss: (Proof.context -> Proof.context) -> Context.generic -> Context.generic
   val prems_of: Proof.context -> thm list
   val add_simp: thm -> Proof.context -> Proof.context
@@ -244,6 +247,18 @@
 fun eq_solver (Solver {id = id1, ...}, Solver {id = id2, ...}) = (id1 = id2);
 
 
+(* trace operations *)
+
+type trace_ops =
+ {trace_invoke: {depth: int, term: term} -> Proof.context -> Proof.context,
+  trace_apply: {unconditional: bool, term: term, thm: thm, name: string} ->
+    Proof.context -> (Proof.context -> (thm * term) option) -> (thm * term) option};
+
+val no_trace_ops : trace_ops =
+ {trace_invoke = fn _ => fn ctxt => ctxt,
+  trace_apply = fn _ => fn ctxt => fn cont => cont ctxt};
+
+
 (* simplification sets *)
 
 (*A simpset contains data required during conversion:
@@ -281,7 +296,8 @@
     termless: term * term -> bool,
     subgoal_tac: Proof.context -> int -> tactic,
     loop_tacs: (string * (Proof.context -> int -> tactic)) list,
-    solvers: solver list * solver list};
+    solvers: solver list * solver list,
+    trace_ops: trace_ops};
 
 fun internal_ss (Simpset (_, ss2)) = ss2;
 
@@ -291,12 +307,12 @@
 fun map_ss1 f {rules, prems, bounds, depth} =
   make_ss1 (f (rules, prems, bounds, depth));
 
-fun make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =
+fun make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =
   {congs = congs, procs = procs, mk_rews = mk_rews, termless = termless,
-    subgoal_tac = subgoal_tac, loop_tacs = loop_tacs, solvers = solvers};
+    subgoal_tac = subgoal_tac, loop_tacs = loop_tacs, solvers = solvers, trace_ops = trace_ops};
 
-fun map_ss2 f {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers} =
-  make_ss2 (f (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers));
+fun map_ss2 f {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops} =
+  make_ss2 (f (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops));
 
 fun make_simpset (args1, args2) = Simpset (make_ss1 args1, make_ss2 args2);
 
@@ -316,9 +332,9 @@
 
 (* empty *)
 
-fun init_ss mk_rews termless subgoal_tac solvers =
+fun init_ss mk_rews termless subgoal_tac solvers trace_ops =
   make_simpset ((Net.empty, [], (0, []), (0, Unsynchronized.ref false)),
-    (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers));
+    (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers, trace_ops));
 
 fun default_mk_sym _ th = SOME (th RS Drule.symmetric_thm);
 
@@ -329,7 +345,9 @@
       mk_sym = default_mk_sym,
       mk_eq_True = K (K NONE),
       reorient = default_reorient}
-    Term_Ord.termless (K (K no_tac)) ([], []);
+    Term_Ord.termless (K (K no_tac))
+    ([], [])
+    no_trace_ops;
 
 
 (* merge *)  (*NOTE: ignores some fields of 2nd simpset*)
@@ -340,10 +358,10 @@
     let
       val Simpset ({rules = rules1, prems = prems1, bounds = bounds1, depth = depth1},
        {congs = (congs1, weak1), procs = procs1, mk_rews, termless, subgoal_tac,
-        loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
+        loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1), trace_ops}) = ss1;
       val Simpset ({rules = rules2, prems = prems2, bounds = bounds2, depth = depth2},
        {congs = (congs2, weak2), procs = procs2, mk_rews = _, termless = _, subgoal_tac = _,
-        loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
+        loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2), trace_ops = _}) = ss2;
 
       val rules' = Net.merge eq_rrule (rules1, rules2);
       val prems' = Thm.merge_thms (prems1, prems2);
@@ -357,7 +375,7 @@
       val solvers' = merge eq_solver (solvers1, solvers2);
     in
       make_simpset ((rules', prems', bounds', depth'), ((congs', weak'), procs',
-        mk_rews, termless, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
+        mk_rews, termless, subgoal_tac, loop_tacs', (unsafe_solvers', solvers'), trace_ops))
     end;
 
 
@@ -380,9 +398,15 @@
 
 fun simpset_map ctxt f ss = ctxt |> map_simpset (K ss) |> f |> Context.Proof |> Simpset.get;
 
-fun put_simpset (Simpset ({rules, prems, ...}, ss2)) =  (* FIXME prems from context (!?) *)
-  map_simpset (fn Simpset ({bounds, depth, ...}, _) =>
-    Simpset (make_ss1 (rules, prems, bounds, depth), ss2));
+fun put_simpset ss = map_simpset (fn context_ss =>
+  let
+    val Simpset ({rules, prems, ...},  (* FIXME prems from context (!?) *)
+      {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, ...}) = ss;
+    val Simpset ({bounds, depth, ...}, {trace_ops, ...}) = context_ss;
+  in
+    Simpset (make_ss1 (rules, prems, bounds, depth),
+      make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops))
+  end);
 
 fun global_context thy ss = Proof_Context.init_global thy |> put_simpset ss;
 
@@ -397,8 +421,8 @@
 fun map_ss f = Context.mapping (map_theory_simpset f) f;
 
 val clear_simpset =
-  map_simpset (fn Simpset (_, {mk_rews, termless, subgoal_tac, solvers, ...}) =>
-    init_ss mk_rews termless subgoal_tac solvers);
+  map_simpset (fn Simpset (_, {mk_rews, termless, subgoal_tac, solvers, trace_ops, ...}) =>
+    init_ss mk_rews termless subgoal_tac solvers trace_ops);
 
 
 (* simp depth *)
@@ -661,8 +685,8 @@
 
 in
 
-fun add_eqcong thm ctxt = ctxt |>
-  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
+fun add_eqcong thm ctxt = ctxt |> map_simpset2
+  (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
     let
       val (lhs, _) = Logic.dest_equals (Thm.concl_of thm)
         handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", thm);
@@ -677,10 +701,10 @@
         else ();
       val xs' = AList.update (op =) (a, thm) xs;
       val weak' = if is_full_cong thm then weak else a :: weak;
-    in ((xs', weak'), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) end);
+    in ((xs', weak'), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) end);
 
-fun del_eqcong thm ctxt = ctxt |>
-  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
+fun del_eqcong thm ctxt = ctxt |> map_simpset2
+  (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
     let
       val (lhs, _) = Logic.dest_equals (Thm.concl_of thm)
         handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", thm);
@@ -691,7 +715,7 @@
       val xs' = filter_out (fn (x : cong_name, _) => x = a) xs;
       val weak' = xs' |> map_filter (fn (a, thm) =>
         if is_full_cong thm then NONE else SOME a);
-    in ((xs', weak'), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) end);
+    in ((xs', weak'), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) end);
 
 fun add_cong thm ctxt = add_eqcong (mk_cong ctxt thm) ctxt;
 fun del_cong thm ctxt = del_eqcong (mk_cong ctxt thm) ctxt;
@@ -733,17 +757,19 @@
 
 fun add_proc (proc as Proc {name, lhs, ...}) ctxt =
  (trace_cterm ctxt false (fn () => "Adding simplification procedure " ^ quote name ^ " for") lhs;
-  ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
-    (congs, Net.insert_term eq_proc (term_of lhs, proc) procs,
-      mk_rews, termless, subgoal_tac, loop_tacs, solvers))
+  ctxt |> map_simpset2
+    (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
+      (congs, Net.insert_term eq_proc (term_of lhs, proc) procs,
+        mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops))
   handle Net.INSERT =>
     (Context_Position.if_visible ctxt
       warning ("Ignoring duplicate simplification procedure " ^ quote name); ctxt));
 
 fun del_proc (proc as Proc {name, lhs, ...}) ctxt =
-  ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
-    (congs, Net.delete_term eq_proc (term_of lhs, proc) procs,
-      mk_rews, termless, subgoal_tac, loop_tacs, solvers))
+  ctxt |> map_simpset2
+    (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
+      (congs, Net.delete_term eq_proc (term_of lhs, proc) procs,
+        mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops))
   handle Net.DELETE =>
     (Context_Position.if_visible ctxt
       warning ("Simplification procedure " ^ quote name ^ " not in simpset"); ctxt);
@@ -763,14 +789,15 @@
 
 local
 
-fun map_mk_rews f = map_simpset2 (fn (congs, procs, {mk, mk_cong, mk_sym, mk_eq_True, reorient},
-      termless, subgoal_tac, loop_tacs, solvers) =>
-  let
-    val (mk', mk_cong', mk_sym', mk_eq_True', reorient') =
-      f (mk, mk_cong, mk_sym, mk_eq_True, reorient);
-    val mk_rews' = {mk = mk', mk_cong = mk_cong', mk_sym = mk_sym', mk_eq_True = mk_eq_True',
-      reorient = reorient'};
-  in (congs, procs, mk_rews', termless, subgoal_tac, loop_tacs, solvers) end);
+fun map_mk_rews f =
+  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
+    let
+      val {mk, mk_cong, mk_sym, mk_eq_True, reorient} = mk_rews;
+      val (mk', mk_cong', mk_sym', mk_eq_True', reorient') =
+        f (mk, mk_cong, mk_sym, mk_eq_True, reorient);
+      val mk_rews' = {mk = mk', mk_cong = mk_cong', mk_sym = mk_sym', mk_eq_True = mk_eq_True',
+        reorient = reorient'};
+    in (congs, procs, mk_rews', termless, subgoal_tac, loop_tacs, solvers, trace_ops) end);
 
 in
 
@@ -799,53 +826,64 @@
 (* termless *)
 
 fun set_termless termless =
-  map_simpset2 (fn (congs, procs, mk_rews, _, subgoal_tac, loop_tacs, solvers) =>
-   (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers));
+  map_simpset2 (fn (congs, procs, mk_rews, _, subgoal_tac, loop_tacs, solvers, trace_ops) =>
+   (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops));
 
 
 (* tactics *)
 
 fun set_subgoaler subgoal_tac =
-  map_simpset2 (fn (congs, procs, mk_rews, termless, _, loop_tacs, solvers) =>
-   (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers));
+  map_simpset2 (fn (congs, procs, mk_rews, termless, _, loop_tacs, solvers, trace_ops) =>
+   (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops));
 
 fun ctxt setloop tac = ctxt |>
-  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, _, solvers) =>
-   (congs, procs, mk_rews, termless, subgoal_tac, [("", tac)], solvers));
+  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, _, solvers, trace_ops) =>
+   (congs, procs, mk_rews, termless, subgoal_tac, [("", tac)], solvers, trace_ops));
 
 fun ctxt addloop (name, tac) = ctxt |>
-  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
+  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
     (congs, procs, mk_rews, termless, subgoal_tac,
-     AList.update (op =) (name, tac) loop_tacs, solvers));
+     AList.update (op =) (name, tac) loop_tacs, solvers, trace_ops));
 
 fun ctxt delloop name = ctxt |>
-  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
+  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops) =>
     (congs, procs, mk_rews, termless, subgoal_tac,
      (if AList.defined (op =) loop_tacs name then ()
       else
         Context_Position.if_visible ctxt
           warning ("No such looper in simpset: " ^ quote name);
-        AList.delete (op =) name loop_tacs), solvers));
+        AList.delete (op =) name loop_tacs), solvers, trace_ops));
 
-fun ctxt setSSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, (unsafe_solvers, _)) =>
-    (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, (unsafe_solvers, [solver])));
+fun ctxt setSSolver solver = ctxt |> map_simpset2
+  (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, (unsafe_solvers, _), trace_ops) =>
+    (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, (unsafe_solvers, [solver]), trace_ops));
 
 fun ctxt addSSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, (unsafe_solvers, solvers)) => (congs, procs, mk_rews, termless,
-    subgoal_tac, loop_tacs, (unsafe_solvers, insert eq_solver solver solvers)));
+  subgoal_tac, loop_tacs, (unsafe_solvers, solvers), trace_ops) => (congs, procs, mk_rews, termless,
+    subgoal_tac, loop_tacs, (unsafe_solvers, insert eq_solver solver solvers), trace_ops));
 
 fun ctxt setSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, (_, solvers)) => (congs, procs, mk_rews, termless,
-    subgoal_tac, loop_tacs, ([solver], solvers)));
+  subgoal_tac, loop_tacs, (_, solvers), trace_ops) => (congs, procs, mk_rews, termless,
+    subgoal_tac, loop_tacs, ([solver], solvers), trace_ops));
 
 fun ctxt addSolver solver = ctxt |> map_simpset2 (fn (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, (unsafe_solvers, solvers)) => (congs, procs, mk_rews, termless,
-    subgoal_tac, loop_tacs, (insert eq_solver solver unsafe_solvers, solvers)));
+  subgoal_tac, loop_tacs, (unsafe_solvers, solvers), trace_ops) => (congs, procs, mk_rews, termless,
+    subgoal_tac, loop_tacs, (insert eq_solver solver unsafe_solvers, solvers), trace_ops));
 
 fun set_solvers solvers = map_simpset2 (fn (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, _) => (congs, procs, mk_rews, termless,
-  subgoal_tac, loop_tacs, (solvers, solvers)));
+  subgoal_tac, loop_tacs, _, trace_ops) => (congs, procs, mk_rews, termless,
+  subgoal_tac, loop_tacs, (solvers, solvers), trace_ops));
+
+
+(* trace operations *)
+
+fun set_trace_ops trace_ops =
+  map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, _) =>
+    (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers, trace_ops));
+
+val trace_ops = #trace_ops o internal_ss o simpset_of;
+fun trace_invoke args ctxt = #trace_invoke (trace_ops ctxt) args ctxt;
+fun trace_apply args ctxt = #trace_apply (trace_ops ctxt) args ctxt;
 
 
 
@@ -943,6 +981,7 @@
         val prop' = Thm.prop_of thm';
         val unconditional = (Logic.count_prems prop' = 0);
         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop');
+        val trace_args = {unconditional = unconditional, term = eta_t, thm = thm', name = name};
       in
         if perm andalso not (termless (rhs', lhs'))
         then
@@ -954,10 +993,11 @@
           if unconditional
           then
            (trace_thm ctxt (fn () => "Rewriting:") thm';
-            let
-              val lr = Logic.dest_equals prop;
-              val SOME thm'' = check_conv ctxt false eta_thm thm';
-            in SOME (thm'', uncond_skel (congs, lr)) end)
+            trace_apply trace_args ctxt (fn ctxt' =>
+              let
+                val lr = Logic.dest_equals prop;
+                val SOME thm'' = check_conv ctxt' false eta_thm thm';
+              in SOME (thm'', uncond_skel (congs, lr)) end))
           else
            (trace_thm ctxt (fn () => "Trying to rewrite:") thm';
             if simp_depth ctxt > Config.get ctxt simp_depth_limit
@@ -968,16 +1008,17 @@
                 val _ = Context_Position.if_visible ctxt warning s;
               in NONE end
             else
-              (case prover ctxt thm' of
-                NONE => (trace_thm ctxt (fn () => "FAILED") thm'; NONE)
-              | SOME thm2 =>
-                  (case check_conv ctxt true eta_thm thm2 of
-                    NONE => NONE
-                  | SOME thm2' =>
-                      let
-                        val concl = Logic.strip_imp_concl prop;
-                        val lr = Logic.dest_equals concl;
-                      in SOME (thm2', cond_skel (congs, lr)) end))))
+              trace_apply trace_args ctxt (fn ctxt' =>
+                (case prover ctxt' thm' of
+                  NONE => (trace_thm ctxt' (fn () => "FAILED") thm'; NONE)
+                | SOME thm2 =>
+                    (case check_conv ctxt' true eta_thm thm2 of
+                      NONE => NONE
+                    | SOME thm2' =>
+                        let
+                          val concl = Logic.strip_imp_concl prop;
+                          val lr = Logic.dest_equals concl;
+                        in SOME (thm2', cond_skel (congs, lr)) end)))))
       end;
 
     fun rews [] = NONE
@@ -1311,11 +1352,7 @@
 
 fun rewrite_cterm mode prover raw_ctxt raw_ct =
   let
-    val ctxt =
-      raw_ctxt
-      |> Context_Position.set_visible false
-      |> inc_simp_depth;
-    val thy = Proof_Context.theory_of ctxt;
+    val thy = Proof_Context.theory_of raw_ctxt;
 
     val ct = Thm.adjust_maxidx_cterm ~1 raw_ct;
     val {maxidx, ...} = Thm.rep_cterm ct;
@@ -1323,6 +1360,12 @@
       Theory.subthy (theory_of_cterm ct, thy) orelse
         raise CTERM ("rewrite_cterm: bad background theory", [ct]);
 
+    val ctxt =
+      raw_ctxt
+      |> Context_Position.set_visible false
+      |> inc_simp_depth
+      |> (fn ctxt => trace_invoke {depth = simp_depth ctxt, term = Thm.term_of ct} ctxt);
+
     val _ = trace_cterm ctxt false (fn () => "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:") ct;
     val _ = check_bounds ctxt ct;
   in bottomc (mode, Option.map Drule.flexflex_unique oo prover, maxidx) ctxt ct end;
--- a/src/Pure/simplifier.ML	Thu Dec 12 17:34:50 2013 +0100
+++ b/src/Pure/simplifier.ML	Thu Dec 12 21:14:33 2013 +0100
@@ -47,6 +47,8 @@
   val set_mkeqTrue: (Proof.context -> thm -> thm option) -> Proof.context -> Proof.context
   val set_termless: (term * term -> bool) -> Proof.context -> Proof.context
   val set_subgoaler: (Proof.context -> int -> tactic) -> Proof.context -> Proof.context
+  type trace_ops
+  val set_trace_ops: trace_ops -> Proof.context -> Proof.context
   val simproc_global_i: theory -> string -> term list ->
     (Proof.context -> term -> thm option) -> simproc
   val simproc_global: theory -> string -> string list ->