src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 48669 cdcdb0547f29
parent 48668 5d63c23b4042
child 48699 a89b83204c24
equal deleted inserted replaced
48668:5d63c23b4042 48669:cdcdb0547f29
   131 val escape_metas = map escape_meta #> space_implode " "
   131 val escape_metas = map escape_meta #> space_implode " "
   132 val unescape_meta = String.explode #> unmeta_chars []
   132 val unescape_meta = String.explode #> unmeta_chars []
   133 val unescape_metas =
   133 val unescape_metas =
   134   space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   134   space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   135 
   135 
       
   136 datatype proof_kind = Isar_Proof | ATP_Proof | Isar_Proof_wegen_ATP_Flop
       
   137 
       
   138 fun str_of_proof_kind Isar_Proof = "i"
       
   139   | str_of_proof_kind ATP_Proof = "a"
       
   140   | str_of_proof_kind Isar_Proof_wegen_ATP_Flop = "x"
       
   141 
       
   142 fun proof_kind_of_str "i" = Isar_Proof
       
   143   | proof_kind_of_str "a" = ATP_Proof
       
   144   | proof_kind_of_str "x" = Isar_Proof_wegen_ATP_Flop
       
   145 
   136 fun extract_node line =
   146 fun extract_node line =
   137   case space_explode ":" line of
   147   case space_explode ":" line of
   138     [head, parents] =>
   148     [head, parents] =>
   139     (case space_explode " " head of
   149     (case space_explode " " head of
   140        [tag, name] => SOME (unescape_meta name, unescape_metas parents, tag = "a")
   150        [kind, name] =>
       
   151        SOME (unescape_meta name, unescape_metas parents,
       
   152              try proof_kind_of_str kind |> the_default Isar_Proof)
   141      | _ => NONE)
   153      | _ => NONE)
   142   | _ => NONE
   154   | _ => NONE
   143 
   155 
   144 fun extract_suggestion sugg =
   156 fun extract_suggestion sugg =
   145   case space_explode "=" sugg of
   157   case space_explode "=" sugg of
   499 
   511 
   500 
   512 
   501 (*** High-level communication with MaSh ***)
   513 (*** High-level communication with MaSh ***)
   502 
   514 
   503 (* FIXME: Here a "Graph.update_node" function would be useful *)
   515 (* FIXME: Here a "Graph.update_node" function would be useful *)
   504 fun update_fact_graph_node (name, atp) =
   516 fun update_fact_graph_node (name, kind) =
   505   Graph.default_node (name, false)
   517   Graph.default_node (name, Isar_Proof)
   506   #> atp ? Graph.map_node name (K atp)
   518   #> kind <> Isar_Proof ? Graph.map_node name (K kind)
   507 
   519 
   508 fun try_graph ctxt when def f =
   520 fun try_graph ctxt when def f =
   509   f ()
   521   f ()
   510   handle Graph.CYCLES (cycle :: _) =>
   522   handle Graph.CYCLES (cycle :: _) =>
   511          (trace_msg ctxt (fn () =>
   523          (trace_msg ctxt (fn () =>
   545       (true,
   557       (true,
   546        case try File.read_lines path of
   558        case try File.read_lines path of
   547          SOME (version' :: node_lines) =>
   559          SOME (version' :: node_lines) =>
   548          let
   560          let
   549            fun add_edge_to name parent =
   561            fun add_edge_to name parent =
   550              Graph.default_node (parent, false)
   562              Graph.default_node (parent, Isar_Proof)
   551              #> Graph.add_edge (parent, name)
   563              #> Graph.add_edge (parent, name)
   552            fun add_node line =
   564            fun add_node line =
   553              case extract_node line of
   565              case extract_node line of
   554                NONE => I (* shouldn't happen *)
   566                NONE => I (* shouldn't happen *)
   555              | SOME (name, parents, atp) =>
   567              | SOME (name, parents, kind) =>
   556                update_fact_graph_node (name, atp)
   568                update_fact_graph_node (name, kind)
   557                #> fold (add_edge_to name) parents
   569                #> fold (add_edge_to name) parents
   558            val fact_G =
   570            val fact_G =
   559              try_graph ctxt "loading state" Graph.empty (fn () =>
   571              try_graph ctxt "loading state" Graph.empty (fn () =>
   560                  Graph.empty |> version' = version ? fold add_node node_lines)
   572                  Graph.empty |> version' = version ? fold add_node node_lines)
   561          in
   573          in
   566        | _ => empty_state)
   578        | _ => empty_state)
   567     end
   579     end
   568 
   580 
   569 fun save ctxt {fact_G} =
   581 fun save ctxt {fact_G} =
   570   let
   582   let
   571     fun str_of_entry (name, parents, atp) =
   583     fun str_of_entry (name, parents, kind) =
   572       (if atp then "a" else "i") ^ " " ^ escape_meta name ^ ": " ^
   584       str_of_proof_kind kind ^ " " ^ escape_meta name ^ ": " ^
   573       escape_metas parents ^ "\n"
   585       escape_metas parents ^ "\n"
   574     fun append_entry (name, (atp, (parents, _))) =
   586     fun append_entry (name, (kind, (parents, _))) =
   575       cons (name, Graph.Keys.dest parents, atp)
   587       cons (name, Graph.Keys.dest parents, kind)
   576     val entries = [] |> Graph.fold append_entry fact_G
   588     val entries = [] |> Graph.fold append_entry fact_G
   577   in
   589   in
   578     write_file (version ^ "\n") (entries, str_of_entry) (mash_state_file ());
   590     write_file (version ^ "\n") (entries, str_of_entry) (mash_state_file ());
   579     trace_msg ctxt (fn () => "Saved fact graph (" ^ graph_info fact_G ^ ")")
   591     trace_msg ctxt (fn () => "Saved fact graph (" ^ graph_info fact_G ^ ")")
   580   end
   592   end
   679 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   691 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   680   let
   692   let
   681     fun maybe_add_from from (accum as (parents, graph)) =
   693     fun maybe_add_from from (accum as (parents, graph)) =
   682       try_graph ctxt "updating graph" accum (fn () =>
   694       try_graph ctxt "updating graph" accum (fn () =>
   683           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   695           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   684     val graph = graph |> Graph.default_node (name, false)
   696     val graph = graph |> Graph.default_node (name, Isar_Proof)
   685     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   697     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   686     val (deps, _) = ([], graph) |> fold maybe_add_from deps
   698     val (deps, _) = ([], graph) |> fold maybe_add_from deps
   687   in ((name, parents, feats, deps) :: adds, graph) end
   699   in ((name, parents, feats, deps) :: adds, graph) end
   688 
   700 
   689 fun reprove_wrt_fact_graph ctxt (name, deps) (reps, graph) =
   701 fun reprove_wrt_fact_graph ctxt (name, deps) (reps, graph) =
   690   let
   702   let
   691     fun maybe_rep_from from (accum as (parents, graph)) =
   703     fun maybe_rep_from from (accum as (parents, graph)) =
   692       try_graph ctxt "updating graph" accum (fn () =>
   704       try_graph ctxt "updating graph" accum (fn () =>
   693           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   705           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   694     val graph = graph |> update_fact_graph_node (name, true)
   706     val graph = graph |> update_fact_graph_node (name, ATP_Proof)
   695     val (deps, _) = ([], graph) |> fold maybe_rep_from deps
   707     val (deps, _) = ([], graph) |> fold maybe_rep_from deps
   696   in ((name, deps) :: reps, graph) end
   708   in ((name, deps) :: reps, graph) end
       
   709 
       
   710 fun flop_wrt_fact_graph name =
       
   711   update_fact_graph_node (name, Isar_Proof_wegen_ATP_Flop)
   697 
   712 
   698 val learn_timeout_slack = 2.0
   713 val learn_timeout_slack = 2.0
   699 
   714 
   700 fun launch_thread timeout task =
   715 fun launch_thread timeout task =
   701   let
   716   let
   764           else if atp then
   779           else if atp then
   765             atp_dependencies_of ctxt params prover auto_level facts all_names th
   780             atp_dependencies_of ctxt params prover auto_level facts all_names th
   766             |> (fn (false, _) => NONE | (true, deps) => deps)
   781             |> (fn (false, _) => NONE | (true, deps) => deps)
   767           else
   782           else
   768             isar_dependencies_of all_names th
   783             isar_dependencies_of all_names th
   769         fun do_commit [] [] state = state
   784         fun do_commit [] [] [] state = state
   770           | do_commit adds reps {fact_G} =
   785           | do_commit adds reps flops {fact_G} =
   771             let
   786             let
   772               val (adds, fact_G) =
   787               val (adds, fact_G) =
   773                 ([], fact_G) |> fold (add_wrt_fact_graph ctxt) adds
   788                 ([], fact_G) |> fold (add_wrt_fact_graph ctxt) adds
   774               val (reps, fact_G) =
   789               val (reps, fact_G) =
   775                 ([], fact_G) |> fold (reprove_wrt_fact_graph ctxt) reps
   790                 ([], fact_G) |> fold (reprove_wrt_fact_graph ctxt) reps
       
   791               val fact_G = fact_G |> fold flop_wrt_fact_graph flops
   776             in
   792             in
   777               mash_ADD ctxt overlord (rev adds);
   793               mash_ADD ctxt overlord (rev adds);
   778               mash_REPROVE ctxt overlord reps;
   794               mash_REPROVE ctxt overlord reps;
   779               {fact_G = fact_G}
   795               {fact_G = fact_G}
   780             end
   796             end
   781         fun commit last adds reps =
   797         fun commit last adds reps flops =
   782           (if debug andalso auto_level = 0 then
   798           (if debug andalso auto_level = 0 then
   783              Output.urgent_message "Committing..."
   799              Output.urgent_message "Committing..."
   784            else
   800            else
   785              ();
   801              ();
   786            mash_map ctxt (do_commit (rev adds) reps);
   802            mash_map ctxt (do_commit (rev adds) reps flops);
   787            if not last andalso auto_level = 0 then
   803            if not last andalso auto_level = 0 then
   788              let val num_proofs = length adds + length reps in
   804              let val num_proofs = length adds + length reps in
   789                "Learned " ^ string_of_int num_proofs ^ " " ^
   805                "Learned " ^ string_of_int num_proofs ^ " " ^
   790                (if atp then "ATP" else "Isar") ^ " proof" ^
   806                (if atp then "ATP" else "Isar") ^ " proof" ^
   791                plural_s num_proofs ^ " in the last " ^
   807                plural_s num_proofs ^ " in the last " ^
   804               val deps = deps_of status th |> these
   820               val deps = deps_of status th |> these
   805               val n = n |> not (null deps) ? Integer.add 1
   821               val n = n |> not (null deps) ? Integer.add 1
   806               val adds = (name, parents, feats, deps) :: adds
   822               val adds = (name, parents, feats, deps) :: adds
   807               val (adds, next_commit) =
   823               val (adds, next_commit) =
   808                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   824                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   809                   (commit false adds []; ([], next_commit_time ()))
   825                   (commit false adds [] []; ([], next_commit_time ()))
   810                 else
   826                 else
   811                   (adds, next_commit)
   827                   (adds, next_commit)
   812               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   828               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   813             in (adds, ([name], n, next_commit, timed_out)) end
   829             in (adds, ([name], n, next_commit, timed_out)) end
   814         val n =
   830         val n =
   823                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   839                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   824               val parents = maximal_in_graph fact_G ancestors
   840               val parents = maximal_in_graph fact_G ancestors
   825               val (adds, (_, n, _, _)) =
   841               val (adds, (_, n, _, _)) =
   826                 ([], (parents, 0, next_commit_time (), false))
   842                 ([], (parents, 0, next_commit_time (), false))
   827                 |> fold learn_new_fact new_facts
   843                 |> fold learn_new_fact new_facts
   828             in commit true adds []; n end
   844             in commit true adds [] []; n end
   829         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   845         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   830           | relearn_old_fact ((_, (_, status)), th)
   846           | relearn_old_fact ((_, (_, status)), th)
   831                              (reps, (n, next_commit, _)) =
   847                              ((reps, flops), (n, next_commit, _)) =
   832             let
   848             let
   833               val name = nickname_of th
   849               val name = nickname_of th
   834               val (n, reps) =
   850               val (n, reps, flops) =
   835                 case deps_of status th of
   851                 case deps_of status th of
   836                   SOME deps => (n + 1, (name, deps) :: reps)
   852                   SOME deps => (n + 1, (name, deps) :: reps, flops)
   837                 | NONE => (n, reps)
   853                 | NONE => (n, reps, name :: flops)
   838               val (reps, next_commit) =
   854               val (reps, flops, next_commit) =
   839                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   855                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   840                   (commit false [] reps; ([], next_commit_time ()))
   856                   (commit false [] reps flops; ([], [], next_commit_time ()))
   841                 else
   857                 else
   842                   (reps, next_commit)
   858                   (reps, flops, next_commit)
   843               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   859               val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   844             in (reps, (n, next_commit, timed_out)) end
   860             in ((reps, flops), (n, next_commit, timed_out)) end
   845         val n =
   861         val n =
   846           if not atp orelse null old_facts then
   862           if not atp orelse null old_facts then
   847             n
   863             n
   848           else
   864           else
   849             let
   865             let
   850               val max_isar = 1000 * max_dependencies
   866               val max_isar = 1000 * max_dependencies
   851               fun has_atp_proof th =
   867               fun kind_of_proof th =
   852                 try (Graph.get_node fact_G) (nickname_of th)
   868                 try (Graph.get_node fact_G) (nickname_of th)
   853                 |> the_default false
   869                 |> the_default Isar_Proof
   854               fun priority_of (_, th) =
   870               fun priority_of (_, th) =
   855                 random_range 0 max_isar
   871                 random_range 0 max_isar
   856                 + (if has_atp_proof th then max_isar else 0)
   872                 + (case kind_of_proof th of
       
   873                      Isar_Proof => 0
       
   874                    | ATP_Proof => 2 * max_isar
       
   875                    | Isar_Proof_wegen_ATP_Flop => max_isar)
   857                 - 500 * (th |> isar_dependencies_of all_names
   876                 - 500 * (th |> isar_dependencies_of all_names
   858                             |> Option.map length
   877                             |> Option.map length
   859                             |> the_default max_dependencies)
   878                             |> the_default max_dependencies)
   860               val old_facts =
   879               val old_facts =
   861                 old_facts |> map (`priority_of)
   880                 old_facts |> map (`priority_of)
   862                           |> sort (int_ord o pairself fst)
   881                           |> sort (int_ord o pairself fst)
   863                           |> map snd
   882                           |> map snd
   864               val (reps, (n, _, _)) =
   883               val ((reps, flops), (n, _, _)) =
   865                 ([], (n, next_commit_time (), false))
   884                 (([], []), (n, next_commit_time (), false))
   866                 |> fold relearn_old_fact old_facts
   885                 |> fold relearn_old_fact old_facts
   867             in commit true [] reps; n end
   886             in commit true [] reps flops; n end
   868       in
   887       in
   869         if verbose orelse auto_level < 2 then
   888         if verbose orelse auto_level < 2 then
   870           "Learned " ^ string_of_int n ^ " nontrivial " ^
   889           "Learned " ^ string_of_int n ^ " nontrivial " ^
   871           (if atp then "ATP" else "Isar") ^ " proof" ^ plural_s n ^
   890           (if atp then "ATP" else "Isar") ^ " proof" ^ plural_s n ^
   872           (if verbose then
   891           (if verbose then