282 (*** Standard ML version of MaSh ***) |
282 (*** Standard ML version of MaSh ***) |
283 |
283 |
284 structure MaSh_SML = |
284 structure MaSh_SML = |
285 struct |
285 struct |
286 |
286 |
287 fun max a b = if a > b then a else b |
|
288 |
|
289 exception BOTTOM of int |
287 exception BOTTOM of int |
290 |
288 |
291 fun heap cmp bnd a = |
289 fun heap cmp bnd a = |
292 let |
290 let |
293 fun maxson l i = |
291 fun maxson l i = |
294 let |
292 let val i31 = i + i + i + 1 in |
295 val i31 = i + i + i + 1 |
|
296 in |
|
297 if i31 + 2 < l then |
293 if i31 + 2 < l then |
298 let |
294 let val x = Unsynchronized.ref i31 in |
299 val x = Unsynchronized.ref i31; |
295 if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else (); |
300 val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else (); |
296 if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else (); |
301 val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else () |
|
302 in |
|
303 !x |
297 !x |
304 end |
298 end |
305 else |
299 else |
306 if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS |
300 if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS |
307 then i31 + 1 else if i31 < l then i31 else raise BOTTOM i |
301 then i31 + 1 else if i31 < l then i31 else raise BOTTOM i |
385 syms = symbols of the conjecture |
379 syms = symbols of the conjecture |
386 *) |
380 *) |
387 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms = |
381 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms = |
388 let |
382 let |
389 (* Can be later used for TFIDF *) |
383 (* Can be later used for TFIDF *) |
390 fun sym_wght _ = 1.0; |
384 fun sym_wght _ = 1.0 |
391 val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0))); |
385 |
|
386 val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0))) |
|
387 |
392 fun inc_overlap j v = |
388 fun inc_overlap j v = |
393 let |
389 let |
394 val ov = snd (Array.sub (overlaps_sqr,j)) |
390 val ov = snd (Array.sub (overlaps_sqr, j)) |
395 in |
391 in |
396 Array.update (overlaps_sqr, j, (j, v + ov)) |
392 Array.update (overlaps_sqr, j, (j, v + ov)) |
397 end; |
393 end |
|
394 |
398 fun do_sym (s, con_wght) = |
395 fun do_sym (s, con_wght) = |
399 let |
396 let |
400 val sw = sym_wght s; |
397 val sw = sym_wght s |
401 val w2 = sw * sw * con_wght; |
398 val w2 = sw * sw * con_wght |
|
399 |
402 fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else () |
400 fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else () |
403 in |
401 in |
404 ignore (map do_th (get_sym_ths s)) |
402 List.app do_th (get_sym_ths s) |
405 end; |
403 end |
406 val () = ignore (map do_sym syms); |
404 |
407 val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr; |
405 val _ = List.app do_sym syms |
408 val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0))); |
406 val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr |
|
407 val recommends = Array.tabulate (adv_max, rpair 0.0) |
|
408 |
409 fun inc_recommend j v = |
409 fun inc_recommend j v = |
410 if j >= adv_max then () else |
410 if j >= adv_max then () |
411 let |
411 else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j)))) |
412 val ov = snd (Array.sub (recommends,j)) |
412 |
413 in |
|
414 Array.update (recommends, j, (j, v + ov)) |
|
415 end; |
|
416 fun for k = |
413 fun for k = |
417 if k = knns then () else |
414 if k = knns orelse k >= adv_max then |
418 if k >= adv_max then () else |
415 () |
419 let |
416 else |
420 val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1); |
417 let |
421 val o1 = Math.sqrt o2; |
418 val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1) |
422 val () = inc_recommend j o1; |
419 val o1 = Math.sqrt o2 |
423 val ds = get_deps j; |
420 val _ = inc_recommend j o1 |
424 val l = Real.fromInt (length ds); |
421 val ds = get_deps j |
425 val _ = map (fn d => inc_recommend d (o1 / l)) ds |
422 val l = Real.fromInt (length ds) |
426 in |
423 val _ = map (fn d => inc_recommend d (o1 / l)) ds |
427 for (k + 1) |
424 in |
428 end; |
425 for (k + 1) |
429 val () = for 0; |
426 end |
430 val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends; |
427 |
|
428 val _ = for 0 |
|
429 val _ = heap (Real.compare o pairself snd) advno recommends |
|
430 |
431 fun ret acc at = |
431 fun ret acc at = |
432 if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1) |
432 if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1) |
433 in |
433 in |
434 ret [] (max 0 (adv_max - advno)) |
434 ret [] (Integer.max 0 (adv_max - advno)) |
435 end |
435 end |
436 |
436 |
437 val knns = 40 (* FUDGE *) |
437 val knns = 40 (* FUDGE *) |
438 |
438 |
439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) |
440 |
440 |
441 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
441 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) |
442 |
442 |
443 fun learn_and_query ctxt parents access_G max_suggs hints feats = |
443 fun query ctxt parents access_G max_suggs hints feats = |
444 let |
444 let |
445 val str_of_feat = space_implode "|" |
445 val str_of_feat = space_implode "|" |
446 |
446 |
447 val visible_facts = Graph.all_preds access_G parents |
447 val visible_facts = Graph.all_preds access_G parents |
448 val visible_fact_set = Symtab.make_set visible_facts |
448 val visible_fact_set = Symtab.make_set visible_facts |
467 add_to_xtab fact fact_xtab, feat_xtab') |
467 add_to_xtab fact fact_xtab, feat_xtab') |
468 end) |
468 end) |
469 all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
469 all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, [])) |
470 |
470 |
471 val facts = rev rev_facts |
471 val facts = rev rev_facts |
472 val fact_ary = Array.fromList facts |
472 val fact_vec = Vector.fromList facts |
473 |
473 |
474 val deps_ary = Array.fromList (rev rev_depss) |
474 val deps_vec = Vector.fromList (rev rev_depss) |
475 val facts_ary = Array.array (num_feats, []) |
475 val facts_ary = Array.array (num_feats, []) |
476 val _ = |
476 val _ = |
477 fold (fn feats => fn fact => |
477 fold (fn feats => fn fact => |
478 let val fact' = fact - 1 in |
478 let val fact' = fact - 1 in |
479 List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) |
479 List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat) |
1020 |> fold (add_isar_dep facts) isar_deps |
1017 |> fold (add_isar_dep facts) isar_deps |
1021 |> map nickify |
1018 |> map nickify |
1022 val num_isar_deps = length isar_deps |
1019 val num_isar_deps = length isar_deps |
1023 in |
1020 in |
1024 if verbose andalso auto_level = 0 then |
1021 if verbose andalso auto_level = 0 then |
1025 "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^ |
1022 Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ |
1026 " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts." |
1023 string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^ |
1027 |> Output.urgent_message |
1024 " facts.") |
1028 else |
1025 else |
1029 (); |
1026 (); |
1030 (case run_prover_for_mash ctxt params prover name facts goal of |
1027 (case run_prover_for_mash ctxt params prover name facts goal of |
1031 {outcome = NONE, used_facts, ...} => |
1028 {outcome = NONE, used_facts, ...} => |
1032 (if verbose andalso auto_level = 0 then |
1029 (if verbose andalso auto_level = 0 then |
1033 let val num_facts = length used_facts in |
1030 let val num_facts = length used_facts in |
1034 "Found proof with " ^ string_of_int num_facts ^ " fact" ^ |
1031 Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^ |
1035 plural_s num_facts ^ "." |
1032 plural_s num_facts ^ ".") |
1036 |> Output.urgent_message |
|
1037 end |
1033 end |
1038 else |
1034 else |
1039 (); |
1035 (); |
1040 (true, map fst used_facts)) |
1036 (true, map fst used_facts)) |
1041 | _ => (false, isar_deps)) |
1037 | _ => (false, isar_deps)) |
1185 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |
1181 fun chained_or_extra_features_of factor (((_, stature), th), weight) = |
1186 [prop_of th] |
1182 [prop_of th] |
1187 |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false |
1183 |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false |
1188 |> map (apsnd (fn r => weight * factor * r)) |
1184 |> map (apsnd (fn r => weight * factor * r)) |
1189 |
1185 |
1190 val (access_G, suggs) = |
1186 fun query_args access_G = |
|
1187 let |
|
1188 val parents = maximal_wrt_access_graph access_G facts |
|
1189 val hints = chained |
|
1190 |> filter (is_fact_in_graph access_G o snd) |
|
1191 |> map (nickname_of_thm o snd) |
|
1192 |
|
1193 val goal_feats = |
|
1194 features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts) |
|
1195 val chained_feats = chained |
|
1196 |> map (rpair 1.0) |
|
1197 |> map (chained_or_extra_features_of chained_feature_factor) |
|
1198 |> rpair [] |-> fold (union (eq_fst (op =))) |
|
1199 val extra_feats = facts |
|
1200 |> take (Int.max (0, num_extra_feature_facts - length chained)) |
|
1201 |> filter fact_has_right_theory |
|
1202 |> weight_facts_steeply |
|
1203 |> map (chained_or_extra_features_of extra_feature_factor) |
|
1204 |> rpair [] |-> fold (union (eq_fst (op =))) |
|
1205 val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats |
|
1206 |> debug ? sort (Real.compare o swap o pairself snd) |
|
1207 in |
|
1208 (parents, hints, feats) |
|
1209 end |
|
1210 |
|
1211 val sml = Config.get ctxt sml |
|
1212 |
|
1213 val (access_G, py_suggs) = |
1191 peek_state ctxt overlord (fn {access_G, ...} => |
1214 peek_state ctxt overlord (fn {access_G, ...} => |
1192 if Graph.is_empty access_G then |
1215 if Graph.is_empty access_G then |
1193 (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, [])) |
1216 (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, [])) |
1194 else |
1217 else |
1195 let |
1218 (access_G, |
1196 val parents = maximal_wrt_access_graph access_G facts |
1219 if sml then |
1197 val goal_feats = |
1220 [] |
1198 features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts) |
1221 else |
1199 val chained_feats = chained |
1222 let val (parents, hints, feats) = query_args access_G in |
1200 |> map (rpair 1.0) |
1223 MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats) |
1201 |> map (chained_or_extra_features_of chained_feature_factor) |
1224 end)) |
1202 |> rpair [] |-> fold (union (eq_fst (op =))) |
1225 |
1203 val extra_feats = facts |
1226 val sml_suggs = |
1204 |> take (Int.max (0, num_extra_feature_facts - length chained)) |
1227 if sml then |
1205 |> filter fact_has_right_theory |
1228 let val (parents, hints, feats) = query_args access_G in |
1206 |> weight_facts_steeply |
1229 MaSh_SML.query ctxt parents access_G max_facts hints feats |
1207 |> map (chained_or_extra_features_of extra_feature_factor) |
1230 end |
1208 |> rpair [] |-> fold (union (eq_fst (op =))) |
1231 else |
1209 val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats |
1232 [] |
1210 |> debug ? sort (Real.compare o swap o pairself snd) |
1233 |
1211 val hints = chained |
|
1212 |> filter (is_fact_in_graph access_G o snd) |
|
1213 |> map (nickname_of_thm o snd) |
|
1214 in |
|
1215 (access_G, |
|
1216 if Config.get ctxt sml then |
|
1217 MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats |
|
1218 else |
|
1219 MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)) |
|
1220 end) |
|
1221 val unknown = filter_out (is_fact_in_graph access_G o snd) facts |
1234 val unknown = filter_out (is_fact_in_graph access_G o snd) facts |
1222 in |
1235 in |
1223 find_mash_suggestions ctxt max_facts suggs facts chained unknown |
1236 find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown |
1224 |> pairself (map fact_of_raw_fact) |
1237 |> pairself (map fact_of_raw_fact) |
1225 end |
1238 end |
1226 |
1239 |
1227 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = |
1240 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) = |
1228 let |
1241 let |
1353 MaSh_Py.relearn ctxt overlord save relearns); |
1366 MaSh_Py.relearn ctxt overlord save relearns); |
1354 {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} |
1367 {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty} |
1355 end |
1368 end |
1356 |
1369 |
1357 fun commit last learns relearns flops = |
1370 fun commit last learns relearns flops = |
1358 (if debug andalso auto_level = 0 then |
1371 (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else (); |
1359 Output.urgent_message "Committing..." |
|
1360 else |
|
1361 (); |
|
1362 map_state ctxt overlord (do_commit (rev learns) relearns flops); |
1372 map_state ctxt overlord (do_commit (rev learns) relearns flops); |
1363 if not last andalso auto_level = 0 then |
1373 if not last andalso auto_level = 0 then |
1364 let val num_proofs = length learns + length relearns in |
1374 let val num_proofs = length learns + length relearns in |
1365 "Learned " ^ string_of_int num_proofs ^ " " ^ |
1375 Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^ |
1366 (if run_prover then "automatic" else "Isar") ^ " proof" ^ |
1376 (if run_prover then "automatic" else "Isar") ^ " proof" ^ |
1367 plural_s num_proofs ^ " in the last " ^ |
1377 plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".") |
1368 string_of_time commit_timeout ^ "." |
|
1369 |> Output.urgent_message |
|
1370 end |
1378 end |
1371 else |
1379 else |
1372 ()) |
1380 ()) |
1373 |
1381 |
1374 fun learn_new_fact _ (accum as (_, (_, _, true))) = accum |
1382 fun learn_new_fact _ (accum as (_, (_, _, true))) = accum |
1476 fun learn auto_level run_prover = |
1484 fun learn auto_level run_prover = |
1477 mash_learn_facts ctxt params prover true auto_level run_prover one_year facts |
1485 mash_learn_facts ctxt params prover true auto_level run_prover one_year facts |
1478 |> Output.urgent_message |
1486 |> Output.urgent_message |
1479 in |
1487 in |
1480 if run_prover then |
1488 if run_prover then |
1481 ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^ |
1489 (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ |
1482 " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^ |
1490 plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^ |
1483 ").\n\nCollecting Isar proofs first..." |
1491 string_of_time timeout ^ ").\n\nCollecting Isar proofs first..."); |
1484 |> Output.urgent_message; |
|
1485 learn 1 false; |
1492 learn 1 false; |
1486 "Now collecting automatic proofs. This may take several hours. You can safely stop the \ |
1493 Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \ |
1487 \learning process at any point." |
1494 \can safely stop the learning process at any point."; |
1488 |> Output.urgent_message; |
|
1489 learn 0 true) |
1495 learn 0 true) |
1490 else |
1496 else |
1491 (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ |
1497 (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ |
1492 plural_s num_facts ^ " for Isar proofs..."); |
1498 plural_s num_facts ^ " for Isar proofs..."); |
1493 learn 0 false) |
1499 learn 0 false) |