src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50861 fa4253914e98
parent 50860 e32a283b8ce0
child 50869 67bb94a6f780
equal deleted inserted replaced
50860:e32a283b8ce0 50861:fa4253914e98
   148     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   148     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   149     val cmd_path = Path.explode cmd_file
   149     val cmd_path = Path.explode cmd_file
   150     val core =
   150     val core =
   151       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   151       "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   152       " --numberOfPredictions " ^ string_of_int max_suggs ^
   152       " --numberOfPredictions " ^ string_of_int max_suggs ^
   153       " --learnTheories" ^
   153       " --learnTheories --NBSinePrior" ^
   154       (if save then " --saveModel" else "")
   154       (if save then " --saveModel" else "")
   155     val command =
   155     val command =
   156       "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^
   156       "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^
   157       File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^
   157       File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^
   158       " >& " ^ err_file
   158       " >& " ^ err_file
   455       map (apsnd (curry Real.* (1.0 / avg))) xs
   455       map (apsnd (curry Real.* (1.0 / avg))) xs
   456     end
   456     end
   457 
   457 
   458 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   458 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   459     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   459     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   460   | mesh_facts eq max_facts mess =
   460   | mesh_facts fact_eq max_facts mess =
   461     let
   461     let
   462       val mess =
   462       val mess =
   463         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   463         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   464       val fact_eq = eq
       
   465       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   464       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   466         let
   465         let
   467           fun score_at j =
   466           fun score_at j =
   468             case try (nth sels) j of
   467             case try (nth sels) j of
   469               SOME (_, score) => SOME (global_weight * score)
   468               SOME (_, score) => SOME (global_weight * score)
   470             | NONE => NONE
   469             | NONE => NONE
   471         in
   470         in
   472           case find_index (curry fact_eq fact o fst) sels of
   471           case find_index (curry fact_eq fact o fst) sels of
   473             ~1 => (case find_index (curry fact_eq fact) unks of
   472             ~1 => (case find_index (curry fact_eq fact) unks of
   474                      ~1 => score_at (sel_len - 1)
   473                      ~1 => SOME 0.0
   475                    | _ => NONE)
   474                    | _ => NONE)
   476           | rank => score_at rank
   475           | rank => score_at rank
   477         end
   476         end
   478       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   477       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   479       val facts =
   478       val facts =
   758              else
   757              else
   759                (maxs, Graph.Keys.fold (insert_new seen)
   758                (maxs, Graph.Keys.fold (insert_new seen)
   760                                       (Graph.imm_preds access_G new) news))
   759                                       (Graph.imm_preds access_G new) news))
   761   in find_maxes Symtab.empty ([], Graph.maximals access_G) end
   760   in find_maxes Symtab.empty ([], Graph.maximals access_G) end
   762 
   761 
   763 fun is_fact_in_graph access_G (_, th) =
   762 fun is_fact_in_graph access_G get_th fact =
   764   can (Graph.get_node access_G) (nickname_of_thm th)
   763   can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
   765 
   764 
   766 val weight_raw_mash_facts = weight_mepo_facts
   765 val weight_raw_mash_facts = weight_mepo_facts
   767 val weight_mash_facts = weight_raw_mash_facts
   766 val weight_mash_facts = weight_raw_mash_facts
   768 
   767 
   769 (* FUDGE *)
   768 (* FUDGE *)
   780     val raw_mash =
   779     val raw_mash =
   781       facts |> find_suggested_facts suggs
   780       facts |> find_suggested_facts suggs
   782             (* The weights currently returned by "mash.py" are too spaced out to
   781             (* The weights currently returned by "mash.py" are too spaced out to
   783                make any sense. *)
   782                make any sense. *)
   784             |> map fst
   783             |> map fst
       
   784     val unknown_chained =
       
   785       inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
   785     val proximity =
   786     val proximity =
   786       facts |> sort (thm_ord o pairself snd o swap)
   787       facts |> sort (thm_ord o pairself snd o swap)
   787             |> take max_proximity_facts
   788             |> take max_proximity_facts
   788     val mess =
   789     val mess =
   789       [(0.8 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   790       [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   790        (0.2 (* FUDGE *), (weight_proximity_facts proximity, []))]
   791        (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
       
   792        (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
   791     val unknown =
   793     val unknown =
   792       raw_unknown |> subtract (Thm.eq_thm_prop o pairself snd) proximity
   794       raw_unknown
       
   795       |> fold (subtract (Thm.eq_thm_prop o pairself snd))
       
   796               [unknown_chained, proximity]
   793   in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   797   in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   794 
   798 
   795 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
   799 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
   796                          hyp_ts concl_t facts =
   800                          hyp_ts concl_t facts =
   797   let
   801   let
   804           else
   808           else
   805             let
   809             let
   806               val parents = maximal_in_graph access_G facts
   810               val parents = maximal_in_graph access_G facts
   807               val feats =
   811               val feats =
   808                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   812                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   809               val hints = map (nickname_of_thm o snd) chained
   813               val hints =
       
   814                 chained |> filter (is_fact_in_graph access_G snd)
       
   815                         |> map (nickname_of_thm o snd)
   810             in
   816             in
   811               (access_G,
   817               (access_G,
   812                MaSh.suggest ctxt overlord learn max_facts
   818                MaSh.suggest ctxt overlord learn max_facts
   813                             (parents, feats, hints))
   819                             (parents, feats, hints))
   814             end)
   820             end)
   815     val unknown = facts |> filter_out (is_fact_in_graph access_G)
   821     val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
   816   in find_mash_suggestions max_facts suggs facts chained unknown end
   822   in find_mash_suggestions max_facts suggs facts chained unknown end
   817 
   823 
   818 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
   824 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) =
   819   let
   825   let
   820     fun maybe_learn_from from (accum as (parents, graph)) =
   826     fun maybe_learn_from from (accum as (parents, graph)) =
   847     val desc = ("Machine learner for Sledgehammer", "")
   853     val desc = ("Machine learner for Sledgehammer", "")
   848   in Async_Manager.launch MaShN birth_time death_time desc task end
   854   in Async_Manager.launch MaShN birth_time death_time desc task end
   849 
   855 
   850 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   856 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   851                      used_ths =
   857                      used_ths =
   852   if is_smt_prover ctxt prover then
   858   launch_thread (timeout |> the_default one_day) (fn () =>
   853     ()
   859       let
   854   else
   860         val thy = Proof_Context.theory_of ctxt
   855     launch_thread (timeout |> the_default one_day) (fn () =>
   861         val name = freshish_name ()
   856         let
   862         val feats = features_of ctxt prover thy (Local, General) [t]
   857           val thy = Proof_Context.theory_of ctxt
   863       in
   858           val name = freshish_name ()
   864         peek_state ctxt (fn {access_G, ...} =>
   859           val feats = features_of ctxt prover thy (Local, General) [t]
   865             let
   860           val deps = used_ths |> map nickname_of_thm
   866               val parents = maximal_in_graph access_G facts
   861         in
   867               val deps =
   862           peek_state ctxt (fn {access_G, ...} =>
   868                 used_ths |> filter (is_fact_in_graph access_G I)
   863               let val parents = maximal_in_graph access_G facts in
   869                          |> map nickname_of_thm
   864                 MaSh.learn ctxt overlord [(name, parents, feats, deps)]
   870             in
   865               end);
   871               MaSh.learn ctxt overlord [(name, parents, feats, deps)]
   866           (true, "")
   872             end);
   867         end)
   873         (true, "")
       
   874       end)
   868 
   875 
   869 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub)
   876 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub)
   870 
   877 
   871 val commit_timeout = seconds 30.0
   878 val commit_timeout = seconds 30.0
   872 
   879 
   878     fun next_commit_time () =
   885     fun next_commit_time () =
   879       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   886       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   880     val {access_G, ...} = peek_state ctxt I
   887     val {access_G, ...} = peek_state ctxt I
   881     val facts = facts |> sort (thm_ord o pairself snd)
   888     val facts = facts |> sort (thm_ord o pairself snd)
   882     val (old_facts, new_facts) =
   889     val (old_facts, new_facts) =
   883       facts |> List.partition (is_fact_in_graph access_G)
   890       facts |> List.partition (is_fact_in_graph access_G snd)
   884   in
   891   in
   885     if null new_facts andalso (not run_prover orelse null old_facts) then
   892     if null new_facts andalso (not run_prover orelse null old_facts) then
   886       if auto_level < 2 then
   893       if auto_level < 2 then
   887         "No new " ^ (if run_prover then "automatic" else "Isar") ^
   894         "No new " ^ (if run_prover then "automatic" else "Isar") ^
   888         " proofs to learn." ^
   895         " proofs to learn." ^