src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53095 667717a5ad80
parent 53094 e33d77814a92
child 53098 db5e1b53bbfc
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Aug 20 11:42:51 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Aug 20 11:42:52 2013 +0200
@@ -43,7 +43,7 @@
     val relearn :
       Proof.context -> bool -> (string * string list) list -> unit
     val query :
-      Proof.context -> bool -> bool -> int
+      Proof.context -> bool -> int
       -> (string * string list * (string * real) list * string list) list
          * string list * string list * (string * real) list
       -> string list
@@ -71,6 +71,8 @@
     Proof.context -> params -> string -> int -> raw_fact list
     -> string Symtab.table * string Symtab.table -> thm
     -> bool * string list
+  val attach_parents_to_facts :
+    ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
   val weight_mepo_facts : 'a list -> ('a * real) list
   val weight_mash_facts : 'a list -> ('a * real) list
   val find_mash_suggestions :
@@ -82,8 +84,6 @@
   val mash_learn_proof :
     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
     -> unit
-  val attach_parents_to_facts :
-    ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
   val mash_learn :
     Proof.context -> params -> fact_override -> thm list -> bool -> unit
   val is_mash_enabled : unit -> bool
@@ -227,16 +227,10 @@
 fun str_of_relearn (name, deps) =
   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
 
-fun str_of_query learn (learns, hints, parents, feats) =
-  (if learn then
-     implode (map str_of_learn learns) ^
-     (if null hints then ""
-      else str_of_learn (freshish_name (), parents, feats, hints))
-   else
-     "") ^
+fun str_of_query (learns, hints, parents, feats) =
+  implode (map str_of_learn learns) ^
   "? " ^ encode_strs parents ^ "; " ^ encode_features feats ^
-  (if learn orelse null hints then "" else "; " ^ encode_strs hints) ^
-  "\n"
+  (if null hints then "" else "; " ^ encode_strs hints) ^ "\n"
 
 (* The weights currently returned by "mash.py" are too spaced out to make any
    sense. *)
@@ -277,11 +271,9 @@
          elide_string 1000 (space_implode " " (map #1 relearns)));
      run_mash_tool ctxt overlord true 0 (relearns, str_of_relearn) (K ()))
 
-fun query ctxt overlord learn max_suggs (query as (learns, hints, _, feats)) =
+fun query ctxt overlord max_suggs (query as (learns, hints, _, feats)) =
   (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
-   run_mash_tool ctxt overlord
-       (learn andalso not (null learns) andalso not (null hints))
-       max_suggs ([query], str_of_query learn)
+   run_mash_tool ctxt overlord false max_suggs ([query], str_of_query)
        (fn suggs =>
            case suggs () of
              [] => []
@@ -335,15 +327,18 @@
   string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   string_of_int (length (Graph.maximals G)) ^ " maximal"
 
-type mash_state = {access_G : unit Graph.T, dirty : string list option}
+type mash_state =
+  {access_G : unit Graph.T,
+   num_known_facts : int,
+   dirty : string list option}
 
-val empty_state = {access_G = Graph.empty, dirty = SOME []}
+val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty = SOME []}
 
 local
 
-val version = "*** MaSh version 20130819a ***"
+val version = "*** MaSh version 20130820 ***"
 
-exception Too_New of unit
+exception FILE_VERSION_TOO_NEW of unit
 
 fun extract_node line =
   case space_explode ":" line of
@@ -371,24 +366,27 @@
              | SOME (name, parents, kind) =>
                update_access_graph_node (name, kind)
                #> fold (add_edge_to name) parents
-           val access_G =
+           val (access_G, num_known_facts) =
              case string_ord (version', version) of
                EQUAL =>
-               try_graph ctxt "loading state" Graph.empty (fn () =>
-                   fold add_node node_lines Graph.empty)
+               (try_graph ctxt "loading state" Graph.empty (fn () =>
+                    fold add_node node_lines Graph.empty),
+                length node_lines)
              | LESS =>
-               (MaSh.unlearn ctxt; Graph.empty) (* can't parse old file *)
-             | GREATER => raise Too_New ()
+               (* can't parse old file *)
+               (MaSh.unlearn ctxt; (Graph.empty, 0))
+             | GREATER => raise FILE_VERSION_TOO_NEW ()
          in
            trace_msg ctxt (fn () =>
                "Loaded fact graph (" ^ graph_info access_G ^ ")");
-           {access_G = access_G, dirty = SOME []}
+           {access_G = access_G, num_known_facts = num_known_facts,
+            dirty = SOME []}
          end
        | _ => empty_state)
     end
 
 fun save _ (state as {dirty = SOME [], ...}) = state
-  | save ctxt {access_G, dirty} =
+  | save ctxt {access_G, num_known_facts, dirty} =
     let
       fun str_of_entry (name, parents, kind) =
         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
@@ -408,7 +406,7 @@
              SOME dirty =>
              "; " ^ string_of_int (length dirty) ^ " dirty fact(s)"
            | _ => "") ^  ")");
-      {access_G = access_G, dirty = SOME []}
+      {access_G = access_G, num_known_facts = num_known_facts, dirty = SOME []}
     end
 
 val global_state =
@@ -418,7 +416,7 @@
 
 fun map_state ctxt f =
   Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
-  handle Too_New () => ()
+  handle FILE_VERSION_TOO_NEW () => ()
 
 fun peek_state ctxt f =
   Synchronized.change_result global_state
@@ -723,6 +721,9 @@
      | NONE => false)
   | is_size_def _ _ = false
 
+fun no_dependencies_for_status status =
+  status = Non_Rec_Def orelse status = Rec_Def
+
 fun trim_dependencies deps =
   if length deps > max_dependencies then NONE else SOME deps
 
@@ -790,159 +791,9 @@
 
 (*** High-level communication with MaSh ***)
 
-fun maximal_wrt_graph G keys =
-  let
-    val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
-    fun insert_new seen name =
-      not (Symtab.defined seen name) ? insert (op =) name
-    fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
-    fun find_maxes _ (maxs, []) = map snd maxs
-      | find_maxes seen (maxs, new :: news) =
-        find_maxes
-            (seen |> num_keys (Graph.imm_succs G new) > 1
-                     ? Symtab.default (new, ()))
-            (if Symtab.defined tab new then
-               let
-                 val newp = Graph.all_preds G [new]
-                 fun is_ancestor x yp = member (op =) yp x
-                 val maxs =
-                   maxs |> filter (fn (_, max) => not (is_ancestor max newp))
-               in
-                 if exists (is_ancestor new o fst) maxs then
-                   (maxs, news)
-                 else
-                   ((newp, new)
-                    :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
-                    news)
-               end
-             else
-               (maxs, Graph.Keys.fold (insert_new seen)
-                                      (Graph.imm_preds G new) news))
-  in find_maxes Symtab.empty ([], Graph.maximals G) end
-
-fun maximal_wrt_access_graph access_G =
-  map (nickname_of_thm o snd)
-  #> maximal_wrt_graph access_G
-
-fun is_fact_in_graph access_G get_th fact =
-  can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
-
-(* FUDGE *)
-fun weight_of_mepo_fact rank =
-  Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
-
-fun weight_mepo_facts facts =
-  facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
-
-val weight_raw_mash_facts = weight_mepo_facts
-val weight_mash_facts = weight_raw_mash_facts
-
-(* FUDGE *)
-fun weight_of_proximity_fact rank =
-  Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
-
-fun weight_proximity_facts facts =
-  facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
-
-val max_proximity_facts = 100
-
-fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
-  | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
-    let
-      val raw_mash = find_suggested_facts ctxt facts suggs
-      val unknown_chained =
-        inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
-      val proximity =
-        facts |> sort (crude_thm_ord o pairself snd o swap)
-              |> take max_proximity_facts
-      val mess =
-        [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
-         (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
-         (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
-      val unknown =
-        raw_unknown
-        |> fold (subtract (Thm.eq_thm_prop o pairself snd))
-                [unknown_chained, proximity]
-    in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
-
-fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
-                         hyp_ts concl_t facts =
-  let
-    val thy = Proof_Context.theory_of ctxt
-    val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
-    val (access_G, suggs) =
-      peek_state ctxt (fn {access_G, ...} =>
-          if Graph.is_empty access_G then
-            (access_G, [])
-          else
-            let
-              val parents = maximal_wrt_access_graph access_G facts
-              val feats =
-                features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
-              val hints =
-                chained |> filter (is_fact_in_graph access_G snd)
-                        |> map (nickname_of_thm o snd)
-            in
-              (access_G, MaSh.query ctxt overlord learn max_facts
-                                    ([], hints, parents, feats))
-            end)
-    val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
-  in
-    find_mash_suggestions ctxt max_facts suggs facts chained unknown
-    |> pairself (map fact_of_raw_fact)
-  end
-
-fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
-  let
-    fun maybe_learn_from from (accum as (parents, graph)) =
-      try_graph ctxt "updating graph" accum (fn () =>
-          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
-    val graph = graph |> Graph.default_node (name, Isar_Proof)
-    val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
-    val (deps, _) = ([], graph) |> fold maybe_learn_from deps
-  in ((name, parents, feats, deps) :: learns, graph) end
-
-fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
-  let
-    fun maybe_relearn_from from (accum as (parents, graph)) =
-      try_graph ctxt "updating graph" accum (fn () =>
-          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
-    val graph = graph |> update_access_graph_node (name, Automatic_Proof)
-    val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
-  in ((name, deps) :: relearns, graph) end
-
-fun flop_wrt_access_graph name =
-  update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
-
-val learn_timeout_slack = 2.0
-
-fun launch_thread timeout task =
-  let
-    val hard_timeout = time_mult learn_timeout_slack timeout
-    val birth_time = Time.now ()
-    val death_time = Time.+ (birth_time, hard_timeout)
-    val desc = ("Machine learner for Sledgehammer", "")
-  in Async_Manager.thread MaShN birth_time death_time desc task end
-
-fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
-                     used_ths =
-  launch_thread (timeout |> the_default one_day) (fn () =>
-      let
-        val thy = Proof_Context.theory_of ctxt
-        val name = freshish_name ()
-        val feats = features_of ctxt prover thy (Local, General) [t]
-      in
-        peek_state ctxt (fn {access_G, ...} =>
-            let
-              val parents = maximal_wrt_access_graph access_G facts
-              val deps =
-                used_ths |> filter (is_fact_in_graph access_G I)
-                         |> map nickname_of_thm
-            in
-              MaSh.learn ctxt overlord [(name, parents, feats, deps)]
-            end);
-        (true, "")
-      end)
+fun attach_crude_parents_to_facts _ [] = []
+  | attach_crude_parents_to_facts parents ((fact as (_, th)) :: facts) =
+    (parents, fact) :: attach_crude_parents_to_facts [nickname_of_thm th] facts
 
 (* In the following functions, chunks are risers w.r.t. "thm_less_eq". *)
 
@@ -994,6 +845,191 @@
       |> drop (length old_facts)
     end
 
+fun maximal_wrt_graph G keys =
+  let
+    val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
+    fun insert_new seen name =
+      not (Symtab.defined seen name) ? insert (op =) name
+    fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
+    fun find_maxes _ (maxs, []) = map snd maxs
+      | find_maxes seen (maxs, new :: news) =
+        find_maxes
+            (seen |> num_keys (Graph.imm_succs G new) > 1
+                     ? Symtab.default (new, ()))
+            (if Symtab.defined tab new then
+               let
+                 val newp = Graph.all_preds G [new]
+                 fun is_ancestor x yp = member (op =) yp x
+                 val maxs =
+                   maxs |> filter (fn (_, max) => not (is_ancestor max newp))
+               in
+                 if exists (is_ancestor new o fst) maxs then
+                   (maxs, news)
+                 else
+                   ((newp, new)
+                    :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
+                    news)
+               end
+             else
+               (maxs, Graph.Keys.fold (insert_new seen)
+                                      (Graph.imm_preds G new) news))
+  in find_maxes Symtab.empty ([], Graph.maximals G) end
+
+fun maximal_wrt_access_graph access_G =
+  map (nickname_of_thm o snd)
+  #> maximal_wrt_graph access_G
+
+fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
+
+(* FUDGE *)
+fun weight_of_mepo_fact rank =
+  Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
+
+fun weight_mepo_facts facts =
+  facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
+
+val weight_raw_mash_facts = weight_mepo_facts
+val weight_mash_facts = weight_raw_mash_facts
+
+(* FUDGE *)
+fun weight_of_proximity_fact rank =
+  Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
+
+fun weight_proximity_facts facts =
+  facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
+
+val max_proximity_facts = 100
+
+fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
+  | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
+    let
+      val raw_mash = find_suggested_facts ctxt facts suggs
+      val unknown_chained =
+        inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
+      val proximity =
+        facts |> sort (crude_thm_ord o pairself snd o swap)
+              |> take max_proximity_facts
+      val mess =
+        [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
+         (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
+         (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
+      val unknown =
+        raw_unknown
+        |> fold (subtract (Thm.eq_thm_prop o pairself snd))
+                [unknown_chained, proximity]
+    in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
+
+val max_learn_on_query = 500
+
+fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
+                         hyp_ts concl_t facts =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
+    val (access_G, suggs) =
+      peek_state ctxt (fn {access_G, num_known_facts, ...} =>
+          if Graph.is_empty access_G then
+            (access_G, [])
+          else
+            let
+              val parents = maximal_wrt_access_graph access_G facts
+              val feats =
+                features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
+              val hints =
+                chained |> filter (is_fact_in_graph access_G o snd)
+                        |> map (nickname_of_thm o snd)
+              val (learns, parents) =
+                if length facts - num_known_facts <= max_learn_on_query then
+                  let
+                    val name_tabs = build_name_tables nickname_of_thm facts
+                    fun deps_of status th =
+                      if no_dependencies_for_status status then
+                        SOME []
+                      else
+                        isar_dependencies_of name_tabs th
+                        |> trim_dependencies
+                    fun learn_new_fact (parents,
+                                        ((_, stature as (_, status)), th)) =
+                      let
+                        val name = nickname_of_thm th
+                        val feats =
+                          features_of ctxt prover (theory_of_thm th) stature
+                                      [prop_of th]
+                        val deps = deps_of status th |> these
+                      in (name, parents, feats, deps) end
+                    val new_facts =
+                      facts |> filter_out (is_fact_in_graph access_G o snd)
+                            |> sort (crude_thm_ord o pairself snd)
+                            |> attach_crude_parents_to_facts parents
+                    val learns = new_facts |> map learn_new_fact
+                    val parents =
+                      if null new_facts then parents
+                      else [#1 (List.last learns)]
+                  in (learns, parents) end
+                else
+                  ([], parents)
+            in
+              (access_G, MaSh.query ctxt overlord max_facts
+                                    (learns, hints, parents, feats))
+            end)
+    val unknown = facts |> filter_out (is_fact_in_graph access_G o snd)
+  in
+    find_mash_suggestions ctxt max_facts suggs facts chained unknown
+    |> pairself (map fact_of_raw_fact)
+  end
+
+fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
+  let
+    fun maybe_learn_from from (accum as (parents, graph)) =
+      try_graph ctxt "updating graph" accum (fn () =>
+          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
+    val graph = graph |> Graph.default_node (name, Isar_Proof)
+    val (parents, graph) = ([], graph) |> fold maybe_learn_from parents
+    val (deps, _) = ([], graph) |> fold maybe_learn_from deps
+  in ((name, parents, feats, deps) :: learns, graph) end
+
+fun relearn_wrt_access_graph ctxt (name, deps) (relearns, graph) =
+  let
+    fun maybe_relearn_from from (accum as (parents, graph)) =
+      try_graph ctxt "updating graph" accum (fn () =>
+          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
+    val graph = graph |> update_access_graph_node (name, Automatic_Proof)
+    val (deps, _) = ([], graph) |> fold maybe_relearn_from deps
+  in ((name, deps) :: relearns, graph) end
+
+fun flop_wrt_access_graph name =
+  update_access_graph_node (name, Isar_Proof_wegen_Prover_Flop)
+
+val learn_timeout_slack = 2.0
+
+fun launch_thread timeout task =
+  let
+    val hard_timeout = time_mult learn_timeout_slack timeout
+    val birth_time = Time.now ()
+    val death_time = Time.+ (birth_time, hard_timeout)
+    val desc = ("Machine learner for Sledgehammer", "")
+  in Async_Manager.thread MaShN birth_time death_time desc task end
+
+fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
+                     used_ths =
+  launch_thread (timeout |> the_default one_day) (fn () =>
+      let
+        val thy = Proof_Context.theory_of ctxt
+        val name = freshish_name ()
+        val feats = features_of ctxt prover thy (Local, General) [t]
+      in
+        peek_state ctxt (fn {access_G, ...} =>
+            let
+              val parents = maximal_wrt_access_graph access_G facts
+              val deps =
+                used_ths |> filter (is_fact_in_graph access_G)
+                         |> map nickname_of_thm
+            in
+              MaSh.learn ctxt overlord [(name, parents, feats, deps)]
+            end);
+        (true, "")
+      end)
+
 fun sendback sub =
   Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
 
@@ -1007,7 +1043,7 @@
     fun next_commit_time () =
       Time.+ (Timer.checkRealTimer timer, commit_timeout)
     val {access_G, ...} = peek_state ctxt I
-    val is_in_access_G = is_fact_in_graph access_G snd
+    val is_in_access_G = is_fact_in_graph access_G o snd
     val no_new_facts = forall is_in_access_G facts
   in
     if no_new_facts andalso not run_prover then
@@ -1025,7 +1061,7 @@
       let
         val name_tabs = build_name_tables nickname_of_thm facts
         fun deps_of status th =
-          if status = Non_Rec_Def orelse status = Rec_Def then
+          if no_dependencies_for_status status then
             SOME []
           else if run_prover then
             prover_dependencies_of ctxt params prover auto_level facts name_tabs
@@ -1036,7 +1072,7 @@
             isar_dependencies_of name_tabs th
             |> trim_dependencies
         fun do_commit [] [] [] state = state
-          | do_commit learns relearns flops {access_G, dirty} =
+          | do_commit learns relearns flops {access_G, num_known_facts, dirty} =
             let
               val was_empty = Graph.is_empty access_G
               val (learns, access_G) =
@@ -1044,6 +1080,7 @@
               val (relearns, access_G) =
                 ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
               val access_G = access_G |> fold flop_wrt_access_graph flops
+              val num_known_facts = num_known_facts + length learns
               val dirty =
                 case (was_empty, dirty, relearns) of
                   (false, SOME names, []) => SOME (map #1 learns @ names)
@@ -1051,7 +1088,8 @@
             in
               MaSh.learn ctxt overlord (rev learns);
               MaSh.relearn ctxt overlord relearns;
-              {access_G = access_G, dirty = dirty}
+              {access_G = access_G, num_known_facts = num_known_facts,
+               dirty = dirty}
             end
         fun commit last learns relearns flops =
           (if debug andalso auto_level = 0 then