src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50861 fa4253914e98
parent 50860 e32a283b8ce0
child 50869 67bb94a6f780
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jan 13 15:04:55 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jan 13 20:57:48 2013 +0100
@@ -150,7 +150,7 @@
     val core =
       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
       " --numberOfPredictions " ^ string_of_int max_suggs ^
-      " --learnTheories" ^
+      " --learnTheories --NBSinePrior" ^
       (if save then " --saveModel" else "")
     val command =
       "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^
@@ -457,11 +457,10 @@
 
 fun mesh_facts _ max_facts [(_, (sels, unks))] =
     map fst (take max_facts sels) @ take (max_facts - length sels) unks
-  | mesh_facts eq max_facts mess =
+  | mesh_facts fact_eq max_facts mess =
     let
       val mess =
         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
-      val fact_eq = eq
       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
         let
           fun score_at j =
@@ -471,7 +470,7 @@
         in
           case find_index (curry fact_eq fact o fst) sels of
             ~1 => (case find_index (curry fact_eq fact) unks of
-                     ~1 => score_at (sel_len - 1)
+                     ~1 => SOME 0.0
                    | _ => NONE)
           | rank => score_at rank
         end
@@ -760,8 +759,8 @@
                                       (Graph.imm_preds access_G new) news))
   in find_maxes Symtab.empty ([], Graph.maximals access_G) end
 
-fun is_fact_in_graph access_G (_, th) =
-  can (Graph.get_node access_G) (nickname_of_thm th)
+fun is_fact_in_graph access_G get_th fact =
+  can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
 
 val weight_raw_mash_facts = weight_mepo_facts
 val weight_mash_facts = weight_raw_mash_facts
@@ -782,14 +781,19 @@
             (* The weights currently returned by "mash.py" are too spaced out to
                make any sense. *)
             |> map fst
+    val unknown_chained =
+      inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
     val proximity =
       facts |> sort (thm_ord o pairself snd o swap)
             |> take max_proximity_facts
     val mess =
-      [(0.8 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
-       (0.2 (* FUDGE *), (weight_proximity_facts proximity, []))]
+      [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
+       (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
+       (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
     val unknown =
-      raw_unknown |> subtract (Thm.eq_thm_prop o pairself snd) proximity
+      raw_unknown
+      |> fold (subtract (Thm.eq_thm_prop o pairself snd))
+              [unknown_chained, proximity]
   in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
 
 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
@@ -806,13 +810,15 @@
               val parents = maximal_in_graph access_G facts
               val feats =
                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
-              val hints = map (nickname_of_thm o snd) chained
+              val hints =
+                chained |> filter (is_fact_in_graph access_G snd)
+                        |> map (nickname_of_thm o snd)
             in
               (access_G,
                MaSh.suggest ctxt overlord learn max_facts
                             (parents, feats, hints))
             end)
-    val unknown = facts |> filter_out (is_fact_in_graph access_G)
+    val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
   in find_mash_suggestions max_facts suggs facts chained unknown end
 
 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
@@ -849,22 +855,23 @@
 
 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
                      used_ths =
-  if is_smt_prover ctxt prover then
-    ()
-  else
-    launch_thread (timeout |> the_default one_day) (fn () =>
-        let
-          val thy = Proof_Context.theory_of ctxt
-          val name = freshish_name ()
-          val feats = features_of ctxt prover thy (Local, General) [t]
-          val deps = used_ths |> map nickname_of_thm
-        in
-          peek_state ctxt (fn {access_G, ...} =>
-              let val parents = maximal_in_graph access_G facts in
-                MaSh.learn ctxt overlord [(name, parents, feats, deps)]
-              end);
-          (true, "")
-        end)
+  launch_thread (timeout |> the_default one_day) (fn () =>
+      let
+        val thy = Proof_Context.theory_of ctxt
+        val name = freshish_name ()
+        val feats = features_of ctxt prover thy (Local, General) [t]
+      in
+        peek_state ctxt (fn {access_G, ...} =>
+            let
+              val parents = maximal_in_graph access_G facts
+              val deps =
+                used_ths |> filter (is_fact_in_graph access_G I)
+                         |> map nickname_of_thm
+            in
+              MaSh.learn ctxt overlord [(name, parents, feats, deps)]
+            end);
+        (true, "")
+      end)
 
 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub)
 
@@ -880,7 +887,7 @@
     val {access_G, ...} = peek_state ctxt I
     val facts = facts |> sort (thm_ord o pairself snd)
     val (old_facts, new_facts) =
-      facts |> List.partition (is_fact_in_graph access_G)
+      facts |> List.partition (is_fact_in_graph access_G snd)
   in
     if null new_facts andalso (not run_prover orelse null old_facts) then
       if auto_level < 2 then