src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53148 c898409d8630
parent 53142 966a251efd16
child 53150 5565d1b56f84
equal deleted inserted replaced
53147:8e8941fea278 53148:c898409d8630
    34   val encode_features : (string * real) list -> string
    34   val encode_features : (string * real) list -> string
    35   val extract_suggestions : string -> string * string list
    35   val extract_suggestions : string -> string * string list
    36 
    36 
    37   structure MaSh:
    37   structure MaSh:
    38   sig
    38   sig
    39     val unlearn : Proof.context -> unit
    39     val unlearn : Proof.context -> bool -> unit
    40     val learn :
    40     val learn :
    41       Proof.context -> bool
    41       Proof.context -> bool
    42       -> (string * string list * string list * string list) list -> unit
    42       -> (string * string list * string list * string list) list -> unit
    43     val relearn :
    43     val relearn :
    44       Proof.context -> bool -> (string * string list) list -> unit
    44       Proof.context -> bool -> (string * string list) list -> unit
    47       -> (string * string list * string list * string list) list
    47       -> (string * string list * string list * string list) list
    48          * string list * string list * (string * real) list
    48          * string list * string list * (string * real) list
    49       -> string list
    49       -> string list
    50   end
    50   end
    51 
    51 
    52   val mash_unlearn : Proof.context -> unit
    52   val mash_unlearn : Proof.context -> params -> unit
    53   val nickname_of_thm : thm -> string
    53   val nickname_of_thm : thm -> string
    54   val find_suggested_facts :
    54   val find_suggested_facts :
    55     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    55     Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    56   val mesh_facts :
    56   val mesh_facts :
    57     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    57     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    88     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    88     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    89     -> unit
    89     -> unit
    90   val mash_learn :
    90   val mash_learn :
    91     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    91     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    92   val is_mash_enabled : unit -> bool
    92   val is_mash_enabled : unit -> bool
    93   val mash_can_suggest_facts : Proof.context -> bool
    93   val mash_can_suggest_facts : Proof.context -> bool -> bool
    94   val generous_max_facts : int -> int
    94   val generous_max_facts : int -> int
    95   val mepo_weight : real
    95   val mepo_weight : real
    96   val mash_weight : real
    96   val mash_weight : real
    97   val relevant_facts :
    97   val relevant_facts :
    98     Proof.context -> params -> string -> int -> fact_override -> term list
    98     Proof.context -> params -> string -> int -> fact_override -> term list
    99     -> term -> raw_fact list -> (string * fact list) list
    99     -> term -> raw_fact list -> (string * fact list) list
   100   val kill_learners : Proof.context -> unit
   100   val kill_learners : Proof.context -> params -> unit
   101   val running_learners : unit -> unit
   101   val running_learners : unit -> unit
   102 end;
   102 end;
   103 
   103 
   104 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
   104 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
   105 struct
   105 struct
   251   | _ => ("", [])
   251   | _ => ("", [])
   252 
   252 
   253 structure MaSh =
   253 structure MaSh =
   254 struct
   254 struct
   255 
   255 
   256 fun shutdown ctxt =
   256 fun shutdown ctxt overlord =
   257   run_mash_tool ctxt false [shutdown_server_arg] ([], K "") (K ())
   257   run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ())
   258 
   258 
   259 fun unlearn ctxt =
   259 fun unlearn ctxt overlord =
   260   let val path = mash_model_dir () in
   260   let val path = mash_model_dir () in
   261     trace_msg ctxt (K "MaSh unlearn");
   261     trace_msg ctxt (K "MaSh unlearn");
   262     shutdown ctxt;
   262     shutdown ctxt overlord;
   263     try (File.fold_dir (fn file => fn _ =>
   263     try (File.fold_dir (fn file => fn _ =>
   264                            try File.rm (Path.append path (Path.basic file)))
   264                            try File.rm (Path.append path (Path.basic file)))
   265                        path) NONE;
   265                        path) NONE;
   266     ()
   266     ()
   267   end
   267   end
   357        SOME (unencode_str name, unencode_strs parents,
   357        SOME (unencode_str name, unencode_strs parents,
   358              try proof_kind_of_str kind |> the_default Isar_Proof)
   358              try proof_kind_of_str kind |> the_default Isar_Proof)
   359      | _ => NONE)
   359      | _ => NONE)
   360   | _ => NONE
   360   | _ => NONE
   361 
   361 
   362 fun load _ (state as (true, _)) = state
   362 fun load _ _ (state as (true, _)) = state
   363   | load ctxt _ =
   363   | load ctxt overlord _ =
   364     let val path = mash_state_file () in
   364     let val path = mash_state_file () in
   365       (true,
   365       (true,
   366        case try File.read_lines path of
   366        case try File.read_lines path of
   367          SOME (version' :: node_lines) =>
   367          SOME (version' :: node_lines) =>
   368          let
   368          let
   381                (try_graph ctxt "loading state" Graph.empty (fn () =>
   381                (try_graph ctxt "loading state" Graph.empty (fn () =>
   382                     fold add_node node_lines Graph.empty),
   382                     fold add_node node_lines Graph.empty),
   383                 length node_lines)
   383                 length node_lines)
   384              | LESS =>
   384              | LESS =>
   385                (* can't parse old file *)
   385                (* can't parse old file *)
   386                (MaSh.unlearn ctxt; (Graph.empty, 0))
   386                (MaSh.unlearn ctxt overlord; (Graph.empty, 0))
   387              | GREATER => raise FILE_VERSION_TOO_NEW ()
   387              | GREATER => raise FILE_VERSION_TOO_NEW ()
   388          in
   388          in
   389            trace_msg ctxt (fn () =>
   389            trace_msg ctxt (fn () =>
   390                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   390                "Loaded fact graph (" ^ graph_info access_G ^ ")");
   391            {access_G = access_G, num_known_facts = num_known_facts,
   391            {access_G = access_G, num_known_facts = num_known_facts,
   421 val global_state =
   421 val global_state =
   422   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   422   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   423 
   423 
   424 in
   424 in
   425 
   425 
   426 fun map_state ctxt f =
   426 fun map_state ctxt overlord f =
   427   Synchronized.change global_state (load ctxt ##> (f #> save ctxt))
   427   Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt))
   428   handle FILE_VERSION_TOO_NEW () => ()
   428   handle FILE_VERSION_TOO_NEW () => ()
   429 
   429 
   430 fun peek_state ctxt f =
   430 fun peek_state ctxt overlord f =
   431   Synchronized.change_result global_state
   431   Synchronized.change_result global_state
   432       (perhaps (try (load ctxt)) #> `snd #>> f)
   432       (perhaps (try (load ctxt overlord)) #> `snd #>> f)
   433 
   433 
   434 fun clear_state ctxt =
   434 fun clear_state ctxt overlord =
   435   Synchronized.change global_state (fn _ =>
   435   Synchronized.change global_state (fn _ =>
   436       (MaSh.unlearn ctxt; (* also removes the state file *)
   436       (MaSh.unlearn ctxt overlord; (* also removes the state file *)
   437        (false, empty_state)))
   437        (false, empty_state)))
   438 
   438 
   439 end
   439 end
   440 
   440 
   441 val mash_unlearn = clear_state
   441 fun mash_unlearn ctxt ({overlord, ...} : params) = clear_state ctxt overlord
   442 
   442 
   443 
   443 
   444 (*** Isabelle helpers ***)
   444 (*** Isabelle helpers ***)
   445 
   445 
   446 val local_prefix = "local" ^ Long_Name.separator
   446 val local_prefix = "local" ^ Long_Name.separator
   587   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   587   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   588 
   588 
   589 fun maybe_singleton_str _ "" = []
   589 fun maybe_singleton_str _ "" = []
   590   | maybe_singleton_str pref s = [pref ^ s]
   590   | maybe_singleton_str pref s = [pref ^ s]
   591 
   591 
   592 val max_pat_breadth = 10
   592 val max_pat_breadth = 10 (* FUDGE *)
   593 
   593 
   594 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth
   594 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth
   595                      type_max_depth ts =
   595                      type_max_depth ts =
   596   let
   596   let
   597     val thy = Proof_Context.theory_of ctxt
   597     val thy = Proof_Context.theory_of ctxt
   639       16.0 +
   639       16.0 +
   640       (if num_facts = 0 then
   640       (if num_facts = 0 then
   641          0.0
   641          0.0
   642        else
   642        else
   643          let val count = Symtab.lookup const_tab s |> the_default 1 in
   643          let val count = Symtab.lookup const_tab s |> the_default 1 in
   644            (Real.fromInt num_facts / Real.fromInt count) (* FUDGE *)
   644            Real.fromInt num_facts / Real.fromInt count (* FUDGE *)
   645          end)
   645          end)
   646     fun pattify_term _ _ 0 _ = []
   646     fun pattify_term _ _ 0 _ = []
   647       | pattify_term _ args _ (Const (x as (s, _))) =
   647       | pattify_term _ args _ (Const (x as (s, _))) =
   648         if fst (is_built_in x args) then []
   648         if fst (is_built_in x args) then []
   649         else [(massage_long_name s, weight_of_const s)]
   649         else [(massage_long_name s, weight_of_const s)]
   904 
   904 
   905 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   905 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   906 
   906 
   907 val chained_feature_factor = 0.5
   907 val chained_feature_factor = 0.5
   908 val extra_feature_factor = 0.1
   908 val extra_feature_factor = 0.1
   909 val num_extra_feature_facts = 10 (* FUDGE *)
   909 val num_extra_feature_facts = 0 (* FUDGE *)
   910 
   910 
   911 (* FUDGE *)
   911 (* FUDGE *)
   912 fun weight_of_proximity_fact rank =
   912 fun weight_of_proximity_fact rank =
   913   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   913   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   914 
   914 
   956     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
   956     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
   957       [prop_of th]
   957       [prop_of th]
   958       |> features_of ctxt prover (theory_of_thm th) num_facts const_tab stature
   958       |> features_of ctxt prover (theory_of_thm th) num_facts const_tab stature
   959       |> map (apsnd (fn r => weight * factor * r))
   959       |> map (apsnd (fn r => weight * factor * r))
   960     val (access_G, suggs) =
   960     val (access_G, suggs) =
   961       peek_state ctxt (fn {access_G, ...} =>
   961       peek_state ctxt overlord (fn {access_G, ...} =>
   962           if Graph.is_empty access_G then
   962           if Graph.is_empty access_G then
   963             (access_G, [])
   963             (access_G, [])
   964           else
   964           else
   965             let
   965             let
   966               val parents = maximal_wrt_access_graph access_G facts
   966               val parents = maximal_wrt_access_graph access_G facts
   972                 |> map (rpair 1.0)
   972                 |> map (rpair 1.0)
   973                 |> map (chained_or_extra_features_of chained_feature_factor)
   973                 |> map (chained_or_extra_features_of chained_feature_factor)
   974                 |> rpair [] |-> fold (union (op = o pairself fst))
   974                 |> rpair [] |-> fold (union (op = o pairself fst))
   975               val extra_feats =
   975               val extra_feats =
   976                 facts
   976                 facts
   977                 |> take (num_extra_feature_facts - length chained)
   977                 |> take (Int.max (0, num_extra_feature_facts - length chained))
   978                 |> weight_facts_steeply
   978                 |> weight_facts_steeply
   979                 |> map (chained_or_extra_features_of extra_feature_factor)
   979                 |> map (chained_or_extra_features_of extra_feature_factor)
   980                 |> rpair [] |-> fold (union (op = o pairself fst))
   980                 |> rpair [] |-> fold (union (op = o pairself fst))
   981               val feats =
   981               val feats =
   982                 fold (union (op = o pairself fst)) [chained_feats, extra_feats]
   982                 fold (union (op = o pairself fst)) [chained_feats, extra_feats]
  1034         val name = freshish_name ()
  1034         val name = freshish_name ()
  1035         val feats =
  1035         val feats =
  1036           features_of ctxt prover thy 0 Symtab.empty (Local, General) [t]
  1036           features_of ctxt prover thy 0 Symtab.empty (Local, General) [t]
  1037           |> map fst
  1037           |> map fst
  1038       in
  1038       in
  1039         peek_state ctxt (fn {access_G, ...} =>
  1039         peek_state ctxt overlord (fn {access_G, ...} =>
  1040             let
  1040             let
  1041               val parents = maximal_wrt_access_graph access_G facts
  1041               val parents = maximal_wrt_access_graph access_G facts
  1042               val deps =
  1042               val deps =
  1043                 used_ths |> filter (is_fact_in_graph access_G)
  1043                 used_ths |> filter (is_fact_in_graph access_G)
  1044                          |> map nickname_of_thm
  1044                          |> map nickname_of_thm
  1058                      auto_level run_prover learn_timeout facts =
  1058                      auto_level run_prover learn_timeout facts =
  1059   let
  1059   let
  1060     val timer = Timer.startRealTimer ()
  1060     val timer = Timer.startRealTimer ()
  1061     fun next_commit_time () =
  1061     fun next_commit_time () =
  1062       Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1062       Time.+ (Timer.checkRealTimer timer, commit_timeout)
  1063     val {access_G, ...} = peek_state ctxt I
  1063     val {access_G, ...} = peek_state ctxt overlord I
  1064     val is_in_access_G = is_fact_in_graph access_G o snd
  1064     val is_in_access_G = is_fact_in_graph access_G o snd
  1065     val no_new_facts = forall is_in_access_G facts
  1065     val no_new_facts = forall is_in_access_G facts
  1066   in
  1066   in
  1067     if no_new_facts andalso not run_prover then
  1067     if no_new_facts andalso not run_prover then
  1068       if auto_level < 2 then
  1068       if auto_level < 2 then
  1112         fun commit last learns relearns flops =
  1112         fun commit last learns relearns flops =
  1113           (if debug andalso auto_level = 0 then
  1113           (if debug andalso auto_level = 0 then
  1114              Output.urgent_message "Committing..."
  1114              Output.urgent_message "Committing..."
  1115            else
  1115            else
  1116              ();
  1116              ();
  1117            map_state ctxt (do_commit (rev learns) relearns flops);
  1117            map_state ctxt overlord (do_commit (rev learns) relearns flops);
  1118            if not last andalso auto_level = 0 then
  1118            if not last andalso auto_level = 0 then
  1119              let val num_proofs = length learns + length relearns in
  1119              let val num_proofs = length learns + length relearns in
  1120                "Learned " ^ string_of_int num_proofs ^ " " ^
  1120                "Learned " ^ string_of_int num_proofs ^ " " ^
  1121                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1121                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1122                plural_s num_proofs ^ " in the last " ^
  1122                plural_s num_proofs ^ " in the last " ^
  1248        |> Output.urgent_message;
  1248        |> Output.urgent_message;
  1249        learn 0 false)
  1249        learn 0 false)
  1250   end
  1250   end
  1251 
  1251 
  1252 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1252 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1253 fun mash_can_suggest_facts ctxt =
  1253 fun mash_can_suggest_facts ctxt overlord =
  1254   not (Graph.is_empty (#access_G (peek_state ctxt I)))
  1254   not (Graph.is_empty (#access_G (peek_state ctxt overlord I)))
  1255 
  1255 
  1256 (* Generate more suggestions than requested, because some might be thrown out
  1256 (* Generate more suggestions than requested, because some might be thrown out
  1257    later for various reasons. *)
  1257    later for various reasons. *)
  1258 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1258 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
  1259 
  1259 
  1262 
  1262 
  1263 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1263 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1264    Sledgehammer and Try. *)
  1264    Sledgehammer and Try. *)
  1265 val min_secs_for_learning = 15
  1265 val min_secs_for_learning = 15
  1266 
  1266 
  1267 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1267 fun relevant_facts ctxt (params as {overlord, learn, fact_filter, timeout, ...})
  1268         max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
  1268         prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t
       
  1269         facts =
  1269   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1270   if not (subset (op =) (the_list fact_filter, fact_filters)) then
  1270     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1271     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
  1271   else if only then
  1272   else if only then
  1272     let val facts = facts |> map fact_of_raw_fact in
  1273     let val facts = facts |> map fact_of_raw_fact in
  1273       [("", facts)]
  1274       [("", facts)]
  1293         case fact_filter of
  1294         case fact_filter of
  1294           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
  1295           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
  1295         | NONE =>
  1296         | NONE =>
  1296           if is_mash_enabled () then
  1297           if is_mash_enabled () then
  1297             (maybe_learn ();
  1298             (maybe_learn ();
  1298              if mash_can_suggest_facts ctxt then meshN else mepoN)
  1299              if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
  1299           else
  1300           else
  1300             mepoN
  1301             mepoN
  1301       val add_ths = Attrib.eval_thms ctxt add
  1302       val add_ths = Attrib.eval_thms ctxt add
  1302       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
  1303       fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
  1303       fun add_and_take accepts =
  1304       fun add_and_take accepts =
  1333         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
  1334         [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
  1334          (mashN, mash |> map fst |> add_and_take)]
  1335          (mashN, mash |> map fst |> add_and_take)]
  1335       | _ => [("", mesh)]
  1336       | _ => [("", mesh)]
  1336     end
  1337     end
  1337 
  1338 
  1338 fun kill_learners ctxt =
  1339 fun kill_learners ctxt ({overlord, ...} : params) =
  1339   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt)
  1340   (Async_Manager.kill_threads MaShN "learner"; MaSh.shutdown ctxt overlord)
  1340 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1341 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1341 
  1342 
  1342 end;
  1343 end;