learn new facts on query if there aren't too many of them in MaSh
authorblanchet
Fri, 23 Aug 2013 13:30:25 +0200
changeset 53152 cbd3c7c48d2c
parent 53151 fbf4d50dec91
child 53153 1e9735cd27aa
learn new facts on query if there aren't too many of them in MaSh
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Fri Aug 23 00:12:20 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Fri Aug 23 13:30:25 2013 +0200
@@ -103,30 +103,27 @@
         received = communicate(data,args.host,args.port)
         logger.info(received)     
     
-    if args.inputFile == None:
-        return
-    logger.debug('Using the following settings: %s',args)
-    # IO Streams
-    OS = open(args.predictions,'w')
-    IS = open(args.inputFile,'r')
-    lineCount = 0
-    for line in IS:
-        lineCount += 1
-        if lineCount % 100 == 0:
-            logger.info('On line %s', lineCount)
-        #if lineCount == 50: ###
-        #    break       
-        received = communicate(line,args.host,args.port)
-        if not received == '':
-            OS.write('%s\n' % received)
-    OS.close()
-    IS.close()
+    if not args.inputFile == None:
+        logger.debug('Using the following settings: %s',args)
+        # IO Streams
+        OS = open(args.predictions,'w')
+        IS = open(args.inputFile,'r')
+        lineCount = 0
+        for line in IS:
+            lineCount += 1
+            if lineCount % 100 == 0:
+                logger.info('On line %s', lineCount)
+            received = communicate(line,args.host,args.port)
+            if not received == '':
+                OS.write('%s\n' % received)
+        OS.close()
+        IS.close()
 
     # Statistics
     if args.statistics:
         received = communicate('avgStats',args.host,args.port)
         logger.info(received)
-    elif args.saveModels:
+    if args.saveModels:
         communicate('save',args.host,args.port)
 
 
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Aug 23 00:12:20 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Aug 23 13:30:25 2013 +0200
@@ -151,7 +151,7 @@
    xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
   handle IO.Io _ => ()
 
-fun run_mash_tool ctxt overlord extra_args write_cmds read_suggs =
+fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
   let
     val (temp_dir, serial) =
       if overlord then (getenv "ISABELLE_HOME_USER", "")
@@ -172,7 +172,8 @@
       " --dictsFile=" ^ model_dir ^ "/dict.pickle" ^
       " --log " ^ log_file ^ " " ^ core ^
       (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
-      " >& " ^ err_file
+      " >& " ^ err_file ^
+      (if background then " &" else "")
     fun run_on () =
       (Isabelle_System.bash command
        |> tap (fn _ => trace_msg ctxt (fn () =>
@@ -254,7 +255,10 @@
 struct
 
 fun shutdown ctxt overlord =
-  run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ())
+  run_mash_tool ctxt overlord [shutdown_server_arg] true ([], K "") (K ())
+
+fun save ctxt overlord =
+  run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())
 
 fun unlearn ctxt overlord =
   let val path = mash_model_dir () in
@@ -270,19 +274,19 @@
   | learn ctxt overlord learns =
     (trace_msg ctxt (fn () => "MaSh learn " ^
          elide_string 1000 (space_implode " " (map #1 learns)));
-     run_mash_tool ctxt overlord [save_models_arg] (learns, str_of_learn)
+     run_mash_tool ctxt overlord [] false (learns, str_of_learn)
                    (K ()))
 
 fun relearn _ _ [] = ()
   | relearn ctxt overlord relearns =
     (trace_msg ctxt (fn () => "MaSh relearn " ^
          elide_string 1000 (space_implode " " (map #1 relearns)));
-     run_mash_tool ctxt overlord [save_models_arg] (relearns, str_of_relearn)
+     run_mash_tool ctxt overlord [] false (relearns, str_of_relearn)
                    (K ()))
 
 fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
   (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
-   run_mash_tool ctxt overlord [] ([query], str_of_query max_suggs)
+   run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs)
        (fn suggs =>
            case suggs () of
              [] => []
@@ -359,8 +363,8 @@
      | _ => NONE)
   | _ => NONE
 
-fun load _ _ (state as (true, _)) = state
-  | load ctxt overlord _ =
+fun load_state _ _ (state as (true, _)) = state
+  | load_state ctxt overlord _ =
     let val path = mash_state_file () in
       (true,
        case try File.read_lines path of
@@ -394,8 +398,8 @@
        | _ => empty_state)
     end
 
-fun save _ (state as {dirty = SOME [], ...}) = state
-  | save ctxt {access_G, num_known_facts, dirty} =
+fun save_state _ (state as {dirty = SOME [], ...}) = state
+  | save_state ctxt {access_G, num_known_facts, dirty} =
     let
       fun str_of_entry (name, parents, kind) =
         str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
@@ -424,12 +428,13 @@
 in
 
 fun map_state ctxt overlord f =
-  Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt))
+  Synchronized.change global_state
+                      (load_state ctxt overlord ##> (f #> save_state ctxt))
   handle FILE_VERSION_TOO_NEW () => ()
 
 fun peek_state ctxt overlord f =
   Synchronized.change_result global_state
-      (perhaps (try (load ctxt overlord)) #> `snd #>> f)
+      (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
 
 fun clear_state ctxt overlord =
   Synchronized.change global_state (fn _ =>
@@ -1044,7 +1049,8 @@
                 used_ths |> filter (is_fact_in_graph access_G)
                          |> map nickname_of_thm
             in
-              MaSh.learn ctxt overlord [(name, parents, feats, deps)]
+              MaSh.learn ctxt overlord [(name, parents, feats, deps)];
+              MaSh.save ctxt overlord
             end);
         (true, "")
       end)
@@ -1056,7 +1062,7 @@
 
 (* The timeout is understood in a very relaxed fashion. *)
 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
-                     auto_level run_prover learn_timeout facts =
+                     save auto_level run_prover learn_timeout facts =
   let
     val timer = Timer.startRealTimer ()
     fun next_commit_time () =
@@ -1107,6 +1113,7 @@
             in
               MaSh.learn ctxt overlord (rev learns);
               MaSh.relearn ctxt overlord relearns;
+              if save then MaSh.save ctxt overlord else ();
               {access_G = access_G, num_known_facts = num_known_facts,
                dirty = dirty}
             end
@@ -1228,7 +1235,7 @@
     val num_facts = length facts
     val prover = hd provers
     fun learn auto_level run_prover =
-      mash_learn_facts ctxt params prover auto_level run_prover NONE facts
+      mash_learn_facts ctxt params prover true auto_level run_prover NONE facts
       |> Output.urgent_message
   in
     if run_prover then
@@ -1261,6 +1268,8 @@
 val mepo_weight = 0.5
 val mash_weight = 0.5
 
+val max_facts_to_learn_before_query = 100
+
 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
    Sledgehammer and Try. *)
 val min_secs_for_learning = 15
@@ -1278,28 +1287,45 @@
     [("", [])]
   else
     let
-      fun maybe_learn () =
-        if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
+      fun maybe_launch_thread () =
+        if not (Async_Manager.has_running_threads MaShN) andalso
            (timeout = NONE orelse
             Time.toSeconds (the timeout) >= min_secs_for_learning) then
           let
             val timeout = Option.map (time_mult learn_timeout_slack) timeout
           in
             launch_thread (timeout |> the_default one_day)
-                (fn () => (true, mash_learn_facts ctxt params prover 2 false
-                                                  timeout facts))
+                (fn () => (true, mash_learn_facts ctxt params prover true 2
+                                                  false timeout facts))
           end
         else
           ()
-      val effective_fact_filter =
+      fun maybe_learn () =
+        if learn then
+          let
+            val {access_G, num_known_facts, ...} = peek_state ctxt overlord I
+            val is_in_access_G = is_fact_in_graph access_G o snd
+          in
+            if length facts - num_known_facts <= max_facts_to_learn_before_query
+               andalso length (filter_out is_in_access_G facts)
+                       <= max_facts_to_learn_before_query then
+              (mash_learn_facts ctxt params prover false 2 false timeout facts
+               |> (fn "" => () | s => Output.urgent_message s);
+               true)
+            else
+              (maybe_launch_thread (); false)
+          end
+        else
+          false
+      val (save, effective_fact_filter) =
         case fact_filter of
-          SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
+          SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
         | NONE =>
           if is_mash_enabled () then
-            (maybe_learn ();
+            (maybe_learn (),
              if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
           else
-            mepoN
+            (false, mepoN)
       val add_ths = Attrib.eval_thms ctxt add
       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
       fun add_and_take accepts =
@@ -1330,6 +1356,7 @@
         mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess
         |> add_and_take
     in
+      if save then MaSh.save ctxt overlord else ();
       case (fact_filter, mess) of
         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),