src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
changeset 48321 c552d7f1720b
parent 48320 891a24a48155
child 48322 8a8d71e34297
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML	Wed Jul 18 08:44:04 2012 +0200
@@ -25,8 +25,9 @@
   val unescape_meta : string -> string
   val unescape_metas : string -> string list
   val extract_query : string -> string * string list
-  val suggested_facts : string list -> fact list -> fact list
-  val mesh_facts : int -> (fact list * fact list) list -> fact list
+  val suggested_facts : string list -> ('a * thm) list -> ('a * thm) list
+  val mesh_facts :
+    int -> (('a * thm) list * ('a * thm) list) list -> ('a * thm) list
   val is_likely_tautology : Proof.context -> string -> thm -> bool
   val is_too_meta : thm -> bool
   val theory_ord : theory * theory -> order
@@ -35,7 +36,7 @@
     Proof.context -> string -> theory -> status -> term list -> string list
   val isabelle_dependencies_of : unit Symtab.table -> thm -> string list
   val goal_of_thm : theory -> thm -> thm
-  val run_prover :
+  val run_prover_for_mash :
     Proof.context -> params -> string -> fact list -> thm -> prover_result
   val mash_RESET : Proof.context -> unit
   val mash_INIT :
@@ -48,14 +49,14 @@
     Proof.context -> bool -> int -> string list * string list -> string list
   val mash_reset : Proof.context -> unit
   val mash_could_suggest_facts : unit -> bool
-  val mash_can_suggest_facts : unit -> bool
+  val mash_can_suggest_facts : Proof.context -> bool
   val mash_suggest_facts :
-    Proof.context -> params -> string -> int -> term list -> term -> fact list
-    -> fact list * fact list
+    Proof.context -> params -> string -> int -> term list -> term
+    -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
   val mash_learn_thy :
     Proof.context -> params -> theory -> Time.time -> fact list -> string
   val mash_learn_proof :
-    Proof.context -> params -> term -> thm list -> fact list -> unit
+    Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
   val relevant_facts :
     Proof.context -> params -> string -> int -> fact_override -> term list
     -> term -> fact list -> fact list
@@ -300,13 +301,13 @@
 
 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
 
-fun run_prover ctxt params prover facts goal =
+fun run_prover_for_mash ctxt params prover facts goal =
   let
     val problem =
       {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
        facts = facts |> map (apfst (apfst (fn name => name ())))
                      |> map Untranslated_Fact}
-    val prover = get_minimizing_prover ctxt Normal prover
+    val prover = get_minimizing_prover ctxt Normal (K ()) prover
   in prover params (K (K (K ""))) problem end
 
 
@@ -406,6 +407,15 @@
 
 (*** High-level communication with MaSh ***)
 
+fun try_graph ctxt when def f =
+  f ()
+  handle Graph.CYCLES (cycle :: _) =>
+         (trace_msg ctxt (fn () =>
+              "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
+       | Graph.UNDEF name =>
+         (trace_msg ctxt (fn () =>
+              "Unknown fact " ^ quote name ^ " when " ^ when); def)
+
 type mash_state =
   {thys : bool Symtab.table,
    fact_graph : unit Graph.T}
@@ -414,8 +424,8 @@
 
 local
 
-fun mash_load (state as (true, _)) = state
-  | mash_load _ =
+fun mash_load _ (state as (true, _)) = state
+  | mash_load ctxt _ =
     let val path = mash_state_path () in
       (true,
        case try File.read_lines path of
@@ -431,7 +441,9 @@
            val thys =
              Symtab.empty |> fold (add_thy true) (unescape_metas comp_thys)
                           |> fold (add_thy false) (unescape_metas incomp_thys)
-           val fact_graph = Graph.empty |> fold add_fact_line fact_lines
+           val fact_graph =
+             try_graph ctxt "loading state file" Graph.empty (fn () =>
+                 Graph.empty |> fold add_fact_line fact_lines)
          in {thys = thys, fact_graph = fact_graph} end
        | _ => empty_state)
     end
@@ -456,10 +468,11 @@
 
 in
 
-fun mash_map f =
-  Synchronized.change global_state (mash_load ##> (f #> tap mash_save))
+fun mash_map ctxt f =
+  Synchronized.change global_state (mash_load ctxt ##> (f #> tap mash_save))
 
-fun mash_get () = Synchronized.change_result global_state (mash_load #> `snd)
+fun mash_get ctxt =
+  Synchronized.change_result global_state (mash_load ctxt #> `snd)
 
 fun mash_reset ctxt =
   Synchronized.change global_state (fn _ =>
@@ -469,17 +482,22 @@
 end
 
 fun mash_could_suggest_facts () = mash_home () <> ""
-fun mash_can_suggest_facts () = not (Graph.is_empty (#fact_graph (mash_get ())))
+fun mash_can_suggest_facts ctxt =
+  not (Graph.is_empty (#fact_graph (mash_get ctxt)))
 
-fun parents_wrt_facts facts fact_graph =
+fun parents_wrt_facts ctxt facts fact_graph =
   let
     val graph_facts = Symtab.make (map (rpair ()) (Graph.keys fact_graph))
     val facts =
-      [] |> fold (cons o Thm.get_name_hint o snd) facts
-         |> filter (Symtab.defined graph_facts)
-         |> Graph.all_preds fact_graph
+      try_graph ctxt "when computing ancestor facts" [] (fn () =>
+          [] |> fold (cons o Thm.get_name_hint o snd) facts
+             |> filter (Symtab.defined graph_facts)
+             |> Graph.all_preds fact_graph)
     val facts = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
-  in fact_graph |> Graph.restrict (Symtab.defined facts) |> Graph.maximals end
+  in
+    try_graph ctxt "when computing parent facts" [] (fn () =>
+        fact_graph |> Graph.restrict (Symtab.defined facts) |> Graph.maximals)
+  end
 
 (* Generate more suggestions than requested, because some might be thrown out
    later for various reasons and "meshing" gives better results with some
@@ -493,8 +511,8 @@
                        concl_t facts =
   let
     val thy = Proof_Context.theory_of ctxt
-    val fact_graph = #fact_graph (mash_get ())
-    val parents = parents_wrt_facts facts fact_graph
+    val fact_graph = #fact_graph (mash_get ctxt)
+    val parents = parents_wrt_facts ctxt facts fact_graph
     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
     val suggs =
       if Graph.is_empty fact_graph then []
@@ -511,13 +529,9 @@
 fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
   let
     fun maybe_add_from from (accum as (parents, graph)) =
-      (from :: parents, Graph.add_edge_acyclic (from, name) graph)
-      handle Graph.CYCLES _ =>
-             (trace_msg ctxt (fn () =>
-                  "Cycle between " ^ quote from ^ " and " ^ quote name); accum)
-           | Graph.UNDEF _ =>
-             (trace_msg ctxt (fn () => "Unknown node " ^ quote from); accum)
-    val graph = graph |> Graph.new_node (name, ())
+      try_graph ctxt "updating graph" accum (fn () =>
+          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
+    val graph = graph |> Graph.default_node (name, ())
     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   in ((name, parents, feats, deps) :: upds, graph) end
@@ -532,7 +546,7 @@
     val prover = hd provers
     fun timed_out frac =
       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
-    val {fact_graph, ...} = mash_get ()
+    val {fact_graph, ...} = mash_get ctxt
     val new_facts =
       facts |> filter_out (is_fact_in_graph fact_graph)
             |> sort (thm_ord o pairself snd)
@@ -554,7 +568,7 @@
               val deps = isabelle_dependencies_of all_names th
               val upd = (name, parents, feats, deps)
             in (([name], upd :: upds), timed_out pass1_learn_timeout_factor) end
-        val parents = parents_wrt_facts facts fact_graph
+        val parents = parents_wrt_facts ctxt facts fact_graph
         val ((_, upds), _) =
           ((parents, []), false) |> fold do_fact new_facts |>> apsnd rev
         val n = length upds
@@ -570,7 +584,7 @@
               fact_graph = fact_graph})
           end
       in
-        mash_map trans;
+        mash_map ctxt trans;
         if verbose then
           "Processed " ^ string_of_int n ^ " proof" ^ plural_s n ^
           (if verbose then
@@ -582,7 +596,7 @@
       end
   end
 
-fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t used_ths facts =
+fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t facts used_ths =
   let
     val thy = Proof_Context.theory_of ctxt
     val prover = hd provers
@@ -590,9 +604,9 @@
     val feats = features_of ctxt prover thy General [t]
     val deps = used_ths |> map Thm.get_name_hint
   in
-    mash_map (fn {thys, fact_graph} =>
+    mash_map ctxt (fn {thys, fact_graph} =>
         let
-          val parents = parents_wrt_facts facts fact_graph
+          val parents = parents_wrt_facts ctxt facts fact_graph
           val upds = [(name, parents, feats, deps)]
           val (upds, fact_graph) =
             ([], fact_graph) |> fold (update_fact_graph ctxt) upds
@@ -608,19 +622,19 @@
 val short_learn_timeout_factor = 0.2
 val long_learn_timeout_factor = 4.0
 
-fun relevant_facts ctxt (params as {fact_filter, timeout, ...}) prover max_facts
-        ({add, only, ...} : fact_override) hyp_ts concl_t facts =
+fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
+        max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
   if not (subset (op =) (the_list fact_filter, fact_filters)) then
     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
   else if only then
     facts
-  else if max_facts <= 0 then
+  else if max_facts <= 0 orelse null facts then
     []
   else
     let
       val thy = Proof_Context.theory_of ctxt
       fun maybe_learn can_suggest =
-        if Async_Manager.has_running_threads MaShN orelse null facts then
+        if not learn orelse Async_Manager.has_running_threads MaShN then
           ()
         else if Time.toSeconds timeout >= min_secs_for_learning then
           let
@@ -642,10 +656,10 @@
       val fact_filter =
         case fact_filter of
           SOME ff =>
-          (if ff <> iterN then maybe_learn (mash_can_suggest_facts ()) else ();
-           ff)
+          (if ff <> iterN then maybe_learn (mash_can_suggest_facts ctxt)
+           else (); ff)
         | NONE =>
-          if mash_can_suggest_facts () then (maybe_learn true; meshN)
+          if mash_can_suggest_facts ctxt then (maybe_learn true; meshN)
           else if mash_could_suggest_facts () then (maybe_learn false; iterN)
           else iterN
       val add_ths = Attrib.eval_thms ctxt add