148 val cmd_file = temp_dir ^ "/mash_commands" ^ serial |
148 val cmd_file = temp_dir ^ "/mash_commands" ^ serial |
149 val cmd_path = Path.explode cmd_file |
149 val cmd_path = Path.explode cmd_file |
150 val core = |
150 val core = |
151 "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^ |
151 "--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^ |
152 " --numberOfPredictions " ^ string_of_int max_suggs ^ |
152 " --numberOfPredictions " ^ string_of_int max_suggs ^ |
153 " --learnTheories" ^ |
153 " --learnTheories --NBSinePrior" ^ |
154 (if save then " --saveModel" else "") |
154 (if save then " --saveModel" else "") |
155 val command = |
155 val command = |
156 "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^ |
156 "\"$ISABELLE_SLEDGEHAMMER_MASH/src/mash.py\" --quiet --outputDir " ^ |
157 File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^ |
157 File.shell_path (mash_model_dir ()) ^ " --log " ^ log_file ^ " " ^ core ^ |
158 " >& " ^ err_file |
158 " >& " ^ err_file |
455 map (apsnd (curry Real.* (1.0 / avg))) xs |
455 map (apsnd (curry Real.* (1.0 / avg))) xs |
456 end |
456 end |
457 |
457 |
458 fun mesh_facts _ max_facts [(_, (sels, unks))] = |
458 fun mesh_facts _ max_facts [(_, (sels, unks))] = |
459 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
459 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
460 | mesh_facts eq max_facts mess = |
460 | mesh_facts fact_eq max_facts mess = |
461 let |
461 let |
462 val mess = |
462 val mess = |
463 mess |> map (apsnd (apfst (normalize_scores max_facts #> `length))) |
463 mess |> map (apsnd (apfst (normalize_scores max_facts #> `length))) |
464 val fact_eq = eq |
|
465 fun score_in fact (global_weight, ((sel_len, sels), unks)) = |
464 fun score_in fact (global_weight, ((sel_len, sels), unks)) = |
466 let |
465 let |
467 fun score_at j = |
466 fun score_at j = |
468 case try (nth sels) j of |
467 case try (nth sels) j of |
469 SOME (_, score) => SOME (global_weight * score) |
468 SOME (_, score) => SOME (global_weight * score) |
470 | NONE => NONE |
469 | NONE => NONE |
471 in |
470 in |
472 case find_index (curry fact_eq fact o fst) sels of |
471 case find_index (curry fact_eq fact o fst) sels of |
473 ~1 => (case find_index (curry fact_eq fact) unks of |
472 ~1 => (case find_index (curry fact_eq fact) unks of |
474 ~1 => score_at (sel_len - 1) |
473 ~1 => SOME 0.0 |
475 | _ => NONE) |
474 | _ => NONE) |
476 | rank => score_at rank |
475 | rank => score_at rank |
477 end |
476 end |
478 fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg |
477 fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg |
479 val facts = |
478 val facts = |
758 else |
757 else |
759 (maxs, Graph.Keys.fold (insert_new seen) |
758 (maxs, Graph.Keys.fold (insert_new seen) |
760 (Graph.imm_preds access_G new) news)) |
759 (Graph.imm_preds access_G new) news)) |
761 in find_maxes Symtab.empty ([], Graph.maximals access_G) end |
760 in find_maxes Symtab.empty ([], Graph.maximals access_G) end |
762 |
761 |
763 fun is_fact_in_graph access_G (_, th) = |
762 fun is_fact_in_graph access_G get_th fact = |
764 can (Graph.get_node access_G) (nickname_of_thm th) |
763 can (Graph.get_node access_G) (nickname_of_thm (get_th fact)) |
765 |
764 |
766 val weight_raw_mash_facts = weight_mepo_facts |
765 val weight_raw_mash_facts = weight_mepo_facts |
767 val weight_mash_facts = weight_raw_mash_facts |
766 val weight_mash_facts = weight_raw_mash_facts |
768 |
767 |
769 (* FUDGE *) |
768 (* FUDGE *) |
780 val raw_mash = |
779 val raw_mash = |
781 facts |> find_suggested_facts suggs |
780 facts |> find_suggested_facts suggs |
782 (* The weights currently returned by "mash.py" are too spaced out to |
781 (* The weights currently returned by "mash.py" are too spaced out to |
783 make any sense. *) |
782 make any sense. *) |
784 |> map fst |
783 |> map fst |
|
784 val unknown_chained = |
|
785 inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown |
785 val proximity = |
786 val proximity = |
786 facts |> sort (thm_ord o pairself snd o swap) |
787 facts |> sort (thm_ord o pairself snd o swap) |
787 |> take max_proximity_facts |
788 |> take max_proximity_facts |
788 val mess = |
789 val mess = |
789 [(0.8 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)), |
790 [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])), |
790 (0.2 (* FUDGE *), (weight_proximity_facts proximity, []))] |
791 (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)), |
|
792 (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))] |
791 val unknown = |
793 val unknown = |
792 raw_unknown |> subtract (Thm.eq_thm_prop o pairself snd) proximity |
794 raw_unknown |
|
795 |> fold (subtract (Thm.eq_thm_prop o pairself snd)) |
|
796 [unknown_chained, proximity] |
793 in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end |
797 in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end |
794 |
798 |
795 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts |
799 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts |
796 hyp_ts concl_t facts = |
800 hyp_ts concl_t facts = |
797 let |
801 let |
804 else |
808 else |
805 let |
809 let |
806 val parents = maximal_in_graph access_G facts |
810 val parents = maximal_in_graph access_G facts |
807 val feats = |
811 val feats = |
808 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) |
812 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts) |
809 val hints = map (nickname_of_thm o snd) chained |
813 val hints = |
|
814 chained |> filter (is_fact_in_graph access_G snd) |
|
815 |> map (nickname_of_thm o snd) |
810 in |
816 in |
811 (access_G, |
817 (access_G, |
812 MaSh.suggest ctxt overlord learn max_facts |
818 MaSh.suggest ctxt overlord learn max_facts |
813 (parents, feats, hints)) |
819 (parents, feats, hints)) |
814 end) |
820 end) |
815 val unknown = facts |> filter_out (is_fact_in_graph access_G) |
821 val unknown = facts |> filter_out (is_fact_in_graph access_G snd) |
816 in find_mash_suggestions max_facts suggs facts chained unknown end |
822 in find_mash_suggestions max_facts suggs facts chained unknown end |
817 |
823 |
818 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) = |
824 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, graph) = |
819 let |
825 let |
820 fun maybe_learn_from from (accum as (parents, graph)) = |
826 fun maybe_learn_from from (accum as (parents, graph)) = |
847 val desc = ("Machine learner for Sledgehammer", "") |
853 val desc = ("Machine learner for Sledgehammer", "") |
848 in Async_Manager.launch MaShN birth_time death_time desc task end |
854 in Async_Manager.launch MaShN birth_time death_time desc task end |
849 |
855 |
850 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts |
856 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts |
851 used_ths = |
857 used_ths = |
852 if is_smt_prover ctxt prover then |
858 launch_thread (timeout |> the_default one_day) (fn () => |
853 () |
859 let |
854 else |
860 val thy = Proof_Context.theory_of ctxt |
855 launch_thread (timeout |> the_default one_day) (fn () => |
861 val name = freshish_name () |
856 let |
862 val feats = features_of ctxt prover thy (Local, General) [t] |
857 val thy = Proof_Context.theory_of ctxt |
863 in |
858 val name = freshish_name () |
864 peek_state ctxt (fn {access_G, ...} => |
859 val feats = features_of ctxt prover thy (Local, General) [t] |
865 let |
860 val deps = used_ths |> map nickname_of_thm |
866 val parents = maximal_in_graph access_G facts |
861 in |
867 val deps = |
862 peek_state ctxt (fn {access_G, ...} => |
868 used_ths |> filter (is_fact_in_graph access_G I) |
863 let val parents = maximal_in_graph access_G facts in |
869 |> map nickname_of_thm |
864 MaSh.learn ctxt overlord [(name, parents, feats, deps)] |
870 in |
865 end); |
871 MaSh.learn ctxt overlord [(name, parents, feats, deps)] |
866 (true, "") |
872 end); |
867 end) |
873 (true, "") |
|
874 end) |
868 |
875 |
869 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub) |
876 fun sendback sub = Active.sendback_markup (sledgehammerN ^ " " ^ sub) |
870 |
877 |
871 val commit_timeout = seconds 30.0 |
878 val commit_timeout = seconds 30.0 |
872 |
879 |
878 fun next_commit_time () = |
885 fun next_commit_time () = |
879 Time.+ (Timer.checkRealTimer timer, commit_timeout) |
886 Time.+ (Timer.checkRealTimer timer, commit_timeout) |
880 val {access_G, ...} = peek_state ctxt I |
887 val {access_G, ...} = peek_state ctxt I |
881 val facts = facts |> sort (thm_ord o pairself snd) |
888 val facts = facts |> sort (thm_ord o pairself snd) |
882 val (old_facts, new_facts) = |
889 val (old_facts, new_facts) = |
883 facts |> List.partition (is_fact_in_graph access_G) |
890 facts |> List.partition (is_fact_in_graph access_G snd) |
884 in |
891 in |
885 if null new_facts andalso (not run_prover orelse null old_facts) then |
892 if null new_facts andalso (not run_prover orelse null old_facts) then |
886 if auto_level < 2 then |
893 if auto_level < 2 then |
887 "No new " ^ (if run_prover then "automatic" else "Isar") ^ |
894 "No new " ^ (if run_prover then "automatic" else "Isar") ^ |
888 " proofs to learn." ^ |
895 " proofs to learn." ^ |