166 fun avg [] = 0.0 |
166 fun avg [] = 0.0 |
167 | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) |
167 | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs) |
168 |
168 |
169 fun normalize_scores _ [] = [] |
169 fun normalize_scores _ [] = [] |
170 | normalize_scores max_facts xs = |
170 | normalize_scores max_facts xs = |
171 map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs |
171 map (apsnd (curry (op *) (1.0 / avg (map snd (take max_facts xs))))) xs |
172 |
172 |
173 fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] = |
173 fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] = |
174 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
174 map fst (take max_facts sels) @ take (max_facts - length sels) unks |
175 |> maybe_distinct |
175 |> maybe_distinct |
176 | mesh_facts _ fact_eq max_facts mess = |
176 | mesh_facts _ fact_eq max_facts mess = |
653 fun load_state ctxt (time_state as (memory_time, _)) = |
653 fun load_state ctxt (time_state as (memory_time, _)) = |
654 let val path = state_file () in |
654 let val path = state_file () in |
655 (case try OS.FileSys.modTime (File.platform_path path) of |
655 (case try OS.FileSys.modTime (File.platform_path path) of |
656 NONE => time_state |
656 NONE => time_state |
657 | SOME disk_time => |
657 | SOME disk_time => |
658 if Time.>= (memory_time, disk_time) then |
658 if memory_time >= disk_time then |
659 time_state |
659 time_state |
660 else |
660 else |
661 (disk_time, |
661 (disk_time, |
662 (case try File.read_lines path of |
662 (case try File.read_lines path of |
663 SOME (version' :: node_lines) => |
663 SOME (version' :: node_lines) => |
698 |
698 |
699 val path = state_file () |
699 val path = state_file () |
700 val dirty_facts' = |
700 val dirty_facts' = |
701 (case try OS.FileSys.modTime (File.platform_path path) of |
701 (case try OS.FileSys.modTime (File.platform_path path) of |
702 NONE => NONE |
702 NONE => NONE |
703 | SOME disk_time => if Time.<= (disk_time, memory_time) then dirty_facts else NONE) |
703 | SOME disk_time => if disk_time <= memory_time then dirty_facts else NONE) |
704 val (banner, entries) = |
704 val (banner, entries) = |
705 (case dirty_facts' of |
705 (case dirty_facts' of |
706 SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names []) |
706 SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names []) |
707 | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])) |
707 | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])) |
708 in |
708 in |
1276 |
1276 |
1277 fun launch_thread timeout task = |
1277 fun launch_thread timeout task = |
1278 let |
1278 let |
1279 val hard_timeout = time_mult learn_timeout_slack timeout |
1279 val hard_timeout = time_mult learn_timeout_slack timeout |
1280 val birth_time = Time.now () |
1280 val birth_time = Time.now () |
1281 val death_time = Time.+ (birth_time, hard_timeout) |
1281 val death_time = birth_time + hard_timeout |
1282 val desc = ("Machine learner for Sledgehammer", "") |
1282 val desc = ("Machine learner for Sledgehammer", "") |
1283 in |
1283 in |
1284 Async_Manager_Legacy.thread MaShN birth_time death_time desc task |
1284 Async_Manager_Legacy.thread MaShN birth_time death_time desc task |
1285 end |
1285 end |
1286 |
1286 |
1326 (* The timeout is understood in a very relaxed fashion. *) |
1326 (* The timeout is understood in a very relaxed fashion. *) |
1327 fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover |
1327 fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover |
1328 learn_timeout facts = |
1328 learn_timeout facts = |
1329 let |
1329 let |
1330 val timer = Timer.startRealTimer () |
1330 val timer = Timer.startRealTimer () |
1331 fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout) |
1331 fun next_commit_time () = Timer.checkRealTimer timer + commit_timeout |
1332 |
1332 |
1333 val {access_G, ...} = peek_state ctxt |
1333 val {access_G, ...} = peek_state ctxt |
1334 val is_in_access_G = is_fact_in_graph access_G o snd |
1334 val is_in_access_G = is_fact_in_graph access_G o snd |
1335 val no_new_facts = forall is_in_access_G facts |
1335 val no_new_facts = forall is_in_access_G facts |
1336 in |
1336 in |
1406 val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th] |
1406 val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th] |
1407 val deps = these (deps_of status th) |
1407 val deps = these (deps_of status th) |
1408 val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1 |
1408 val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1 |
1409 val learns = (name, parents, feats, deps) :: learns |
1409 val learns = (name, parents, feats, deps) :: learns |
1410 val (learns, next_commit) = |
1410 val (learns, next_commit) = |
1411 if Time.> (Timer.checkRealTimer timer, next_commit) then |
1411 if Timer.checkRealTimer timer > next_commit then |
1412 (commit false learns [] []; ([], next_commit_time ())) |
1412 (commit false learns [] []; ([], next_commit_time ())) |
1413 else |
1413 else |
1414 (learns, next_commit) |
1414 (learns, next_commit) |
1415 val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout) |
1415 val timed_out = Timer.checkRealTimer timer > learn_timeout |
1416 in |
1416 in |
1417 (learns, (num_nontrivial, next_commit, timed_out)) |
1417 (learns, (num_nontrivial, next_commit, timed_out)) |
1418 end |
1418 end |
1419 |
1419 |
1420 val (num_new_facts, num_nontrivial) = |
1420 val (num_new_facts, num_nontrivial) = |
1441 val (num_nontrivial, relearns, flops) = |
1441 val (num_nontrivial, relearns, flops) = |
1442 (case deps_of status th of |
1442 (case deps_of status th of |
1443 SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops) |
1443 SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops) |
1444 | NONE => (num_nontrivial, relearns, name :: flops)) |
1444 | NONE => (num_nontrivial, relearns, name :: flops)) |
1445 val (relearns, flops, next_commit) = |
1445 val (relearns, flops, next_commit) = |
1446 if Time.> (Timer.checkRealTimer timer, next_commit) then |
1446 if Timer.checkRealTimer timer > next_commit then |
1447 (commit false [] relearns flops; ([], [], next_commit_time ())) |
1447 (commit false [] relearns flops; ([], [], next_commit_time ())) |
1448 else |
1448 else |
1449 (relearns, flops, next_commit) |
1449 (relearns, flops, next_commit) |
1450 val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout) |
1450 val timed_out = Timer.checkRealTimer timer > learn_timeout |
1451 in |
1451 in |
1452 ((relearns, flops), (num_nontrivial, next_commit, timed_out)) |
1452 ((relearns, flops), (num_nontrivial, next_commit, timed_out)) |
1453 end |
1453 end |
1454 |
1454 |
1455 val num_nontrivial = |
1455 val num_nontrivial = |