relearn ATP proofs
authorblanchet
Fri, 20 Jul 2012 22:19:46 +0200
changeset 48404 0a261b4aa093
parent 48403 1f214c653c80
child 48405 7682bc885e8a
relearn ATP proofs
src/HOL/TPTP/mash_eval.ML
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/TPTP/mash_eval.ML	Fri Jul 20 22:19:46 2012 +0200
+++ b/src/HOL/TPTP/mash_eval.ML	Fri Jul 20 22:19:46 2012 +0200
@@ -75,7 +75,7 @@
           | NONE => error ("No fact called \"" ^ name ^ "\"")
         val goal = goal_of_thm thy th
         val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
-        val isar_deps = isar_dependencies_of all_names th
+        val isar_deps = isar_dependencies_of all_names th |> these
         val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
         val mepo_facts =
           Sledgehammer_MePo.iterative_relevant_facts ctxt params prover
--- a/src/HOL/TPTP/mash_export.ML	Fri Jul 20 22:19:46 2012 +0200
+++ b/src/HOL/TPTP/mash_export.ML	Fri Jul 20 22:19:46 2012 +0200
@@ -102,7 +102,7 @@
     fun do_thm th =
       let
         val name = nickname_of th
-        val deps = isar_dependencies_of all_names th
+        val deps = isar_dependencies_of all_names th |> these
         val s = escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
       in File.append path s end
   in List.app do_thm ths end
@@ -122,7 +122,9 @@
       let
         val name = nickname_of th
         val deps =
-          atp_dependencies_of ctxt params prover false facts all_names th
+          case atp_dependencies_of ctxt params prover 0 facts all_names th of
+            SOME deps => deps
+          | NONE => isar_dependencies_of all_names th |> these
         val s = escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
       in File.append path s end
   in List.app do_thm ths end
@@ -142,7 +144,7 @@
       let
         val name = nickname_of th
         val feats = features_of ctxt prover thy stature [prop_of th]
-        val deps = isar_dependencies_of all_names th
+        val deps = isar_dependencies_of all_names th |> these
         val kind = Thm.legacy_get_kind th
         val core = escape_metas prevs ^ "; " ^ escape_metas feats
         val query = if kind <> "" then "? " ^ core ^ "\n" else ""
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
@@ -41,14 +41,16 @@
     Proof.context -> params -> string -> fact list -> thm -> prover_result
   val features_of :
     Proof.context -> string -> theory -> stature -> term list -> string list
-  val isar_dependencies_of : unit Symtab.table -> thm -> string list
+  val isar_dependencies_of : unit Symtab.table -> thm -> string list option
   val atp_dependencies_of :
-    Proof.context -> params -> string -> bool -> fact list -> unit Symtab.table
-    -> thm -> string list
+    Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
+    -> thm -> string list option
   val mash_CLEAR : Proof.context -> unit
   val mash_ADD :
     Proof.context -> bool
     -> (string * string list * string list * string list) list -> unit
+  val mash_REPROVE :
+    Proof.context -> bool -> (string * string list) list -> unit
   val mash_QUERY :
     Proof.context -> bool -> int -> string list * string list -> string list
   val mash_unlearn : Proof.context -> unit
@@ -60,9 +62,6 @@
   val mash_learn_proof :
     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
     -> unit
-  val mash_learn_facts :
-    Proof.context -> params -> string -> bool -> bool -> Time.time -> fact list
-    -> string
   val mash_learn :
     Proof.context -> params -> fact_override -> thm list -> bool -> unit
   val relevant_facts :
@@ -320,14 +319,21 @@
       | Simp => cons "simp"
       | Def => cons "def")
 
-fun isar_dependencies_of all_facts = thms_in_proof (SOME all_facts)
+(* Too many dependencies is a sign that a decision procedure is at work. There
+   isn't much too learn from such proofs. *)
+val max_dependencies = 10
+val atp_dependency_default_max_fact = 50
 
-val atp_dep_default_max_fact = 50
+fun trim_dependencies deps =
+  if length deps <= max_dependencies then SOME deps else NONE
 
-fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover auto
-                        facts all_names th =
+fun isar_dependencies_of all_facts =
+  thms_in_proof (SOME all_facts) #> trim_dependencies
+
+fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
+                        auto_level facts all_names th =
   case isar_dependencies_of all_names th of
-    [] => []
+    SOME [] => NONE
   | isar_deps =>
     let
       val thy = Proof_Context.theory_of ctxt
@@ -344,12 +350,12 @@
         | NONE => accum (* shouldn't happen *)
       val facts =
         facts |> iterative_relevant_facts ctxt params prover
-                     (max_facts |> the_default atp_dep_default_max_fact) NONE
-                     hyp_ts concl_t
-              |> fold (add_isar_dep facts) isar_deps
+                     (max_facts |> the_default atp_dependency_default_max_fact)
+                     NONE hyp_ts concl_t
+              |> fold (add_isar_dep facts) (these isar_deps)
               |> map fix_name
     in
-      if verbose andalso not auto then
+      if verbose andalso auto_level = 0 then
         let val num_facts = length facts in
           "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of th) ^
           " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
@@ -360,7 +366,7 @@
         ();
       case run_prover_for_mash ctxt params prover facts goal of
         {outcome = NONE, used_facts, ...} =>
-        (if verbose andalso not auto then
+        (if verbose andalso auto_level = 0 then
            let val num_facts = length used_facts in
              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
              plural_s num_facts ^ "."
@@ -368,8 +374,8 @@
            end
          else
            ();
-         used_facts |> map fst)
-      | _ => isar_deps
+         used_facts |> map fst |> trim_dependencies)
+      | _ => NONE
     end
 
 
@@ -418,10 +424,13 @@
                                [err_file, sugg_file, cmd_file])
   end
 
-fun str_of_update (name, parents, feats, deps) =
+fun str_of_add (name, parents, feats, deps) =
   "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
 
+fun str_of_reprove (name, deps) =
+  "p " ^ escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
+
 fun str_of_query (parents, feats) =
   "? " ^ escape_metas parents ^ "; " ^ escape_metas feats
 
@@ -435,10 +444,16 @@
   end
 
 fun mash_ADD _ _ [] = ()
-  | mash_ADD ctxt overlord upds =
+  | mash_ADD ctxt overlord adds =
     (trace_msg ctxt (fn () => "MaSh ADD " ^
-         elide_string 1000 (space_implode " " (map #1 upds)));
-     run_mash_tool ctxt overlord true 0 (upds, str_of_update) (K ()))
+         elide_string 1000 (space_implode " " (map #1 adds)));
+     run_mash_tool ctxt overlord true 0 (adds, str_of_add) (K ()))
+
+fun mash_REPROVE _ _ [] = ()
+  | mash_REPROVE ctxt overlord reps =
+    (trace_msg ctxt (fn () => "MaSh REPROVE " ^
+         elide_string 1000 (space_implode " " (map #1 reps)));
+     run_mash_tool ctxt overlord true 0 (reps, str_of_reprove) (K ()))
 
 fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
@@ -584,7 +599,7 @@
     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   in (selected, unknown) end
 
-fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
+fun add_to_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   let
     fun maybe_add_from from (accum as (parents, graph)) =
       try_graph ctxt "updating graph" accum (fn () =>
@@ -592,7 +607,7 @@
     val graph = graph |> Graph.default_node (name, ())
     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
-  in ((name, parents, feats, deps) :: upds, graph) end
+  in ((name, parents, feats, deps) :: adds, graph) end
 
 val learn_timeout_slack = 2.0
 
@@ -628,14 +643,11 @@
 fun sendback sub =
   Markup.markup Isabelle_Markup.sendback (sledgehammerN ^ " " ^ sub)
 
-(* Too many dependencies is a sign that a decision procedure is at work. There
-   isn't much too learn from such proofs. *)
-val max_dependencies = 10
 val commit_timeout = seconds 30.0
 
 (* The timeout is understood in a very slack fashion. *)
-fun mash_learn_facts ctxt (params as {debug, verbose, overlord, timeout, ...})
-                     prover auto atp learn_timeout facts =
+fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
+                     auto_level atp learn_timeout facts =
   let
     val timer = Timer.startRealTimer ()
     fun next_commit_time () =
@@ -644,86 +656,123 @@
     val (old_facts, new_facts) =
       facts |> List.partition (is_fact_in_graph fact_G)
             ||> sort (thm_ord o pairself snd)
-    val num_new_facts = length new_facts
   in
-    (if not auto then
-       "MaShing" ^
-       (if not auto then
-          " " ^ string_of_int num_new_facts ^ " fact" ^
-          plural_s num_new_facts ^
-          (if atp then " (ATP timeout: " ^ string_from_time timeout ^ ")"
-           else "")
-        else
-          "") ^ "..."
-     else
-       "")
-    |> Output.urgent_message;
-    if num_new_facts = 0 then
-      if not auto then
-        "Nothing to learn.\n\nHint: Try " ^ sendback relearn_isarN ^ " or " ^
-        sendback relearn_atpN ^ " to learn from scratch."
+    if null new_facts andalso (not atp orelse null old_facts) then
+      if auto_level < 2 then
+        "No new " ^ (if atp then "ATP" else "Isar") ^ " proofs to learn." ^
+        (if auto_level = 0 andalso not atp then
+           "\n\nHint: Try " ^ sendback learn_atpN ^ " to learn from ATP proofs."
+         else
+           "")
       else
         ""
     else
       let
-        val last_th = new_facts |> List.last |> snd
-        (* crude approximation *)
-        val ancestors =
-          old_facts |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
         val all_names =
           facts |> map snd
                 |> filter_out is_likely_tautology_or_too_meta
                 |> map (rpair () o nickname_of)
                 |> Symtab.make
-        fun do_commit [] state = state
-          | do_commit upds {fact_G} =
+        val deps_of =
+          if atp then
+            atp_dependencies_of ctxt params prover auto_level facts all_names
+          else
+            isar_dependencies_of all_names
+        fun do_commit [] [] state = state
+          | do_commit adds reps {fact_G} =
             let
-              val (upds, fact_G) =
-                ([], fact_G) |> fold (update_fact_graph ctxt) upds
-            in mash_ADD ctxt overlord (rev upds); {fact_G = fact_G} end
-        fun trim_deps deps = if length deps > max_dependencies then [] else deps
-        fun commit last upds =
-          (if debug andalso not auto then Output.urgent_message "Committing..."
-           else ();
-           mash_map ctxt (do_commit (rev upds));
-           if not last andalso not auto then
-             let val num_upds = length upds in
-               "Processed " ^ string_of_int num_upds ^ " fact" ^
-               plural_s num_upds ^ " in the last " ^
+              val (adds, fact_G) =
+                ([], fact_G) |> fold (add_to_fact_graph ctxt) adds
+            in
+              mash_ADD ctxt overlord (rev adds);
+              mash_REPROVE ctxt overlord reps;
+              {fact_G = fact_G}
+            end
+        fun commit last adds reps =
+          (if debug andalso auto_level = 0 then
+             Output.urgent_message "Committing..."
+           else
+             ();
+           mash_map ctxt (do_commit (rev adds) reps);
+           if not last andalso auto_level = 0 then
+             let val num_proofs = length adds + length reps in
+               "Learned " ^ string_of_int num_proofs ^ " " ^
+               (if atp then "ATP" else "Isar") ^ " proof" ^
+               plural_s num_proofs ^ " in the last " ^
                string_from_time commit_timeout ^ "."
                |> Output.urgent_message
              end
            else
              ())
-        fun do_fact _ (accum as (_, (_, _, _, true))) = accum
-          | do_fact ((_, stature), th)
-                    (upds, (parents, n, next_commit, false)) =
+        fun learn_new_fact _ (accum as (_, (_, _, _, true))) = accum
+          | learn_new_fact ((_, stature), th)
+                           (adds, (parents, n, next_commit, _)) =
             let
               val name = nickname_of th
               val feats =
                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
-              val deps =
-                (if atp then atp_dependencies_of ctxt params prover auto facts
-                 else isar_dependencies_of) all_names th
-                |> trim_deps
+              val deps = deps_of th |> these
               val n = n |> not (null deps) ? Integer.add 1
-              val upds = (name, parents, feats, deps) :: upds
-              val (upds, next_commit) =
+              val adds = (name, parents, feats, deps) :: adds
+              val (adds, next_commit) =
                 if Time.> (Timer.checkRealTimer timer, next_commit) then
-                  (commit false upds; ([], next_commit_time ()))
+                  (commit false adds []; ([], next_commit_time ()))
                 else
-                  (upds, next_commit)
-              val timed_out =
-                Time.> (Timer.checkRealTimer timer, learn_timeout)
-            in (upds, ([name], n, next_commit, timed_out)) end
-        val parents = max_facts_in_graph fact_G ancestors
-        val (upds, (_, n, _, _)) =
-          ([], (parents, 0, next_commit_time (), false))
-          |> fold do_fact new_facts
+                  (adds, next_commit)
+              val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
+            in (adds, ([name], n, next_commit, timed_out)) end
+        val n =
+          if null new_facts then
+            0
+          else
+            let
+              val last_th = new_facts |> List.last |> snd
+              (* crude approximation *)
+              val ancestors =
+                old_facts
+                |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
+              val parents = max_facts_in_graph fact_G ancestors
+              val (adds, (_, n, _, _)) =
+                ([], (parents, 0, next_commit_time (), false))
+                |> fold learn_new_fact new_facts
+            in commit true adds []; n end
+        fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
+          | relearn_old_fact (_, th) (reps, (n, next_commit, _)) =
+            let
+              val name = nickname_of th
+              val (n, reps) =
+                case deps_of th of
+                  SOME deps => (n + 1, (name, deps) :: reps)
+                | NONE => (n, reps)
+              val (reps, next_commit) =
+                if Time.> (Timer.checkRealTimer timer, next_commit) then
+                  (commit false [] reps; ([], next_commit_time ()))
+                else
+                  (reps, next_commit)
+              val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
+            in (reps, (n, next_commit, timed_out)) end
+        val n =
+          if null old_facts then
+            n
+          else
+            let
+              fun score_of (_, th) =
+                random_range 0 (1000 * max_dependencies)
+                - 500 * (th |> isar_dependencies_of all_names
+                            |> Option.map length
+                            |> the_default max_dependencies)
+              val old_facts =
+                old_facts |> map (`score_of)
+                          |> sort (int_ord o pairself fst)
+                          |> map snd
+              val (reps, (n, _, _)) =
+                ([], (n, next_commit_time (), false))
+                |> fold relearn_old_fact old_facts
+            in commit true [] reps; n end
       in
-        commit true upds;
-        if verbose orelse not auto then
-          "Learned " ^ string_of_int n ^ " nontrivial proof" ^ plural_s n ^
+        if verbose orelse auto_level < 2 then
+          "Learned " ^ string_of_int n ^ " nontrivial " ^
+          (if atp then "ATP" else "Isar") ^ " proof" ^ plural_s n ^
           (if verbose then
              " in " ^ string_from_time (Timer.checkRealTimer timer)
            else
@@ -733,16 +782,35 @@
       end
   end
 
-fun mash_learn ctxt (params as {provers, ...}) fact_override chained atp =
+fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained
+               atp =
   let
     val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
     val ctxt = ctxt |> Config.put instantiate_inducts false
     val facts =
       nearly_all_facts ctxt false fact_override Symtab.empty css chained []
                        @{prop True}
+    val num_facts = length facts
+    val prover = hd provers
+    fun learn auto_level atp =
+      mash_learn_facts ctxt params prover auto_level atp infinite_timeout facts
+      |> Output.urgent_message
   in
-     mash_learn_facts ctxt params (hd provers) false atp infinite_timeout facts
-     |> Output.urgent_message
+    (if atp then
+       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+        plural_s num_facts ^ " for ATP proofs (" ^ quote prover ^ " timeout: " ^
+        string_from_time timeout ^ ").\n\nCollecting Isar proofs first..."
+        |> Output.urgent_message;
+        learn 1 false;
+        "Now collecting ATP proofs. This may take several hours. You can \
+        \safely stop the learning process at any point."
+        |> Output.urgent_message;
+        learn 0 true)
+     else
+       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
+        plural_s num_facts ^ " for Isar proofs..."
+        |> Output.urgent_message;
+        learn 0 false))
   end
 
 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
@@ -764,7 +832,7 @@
            Time.toSeconds timeout >= min_secs_for_learning then
           let val timeout = time_mult learn_timeout_slack timeout in
             launch_thread timeout
-                (fn () => (true, mash_learn_facts ctxt params prover true false
+                (fn () => (true, mash_learn_facts ctxt params prover 2 false
                                                   timeout facts))
           end
         else