src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57013 ed95456499e6
parent 57012 43fd82a537a3
child 57014 b7999893ffcc
equal deleted inserted replaced
57012:43fd82a537a3 57013:ed95456499e6
   105 val unlearnN = "unlearn"
   105 val unlearnN = "unlearn"
   106 val learn_isarN = "learn_isar"
   106 val learn_isarN = "learn_isar"
   107 val learn_proverN = "learn_prover"
   107 val learn_proverN = "learn_prover"
   108 val relearn_isarN = "relearn_isar"
   108 val relearn_isarN = "relearn_isar"
   109 val relearn_proverN = "relearn_prover"
   109 val relearn_proverN = "relearn_prover"
   110 
       
   111 val learned_proof_prefix = ".."
       
   112 
       
   113 fun learned_proof_name () =
       
   114   learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
       
   115   serial_string ()
       
   116 
   110 
   117 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   111 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
   118 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   112 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
   119 
   113 
   120 fun wipe_out_mash_state_dir () =
   114 fun wipe_out_mash_state_dir () =
   380       end
   374       end
   381     else ()
   375     else ()
   382   end
   376   end
   383 
   377 
   384 (*
   378 (*
   385   avail_no = maximum number of theorems to check dependencies and symbols
   379   avail_num = maximum number of theorems to check dependencies and symbols
       
   380   adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
   386   get_deps = returns dependencies of a theorem
   381   get_deps = returns dependencies of a theorem
   387   get_sym_ths = get theorems that have this feature
   382   get_sym_ths = get theorems that have this feature
   388   knns    = number of nearest neighbours
   383   knns = number of nearest neighbours
   389   advno   = number of predictions to return
   384   advno = number of predictions to return
   390   syms    = symbols of the conjecture
   385   syms = symbols of the conjecture
   391 *)
   386 *)
   392 fun knn avail_no get_deps get_sym_ths knns advno syms =
   387 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   393   let
   388   let
   394     (* Can be later used for TFIDF *)
   389     (* Can be later used for TFIDF *)
   395     fun sym_wght _ = 1.0
   390     fun sym_wght _ = 1.0;
   396 
   391     val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
   397     val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)))
       
   398 
       
   399     fun inc_overlap j v =
   392     fun inc_overlap j v =
   400       let
   393       let
   401         val ov = snd (Array.sub (overlaps_sqr,j))
   394         val ov = snd (Array.sub (overlaps_sqr,j))
   402       in
   395       in
   403         Array.update (overlaps_sqr, j, (j, v + ov))
   396         Array.update (overlaps_sqr, j, (j, v + ov))
   404       end
   397       end;
   405 
       
   406     fun do_sym (s, con_wght) =
   398     fun do_sym (s, con_wght) =
   407       let
   399       let
   408         val sw = sym_wght s
   400         val sw = sym_wght s;
   409         val w2 = sw * sw * con_wght
   401         val w2 = sw * sw * con_wght;
   410         fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
   402         fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
   411       in
   403       in
   412         ignore (map do_th (get_sym_ths s))
   404         ignore (map do_th (get_sym_ths s))
   413       end
   405       end;
   414 
   406     val () = ignore (map do_sym syms);
   415     val _ = ignore (map do_sym syms)
   407     val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
   416     val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
   408     val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
   417     val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)))
       
   418 
       
   419     fun inc_recommend j v =
   409     fun inc_recommend j v =
       
   410       if j >= adv_max then () else
   420       let
   411       let
   421         val ov = snd (Array.sub (recommends,j))
   412         val ov = snd (Array.sub (recommends,j))
   422       in
   413       in
   423         Array.update (recommends, j, (j, v + ov))
   414         Array.update (recommends, j, (j, v + ov))
   424       end
   415       end;
   425 
       
   426     fun for k =
   416     fun for k =
   427       if k = knns then () else
   417       if k = knns then () else
   428       if k >= avail_no then () else
   418       if k >= adv_max then () else
   429       let
   419       let
   430         val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1)
   420         val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
   431         val o1 = Math.sqrt o2
   421         val o1 = Math.sqrt o2;
   432         val _ = inc_recommend j o1
   422         val () = inc_recommend j o1;
   433         val ds = get_deps j
   423         val ds = get_deps j;
   434         val l = Real.fromInt (length ds)
   424         val l = Real.fromInt (length ds);
   435         val _ = map (fn d => inc_recommend d (o1 / l)) ds
   425         val _ = map (fn d => inc_recommend d (o1 / l)) ds
   436       in
   426       in
   437         for (k + 1)
   427         for (k + 1)
   438       end
   428       end;
   439 
   429     val () = for 0;
   440     val _ = for 0
   430     val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   441     val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends
       
   442 
       
   443     fun ret acc at =
   431     fun ret acc at =
   444       if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   432       if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   445   in
   433   in
   446     ret [] (max 0 (avail_no - advno))
   434     ret [] (max 0 (adv_max - advno))
   447   end
   435   end
   448 
   436 
   449 val knns = 40 (* FUDGE *)
   437 val knns = 40 (* FUDGE *)
   450 
   438 
   451 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   454 
   442 
   455 fun learn_and_query ctxt parents access_G max_suggs hints feats =
   443 fun learn_and_query ctxt parents access_G max_suggs hints feats =
   456   let
   444   let
   457     val str_of_feat = space_implode "|"
   445     val str_of_feat = space_implode "|"
   458 
   446 
   459     val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
   447     val visible_facts = Graph.all_preds access_G parents
   460       fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   448     val visible_fact_set = Symtab.make_set visible_facts
       
   449 
       
   450     val all_nodes =
       
   451       Graph.schedule (K I) access_G
       
   452       |> List.partition (Symtab.defined visible_fact_set o fst)
       
   453       |> op @
       
   454 
       
   455     val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
       
   456       fold (fn (fact, (_, feats, deps)) =>
       
   457             fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   461           let
   458           let
   462             val (_, feats, deps) = Graph.get_node access_G fact
       
   463 
       
   464             fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   459             fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   465               (case Symtab.lookup tab feat of
   460               (case Symtab.lookup tab feat of
   466                 SOME i => ((i, weight), xtab)
   461                 SOME i => ((i, weight), xtab)
   467               | NONE => ((n, weight), add_to_xtab feat xtab))
   462               | NONE => ((n, weight), add_to_xtab feat xtab))
   468 
   463 
   469             val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
   464             val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
   470           in
   465           in
   471             (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   466             (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   472              add_to_xtab fact fact_xtab, feat_xtab')
   467              add_to_xtab fact fact_xtab, feat_xtab')
   473           end)
   468           end)
   474         (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   469         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   475 
   470 
   476     val facts = rev facts0
   471     val facts = rev rev_facts
   477     val fact_ary = Array.fromList facts
   472     val fact_ary = Array.fromList facts
   478 
   473 
   479     val deps_ary = Array.fromList (rev depss0)
   474     val deps_ary = Array.fromList (rev rev_depss)
   480     val facts_ary = Array.array (num_feats, [])
   475     val facts_ary = Array.array (num_feats, [])
   481     val _ =
   476     val _ =
   482       fold (fn feats => fn fact =>
   477       fold (fn feats => fn fact =>
   483           let val fact' = fact - 1 in
   478           let val fact' = fact - 1 in
   484             List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   479             List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   485               feats;
   480               feats;
   486             fact'
   481             fact'
   487           end)
   482           end)
   488         featss (length featss)
   483         featss (length featss)
   489   in
   484   in
   490     trace_msg ctxt (fn () =>
   485     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   491       "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   486       elide_string 1000 (space_implode " " facts) ^ "}");
   492        elide_string 1000 (space_implode " " facts) ^ "}");
   487     knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
   493     knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
   488       (curry Array.sub facts_ary) knns max_suggs
   494       max_suggs
       
   495       (map_filter (fn (feat, weight) =>
   489       (map_filter (fn (feat, weight) =>
   496          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   490          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   497     |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   491     |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   498     |> filter_out (String.isPrefix learned_proof_prefix)
       
   499   end
   492   end
   500 
   493 
   501 end;
   494 end;
   502 
   495 
   503 
   496 
   515 
   508 
   516 fun add_edge_to name parent =
   509 fun add_edge_to name parent =
   517   Graph.default_node (parent, (Isar_Proof, [], []))
   510   Graph.default_node (parent, (Isar_Proof, [], []))
   518   #> Graph.add_edge (parent, name)
   511   #> Graph.add_edge (parent, name)
   519 
   512 
   520 fun add_node kind name feats deps =
   513 fun add_node kind name parents feats deps =
   521   Graph.default_node (name, (kind, feats, deps))
   514   Graph.default_node (name, (kind, feats, deps))
   522   #> Graph.map_node name (K (kind, feats, deps))
   515   #> Graph.map_node name (K (kind, feats, deps))
       
   516   #> fold (add_edge_to name) parents
   523 
   517 
   524 fun try_graph ctxt when def f =
   518 fun try_graph ctxt when def f =
   525   f ()
   519   f ()
   526   handle
   520   handle
   527     Graph.CYCLES (cycle :: _) =>
   521     Graph.CYCLES (cycle :: _) =>
   573          SOME (version' :: node_lines) =>
   567          SOME (version' :: node_lines) =>
   574          let
   568          let
   575            fun extract_line_and_add_node line =
   569            fun extract_line_and_add_node line =
   576              (case extract_node line of
   570              (case extract_node line of
   577                NONE => I (* should not happen *)
   571                NONE => I (* should not happen *)
   578              | SOME (kind, name, parents, feats, deps) =>
   572              | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
   579                add_node kind name feats deps
       
   580                #> fold (add_edge_to name) parents)
       
   581 
   573 
   582            val (access_G, num_known_facts) =
   574            val (access_G, num_known_facts) =
   583              (case string_ord (version', version) of
   575              (case string_ord (version', version) of
   584                EQUAL =>
   576                EQUAL =>
   585                (try_graph ctxt "loading state" Graph.empty (fn () =>
   577                (try_graph ctxt "loading state" Graph.empty (fn () =>
  1130              (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
  1122              (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
  1131   in
  1123   in
  1132     find_maxes Symtab.empty ([], Graph.maximals G)
  1124     find_maxes Symtab.empty ([], Graph.maximals G)
  1133   end
  1125   end
  1134 
  1126 
  1135 fun graph_islands G =
       
  1136   Graph.fold (fn (m, (_, (preds, succs))) =>
       
  1137     (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G [];
       
  1138 
       
  1139 (* islands represent learned proofs associated with no facts *)
       
  1140 fun maximal_wrt_access_graph access_G facts =
  1127 fun maximal_wrt_access_graph access_G facts =
  1141   map (nickname_of_thm o snd) facts @ graph_islands access_G
  1128   map (nickname_of_thm o snd) facts
  1142   |> maximal_wrt_graph access_G
  1129   |> maximal_wrt_graph access_G
  1143 
  1130 
  1144 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
  1131 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
  1145 
  1132 
  1146 val chained_feature_factor = 0.5 (* FUDGE *)
  1133 val chained_feature_factor = 0.5 (* FUDGE *)
  1238   end
  1225   end
  1239 
  1226 
  1240 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
  1227 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
  1241   let
  1228   let
  1242     fun maybe_learn_from from (accum as (parents, G)) =
  1229     fun maybe_learn_from from (accum as (parents, G)) =
  1243       try_graph ctxt "updating G" accum (fn () =>
  1230       try_graph ctxt "updating graph" accum (fn () =>
  1244         (from :: parents, Graph.add_edge_acyclic (from, name) G))
  1231         (from :: parents, Graph.add_edge_acyclic (from, name) G))
  1245     val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
  1232     val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
  1246     val (parents, G) = ([], G) |> fold maybe_learn_from parents
  1233     val (parents, G) = ([], G) |> fold maybe_learn_from parents
  1247     val (deps, _) = ([], G) |> fold maybe_learn_from deps
  1234     val (deps, _) = ([], G) |> fold maybe_learn_from deps
  1248   in
  1235   in
  1272     val death_time = Time.+ (birth_time, hard_timeout)
  1259     val death_time = Time.+ (birth_time, hard_timeout)
  1273     val desc = ("Machine learner for Sledgehammer", "")
  1260     val desc = ("Machine learner for Sledgehammer", "")
  1274   in
  1261   in
  1275     Async_Manager.thread MaShN birth_time death_time desc task
  1262     Async_Manager.thread MaShN birth_time death_time desc task
  1276   end
  1263   end
       
  1264 
       
  1265 fun learned_proof_name () =
       
  1266   Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
  1277 
  1267 
  1278 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
  1268 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
  1279   if is_mash_enabled () then
  1269   if is_mash_enabled () then
  1280     launch_thread timeout (fn () =>
  1270     launch_thread timeout (fn () =>
  1281       let
  1271       let
  1283         val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t]
  1273         val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t]
  1284       in
  1274       in
  1285         map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
  1275         map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
  1286           let
  1276           let
  1287             val name = learned_proof_name ()
  1277             val name = learned_proof_name ()
       
  1278             val parents = maximal_wrt_access_graph access_G facts
  1288             val deps = used_ths
  1279             val deps = used_ths
  1289               |> filter (is_fact_in_graph access_G)
  1280               |> filter (is_fact_in_graph access_G)
  1290               |> map nickname_of_thm
  1281               |> map nickname_of_thm
  1291           in
  1282           in
  1292             if Config.get ctxt sml then
  1283             if Config.get ctxt sml then
  1293               let val access_G = access_G |> add_node Automatic_Proof name feats deps in
  1284               let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
  1294                 {access_G = access_G, num_known_facts = num_known_facts + 1,
  1285                 {access_G = access_G, num_known_facts = num_known_facts + 1,
  1295                  dirty = Option.map (cons name) dirty}
  1286                  dirty = Option.map (cons name) dirty}
  1296               end
  1287               end
  1297             else
  1288             else
  1298               let val parents = maximal_wrt_access_graph access_G facts in
  1289               (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
  1299                 (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
       
  1300               end
       
  1301           end);
  1290           end);
  1302         (true, "")
  1291         (true, "")
  1303       end)
  1292       end)
  1304   else
  1293   else
  1305     ()
  1294     ()