src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 62826 eb94e570c1a4
parent 62735 23de054397e5
child 63518 ae8fd6fe63a1
equal deleted inserted replaced
62825:e6e80a8bf624 62826:eb94e570c1a4
   166 fun avg [] = 0.0
   166 fun avg [] = 0.0
   167   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   167   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   168 
   168 
   169 fun normalize_scores _ [] = []
   169 fun normalize_scores _ [] = []
   170   | normalize_scores max_facts xs =
   170   | normalize_scores max_facts xs =
   171     map (apsnd (curry Real.* (1.0 / avg (map snd (take max_facts xs))))) xs
   171     map (apsnd (curry (op *) (1.0 / avg (map snd (take max_facts xs))))) xs
   172 
   172 
   173 fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] =
   173 fun mesh_facts maybe_distinct _ max_facts [(_, (sels, unks))] =
   174     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   174     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   175     |> maybe_distinct
   175     |> maybe_distinct
   176   | mesh_facts _ fact_eq max_facts mess =
   176   | mesh_facts _ fact_eq max_facts mess =
   653 fun load_state ctxt (time_state as (memory_time, _)) =
   653 fun load_state ctxt (time_state as (memory_time, _)) =
   654   let val path = state_file () in
   654   let val path = state_file () in
   655     (case try OS.FileSys.modTime (File.platform_path path) of
   655     (case try OS.FileSys.modTime (File.platform_path path) of
   656       NONE => time_state
   656       NONE => time_state
   657     | SOME disk_time =>
   657     | SOME disk_time =>
   658       if Time.>= (memory_time, disk_time) then
   658       if memory_time >= disk_time then
   659         time_state
   659         time_state
   660       else
   660       else
   661         (disk_time,
   661         (disk_time,
   662          (case try File.read_lines path of
   662          (case try File.read_lines path of
   663            SOME (version' :: node_lines) =>
   663            SOME (version' :: node_lines) =>
   698 
   698 
   699       val path = state_file ()
   699       val path = state_file ()
   700       val dirty_facts' =
   700       val dirty_facts' =
   701         (case try OS.FileSys.modTime (File.platform_path path) of
   701         (case try OS.FileSys.modTime (File.platform_path path) of
   702           NONE => NONE
   702           NONE => NONE
   703         | SOME disk_time => if Time.<= (disk_time, memory_time) then dirty_facts else NONE)
   703         | SOME disk_time => if disk_time <= memory_time then dirty_facts else NONE)
   704       val (banner, entries) =
   704       val (banner, entries) =
   705         (case dirty_facts' of
   705         (case dirty_facts' of
   706           SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   706           SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
   707         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
   707         | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
   708     in
   708     in
  1276 
  1276 
  1277 fun launch_thread timeout task =
  1277 fun launch_thread timeout task =
  1278   let
  1278   let
  1279     val hard_timeout = time_mult learn_timeout_slack timeout
  1279     val hard_timeout = time_mult learn_timeout_slack timeout
  1280     val birth_time = Time.now ()
  1280     val birth_time = Time.now ()
  1281     val death_time = Time.+ (birth_time, hard_timeout)
  1281     val death_time = birth_time + hard_timeout
  1282     val desc = ("Machine learner for Sledgehammer", "")
  1282     val desc = ("Machine learner for Sledgehammer", "")
  1283   in
  1283   in
  1284     Async_Manager_Legacy.thread MaShN birth_time death_time desc task
  1284     Async_Manager_Legacy.thread MaShN birth_time death_time desc task
  1285   end
  1285   end
  1286 
  1286 
  1326 (* The timeout is understood in a very relaxed fashion. *)
  1326 (* The timeout is understood in a very relaxed fashion. *)
  1327 fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover
  1327 fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover
  1328     learn_timeout facts =
  1328     learn_timeout facts =
  1329   let
  1329   let
  1330     val timer = Timer.startRealTimer ()
  1330     val timer = Timer.startRealTimer ()
  1331     fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1331     fun next_commit_time () = Timer.checkRealTimer timer + commit_timeout
  1332 
  1332 
  1333     val {access_G, ...} = peek_state ctxt
  1333     val {access_G, ...} = peek_state ctxt
  1334     val is_in_access_G = is_fact_in_graph access_G o snd
  1334     val is_in_access_G = is_fact_in_graph access_G o snd
  1335     val no_new_facts = forall is_in_access_G facts
  1335     val no_new_facts = forall is_in_access_G facts
  1336   in
  1336   in
  1406               val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th]
  1406               val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th]
  1407               val deps = these (deps_of status th)
  1407               val deps = these (deps_of status th)
  1408               val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1
  1408               val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1
  1409               val learns = (name, parents, feats, deps) :: learns
  1409               val learns = (name, parents, feats, deps) :: learns
  1410               val (learns, next_commit) =
  1410               val (learns, next_commit) =
  1411                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1411                 if Timer.checkRealTimer timer > next_commit then
  1412                   (commit false learns [] []; ([], next_commit_time ()))
  1412                   (commit false learns [] []; ([], next_commit_time ()))
  1413                 else
  1413                 else
  1414                   (learns, next_commit)
  1414                   (learns, next_commit)
  1415               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1415               val timed_out = Timer.checkRealTimer timer > learn_timeout
  1416             in
  1416             in
  1417               (learns, (num_nontrivial, next_commit, timed_out))
  1417               (learns, (num_nontrivial, next_commit, timed_out))
  1418             end
  1418             end
  1419 
  1419 
  1420         val (num_new_facts, num_nontrivial) =
  1420         val (num_new_facts, num_nontrivial) =
  1441               val (num_nontrivial, relearns, flops) =
  1441               val (num_nontrivial, relearns, flops) =
  1442                 (case deps_of status th of
  1442                 (case deps_of status th of
  1443                   SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops)
  1443                   SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops)
  1444                 | NONE => (num_nontrivial, relearns, name :: flops))
  1444                 | NONE => (num_nontrivial, relearns, name :: flops))
  1445               val (relearns, flops, next_commit) =
  1445               val (relearns, flops, next_commit) =
  1446                 if Time.> (Timer.checkRealTimer timer, next_commit) then
  1446                 if Timer.checkRealTimer timer > next_commit then
  1447                   (commit false [] relearns flops; ([], [], next_commit_time ()))
  1447                   (commit false [] relearns flops; ([], [], next_commit_time ()))
  1448                 else
  1448                 else
  1449                   (relearns, flops, next_commit)
  1449                   (relearns, flops, next_commit)
  1450               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
  1450               val timed_out = Timer.checkRealTimer timer > learn_timeout
  1451             in
  1451             in
  1452               ((relearns, flops), (num_nontrivial, next_commit, timed_out))
  1452               ((relearns, flops), (num_nontrivial, next_commit, timed_out))
  1453             end
  1453             end
  1454 
  1454 
  1455         val num_nontrivial =
  1455         val num_nontrivial =