made MaSh faster and less likely to hang seemingly forever
authorblanchet
Wed Nov 23 20:55:58 2016 +0100 (2016-11-23)
changeset 64522b66f8caf86b6
parent 64521 1aef5a0e18d7
child 64523 49a29161d8ef
made MaSh faster and less likely to hang seemingly forever
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/TPTP/mash_export.ML	Wed Nov 23 16:15:17 2016 +0100
     1.2 +++ b/src/HOL/TPTP/mash_export.ML	Wed Nov 23 20:55:58 2016 +0100
     1.3 @@ -286,16 +286,16 @@
     1.4         not (Config.get ctxt Sledgehammer_MaSh.duplicates) ? Sledgehammer_Fact.drop_duplicate_facts
     1.5         #> Sledgehammer_MePo.mepo_suggested_facts ctxt params max_suggs NONE hyp_ts concl_t)
     1.6  
     1.7 -fun generate_mash_suggestions algorithm =
     1.8 +fun generate_mash_suggestions algorithm ctxt =
     1.9    (Options.put_default @{system_option MaSh} algorithm;
    1.10 -   Sledgehammer_MaSh.mash_unlearn ();
    1.11 +   Sledgehammer_MaSh.mash_unlearn ctxt;
    1.12     generate_mepo_or_mash_suggestions
    1.13       (fn ctxt => fn thy_name => fn params as {provers = prover :: _, ...} =>
    1.14            fn max_suggs => fn hyp_ts => fn concl_t =>
    1.15          tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover 2 false
    1.16            Sledgehammer_Util.one_year)
    1.17          #> Sledgehammer_MaSh.mash_suggested_facts ctxt thy_name params max_suggs hyp_ts concl_t
    1.18 -        #> fst))
    1.19 +        #> fst) ctxt)
    1.20  
    1.21  fun generate_mesh_suggestions max_suggs mash_file_name mepo_file_name mesh_file_name =
    1.22    let
     2.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML	Wed Nov 23 16:15:17 2016 +0100
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_commands.ML	Wed Nov 23 20:55:58 2016 +0100
     2.3 @@ -324,10 +324,10 @@
     2.4      else if subcommand = supported_proversN then
     2.5        supported_provers ctxt
     2.6      else if subcommand = unlearnN then
     2.7 -      mash_unlearn ()
     2.8 +      mash_unlearn ctxt
     2.9      else if subcommand = learn_isarN orelse subcommand = learn_proverN orelse
    2.10              subcommand = relearn_isarN orelse subcommand = relearn_proverN then
    2.11 -      (if subcommand = relearn_isarN orelse subcommand = relearn_proverN then mash_unlearn ()
    2.12 +      (if subcommand = relearn_isarN orelse subcommand = relearn_proverN then mash_unlearn ctxt
    2.13         else ();
    2.14         mash_learn ctxt
    2.15             (* TODO: Use MaSh mode instead and have the special defaults hardcoded in "get_params" *)
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Nov 23 16:15:17 2016 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Wed Nov 23 20:55:58 2016 +0100
     3.3 @@ -69,13 +69,13 @@
     3.4    val mash_suggested_facts : Proof.context -> string -> params -> int -> term list -> term ->
     3.5      raw_fact list -> fact list * fact list
     3.6  
     3.7 -  val mash_unlearn : unit -> unit
     3.8 +  val mash_unlearn : Proof.context -> unit
     3.9    val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
    3.10    val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
    3.11      raw_fact list -> string
    3.12    val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
    3.13    val mash_can_suggest_facts : Proof.context -> bool
    3.14 -  val mash_can_suggest_facts_fast : unit -> bool
    3.15 +  val mash_can_suggest_facts_fast : Proof.context -> bool
    3.16  
    3.17    val generous_max_suggestions : int -> int
    3.18    val mepo_weight : real
    3.19 @@ -274,6 +274,18 @@
    3.20    end
    3.21  
    3.22  
    3.23 +(*** Convenience functions for synchronized access ***)
    3.24 +
    3.25 +fun synchronized_timed_value var time_limit =
    3.26 +  Synchronized.timed_access var time_limit (fn value => SOME (value, value))
    3.27 +fun synchronized_timed_change_result var time_limit f =
    3.28 +  Synchronized.timed_access var time_limit (SOME o f)
    3.29 +fun synchronized_timed_change var time_limit f =
    3.30 +  synchronized_timed_change_result var time_limit (fn x => ((), f x))
    3.31 +
    3.32 +fun mash_time_limit _ = SOME (seconds 0.1)
    3.33 +
    3.34 +
    3.35  (*** Isabelle-agnostic machine learning ***)
    3.36  
    3.37  structure MaSh =
    3.38 @@ -640,7 +652,7 @@
    3.39  
    3.40  local
    3.41  
    3.42 -val version = "*** MaSh version 20160805 ***"
    3.43 +val version = "*** MaSh version 20161123 ***"
    3.44  
    3.45  exception FILE_VERSION_TOO_NEW of unit
    3.46  
    3.47 @@ -734,42 +746,49 @@
    3.48  in
    3.49  
    3.50  fun map_state ctxt f =
    3.51 -  Synchronized.change global_state (load_state ctxt ##> f #> save_state ctxt)
    3.52 +  (trace_msg ctxt (fn () => "Changing MaSh state");
    3.53 +   synchronized_timed_change global_state mash_time_limit
    3.54 +     (load_state ctxt ##> f #> save_state ctxt))
    3.55 +  |> ignore
    3.56    handle FILE_VERSION_TOO_NEW () => ()
    3.57  
    3.58 -fun peek_state () =
    3.59 -  let val state = Synchronized.value global_state in
    3.60 -    if would_load_state state then NONE else SOME state
    3.61 -  end
    3.62 +fun peek_state ctxt =
    3.63 +  (trace_msg ctxt (fn () => "Peeking at MaSh state");
    3.64 +   (case synchronized_timed_value global_state mash_time_limit of
    3.65 +     NONE => NONE
    3.66 +   | SOME state => if would_load_state state then NONE else SOME state))
    3.67  
    3.68  fun get_state ctxt =
    3.69 -  Synchronized.change_result global_state (perhaps (try (load_state ctxt)) #> `snd)
    3.70 +  (trace_msg ctxt (fn () => "Retrieving MaSh state");
    3.71 +   synchronized_timed_change_result global_state mash_time_limit
    3.72 +     (perhaps (try (load_state ctxt)) #> `snd))
    3.73  
    3.74 -fun clear_state () =
    3.75 -  Synchronized.change global_state (fn _ => (remove_state_file (); (Time.zeroTime, empty_state)))
    3.76 +fun clear_state ctxt =
    3.77 +  (trace_msg ctxt (fn () => "Clearing MaSh state");
    3.78 +   Synchronized.change global_state (fn _ => (remove_state_file (); (Time.zeroTime, empty_state))))
    3.79  
    3.80  end
    3.81  
    3.82  
    3.83  (*** Isabelle helpers ***)
    3.84  
    3.85 -fun crude_printed_term depth t =
    3.86 +fun crude_printed_term size t =
    3.87    let
    3.88      fun term _ (res, 0) = (res, 0)
    3.89 -      | term (t $ u) (res, depth) =
    3.90 +      | term (t $ u) (res, size) =
    3.91          let
    3.92 -          val (res, depth) = term t (res ^ "(", depth)
    3.93 -          val (res, depth) = term u (res ^ " ", depth)
    3.94 +          val (res, size) = term t (res ^ "(", size)
    3.95 +          val (res, size) = term u (res ^ " ", size)
    3.96          in
    3.97 -          (res ^ ")", depth)
    3.98 +          (res ^ ")", size)
    3.99          end
   3.100 -      | term (Abs (s, _, t)) (res, depth) = term t (res ^ "%" ^ s ^ ".", depth - 1)
   3.101 -      | term (Bound n) (res, depth) = (res ^ "#" ^ string_of_int n, depth - 1)
   3.102 -      | term (Const (s, _)) (res, depth) = (res ^ Long_Name.base_name s, depth - 1)
   3.103 -      | term (Free (s, _)) (res, depth) = (res ^ s, depth - 1)
   3.104 -      | term (Var ((s, _), _)) (res, depth) = (res ^ s, depth - 1)
   3.105 +      | term (Abs (s, _, t)) (res, size) = term t (res ^ "%" ^ s ^ ".", size - 1)
   3.106 +      | term (Bound n) (res, size) = (res ^ "#" ^ string_of_int n, size - 1)
   3.107 +      | term (Const (s, _)) (res, size) = (res ^ Long_Name.base_name s, size - 1)
   3.108 +      | term (Free (s, _)) (res, size) = (res ^ s, size - 1)
   3.109 +      | term (Var ((s, _), _)) (res, size) = (res ^ s, size - 1)
   3.110    in
   3.111 -    fst (term t ("", depth))
   3.112 +    fst (term t ("", size))
   3.113    end
   3.114  
   3.115  fun nickname_of_thm th =
   3.116 @@ -778,11 +797,11 @@
   3.117        (* There must be a better way to detect local facts. *)
   3.118        (case Long_Name.dest_local hint of
   3.119          SOME suf =>
   3.120 -        Long_Name.implode [Thm.theory_name_of_thm th, suf, crude_printed_term 50 (Thm.prop_of th)]
   3.121 +        Long_Name.implode [Thm.theory_name_of_thm th, suf, crude_printed_term 25 (Thm.prop_of th)]
   3.122        | NONE => hint)
   3.123      end
   3.124    else
   3.125 -    crude_printed_term 100 (Thm.prop_of th)
   3.126 +    crude_printed_term 50 (Thm.prop_of th)
   3.127  
   3.128  fun find_suggested_facts ctxt facts =
   3.129    let
   3.130 @@ -857,23 +876,17 @@
   3.131  
   3.132  val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   3.133  
   3.134 -val pat_tvar_prefix = "_"
   3.135 -val pat_var_prefix = "_"
   3.136 +val crude_str_of_sort = space_implode "," o map Long_Name.base_name o subtract (op =) @{sort type}
   3.137  
   3.138 -(* try "Long_Name.base_name" for shorter names *)
   3.139 -fun massage_long_name s = s
   3.140 -
   3.141 -val crude_str_of_sort = space_implode "," o map massage_long_name o subtract (op =) @{sort type}
   3.142 -
   3.143 -fun crude_str_of_typ (Type (s, [])) = massage_long_name s
   3.144 -  | crude_str_of_typ (Type (s, Ts)) = massage_long_name s ^ implode (map crude_str_of_typ Ts)
   3.145 +fun crude_str_of_typ (Type (s, [])) = Long_Name.base_name s
   3.146 +  | crude_str_of_typ (Type (s, Ts)) = Long_Name.base_name s ^ implode (map crude_str_of_typ Ts)
   3.147    | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
   3.148    | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   3.149  
   3.150 -fun maybe_singleton_str _ "" = []
   3.151 -  | maybe_singleton_str pref s = [pref ^ s]
   3.152 +fun maybe_singleton_str "" = []
   3.153 +  | maybe_singleton_str s = [s]
   3.154  
   3.155 -val max_pat_breadth = 10 (* FUDGE *)
   3.156 +val max_pat_breadth = 5 (* FUDGE *)
   3.157  
   3.158  fun term_features_of ctxt thy_name term_max_depth type_max_depth ts =
   3.159    let
   3.160 @@ -886,13 +899,12 @@
   3.161        | add_classes S =
   3.162          fold (`(Sorts.super_classes classes)
   3.163            #> swap #> op ::
   3.164 -          #> subtract (op =) @{sort type} #> map massage_long_name
   3.165 +          #> subtract (op =) @{sort type}
   3.166            #> map class_feature_of
   3.167            #> union (op =)) S
   3.168  
   3.169      fun pattify_type 0 _ = []
   3.170 -      | pattify_type _ (Type (s, [])) =
   3.171 -        if member (op =) bad_types s then [] else [massage_long_name s]
   3.172 +      | pattify_type depth (Type (s, [])) = if member (op =) bad_types s then [] else [s]
   3.173        | pattify_type depth (Type (s, U :: Ts)) =
   3.174          let
   3.175            val T = Type (s, Ts)
   3.176 @@ -901,8 +913,8 @@
   3.177          in
   3.178            map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs
   3.179          end
   3.180 -      | pattify_type _ (TFree (_, S)) = maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   3.181 -      | pattify_type _ (TVar (_, S)) = maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   3.182 +      | pattify_type _ (TFree (_, S)) = maybe_singleton_str (crude_str_of_sort S)
   3.183 +      | pattify_type _ (TVar (_, S)) = maybe_singleton_str (crude_str_of_sort S)
   3.184  
   3.185      fun add_type_pat depth T =
   3.186        union (op =) (map type_feature_of (pattify_type depth T))
   3.187 @@ -918,17 +930,16 @@
   3.188        | add_subtypes T = add_type T
   3.189  
   3.190      fun pattify_term _ 0 _ = []
   3.191 -      | pattify_term _ _ (Const (s, _)) =
   3.192 -        if is_widely_irrelevant_const s then [] else [massage_long_name s]
   3.193 +      | pattify_term _ depth (Const (s, _)) =
   3.194 +        if is_widely_irrelevant_const s then [] else [s]
   3.195        | pattify_term _ _ (Free (s, T)) =
   3.196 -        maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
   3.197 -        |> (if member (op =) fixes s then
   3.198 -              cons (free_feature_of (massage_long_name (Long_Name.append thy_name s)))
   3.199 -            else
   3.200 -              I)
   3.201 -      | pattify_term _ _ (Var (_, T)) = maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
   3.202 +        maybe_singleton_str (crude_str_of_typ T)
   3.203 +        |> (if member (op =) fixes s then cons (free_feature_of (Long_Name.append thy_name s))
   3.204 +            else I)
   3.205 +      | pattify_term _ _ (Var (_, T)) =
   3.206 +        maybe_singleton_str (crude_str_of_typ T)
   3.207        | pattify_term Ts _ (Bound j) =
   3.208 -        maybe_singleton_str pat_var_prefix (crude_str_of_typ (nth Ts j))
   3.209 +        maybe_singleton_str (crude_str_of_typ (nth Ts j))
   3.210        | pattify_term Ts depth (t $ u) =
   3.211          let
   3.212            val ps = take max_pat_breadth (pattify_term Ts depth t)
   3.213 @@ -972,9 +983,9 @@
   3.214  
   3.215  (* Too many dependencies is a sign that a decision procedure is at work. There is not much to learn
   3.216     from such proofs. *)
   3.217 -val max_dependencies = 20
   3.218 +val max_dependencies = 20 (* FUDGE *)
   3.219  
   3.220 -val prover_default_max_facts = 25
   3.221 +val prover_default_max_facts = 25 (* FUDGE *)
   3.222  
   3.223  (* "type_definition_xxx" facts are characterized by their use of "CollectI". *)
   3.224  val typedef_dep = nickname_of_thm @{thm CollectI}
   3.225 @@ -1161,7 +1172,7 @@
   3.226  fun maximal_wrt_access_graph _ [] = []
   3.227    | maximal_wrt_access_graph access_G (fact :: facts) =
   3.228      let
   3.229 -      fun cleanup_wrt (fact as (_, th)) =
   3.230 +      fun cleanup_wrt (_, th) =
   3.231          let val thy_id = Thm.theory_id_of_thm th in
   3.232            filter_out (fn (_, th') =>
   3.233              Context.proper_subthy_id (Thm.theory_id_of_thm th', thy_id))
   3.234 @@ -1224,54 +1235,57 @@
   3.235        [Thm.prop_of th]
   3.236        |> features_of ctxt (Thm.theory_name_of_thm th) stature
   3.237        |> map (rpair (weight * factor))
   3.238 -
   3.239 -    val {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} =
   3.240 -      get_state ctxt
   3.241 +  in
   3.242 +    (case get_state ctxt of
   3.243 +      NONE => ([], [])
   3.244 +    | SOME {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} =>
   3.245 +      let
   3.246 +        val goal_feats0 =
   3.247 +          features_of ctxt thy_name (Local, General) (concl_t :: hyp_ts)
   3.248 +        val chained_feats = chained
   3.249 +          |> map (rpair 1.0)
   3.250 +          |> map (chained_or_extra_features_of chained_feature_factor)
   3.251 +          |> rpair [] |-> fold (union (eq_fst (op =)))
   3.252 +        val extra_feats = facts
   3.253 +          |> take (Int.max (0, num_extra_feature_facts - length chained))
   3.254 +          |> filter fact_has_right_theory
   3.255 +          |> weight_facts_steeply
   3.256 +          |> map (chained_or_extra_features_of extra_feature_factor)
   3.257 +          |> rpair [] |-> fold (union (eq_fst (op =)))
   3.258  
   3.259 -    val goal_feats0 =
   3.260 -      features_of ctxt thy_name (Local, General) (concl_t :: hyp_ts)
   3.261 -    val chained_feats = chained
   3.262 -      |> map (rpair 1.0)
   3.263 -      |> map (chained_or_extra_features_of chained_feature_factor)
   3.264 -      |> rpair [] |-> fold (union (eq_fst (op =)))
   3.265 -    val extra_feats = facts
   3.266 -      |> take (Int.max (0, num_extra_feature_facts - length chained))
   3.267 -      |> filter fact_has_right_theory
   3.268 -      |> weight_facts_steeply
   3.269 -      |> map (chained_or_extra_features_of extra_feature_factor)
   3.270 -      |> rpair [] |-> fold (union (eq_fst (op =)))
   3.271 -
   3.272 -    val goal_feats =
   3.273 -      fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
   3.274 -      |> debug ? sort (Real.compare o swap o apply2 snd)
   3.275 +        val goal_feats =
   3.276 +          fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0)
   3.277 +          |> debug ? sort (Real.compare o swap o apply2 snd)
   3.278  
   3.279 -    val parents = maximal_wrt_access_graph access_G facts
   3.280 -    val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
   3.281 +        val parents = maximal_wrt_access_graph access_G facts
   3.282 +        val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
   3.283  
   3.284 -    val suggs =
   3.285 -      if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then
   3.286 -        let
   3.287 -          val learns =
   3.288 -            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
   3.289 -        in
   3.290 -          MaSh.query_external ctxt algorithm max_suggs learns goal_feats
   3.291 -        end
   3.292 -      else
   3.293 -        let
   3.294 -          val int_goal_feats =
   3.295 -            map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
   3.296 -        in
   3.297 -          MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs visible_facts max_suggs
   3.298 -            goal_feats int_goal_feats
   3.299 -        end
   3.300 +        val suggs =
   3.301 +          if algorithm = MaSh_NB_Ext orelse algorithm = MaSh_kNN_Ext then
   3.302 +            let
   3.303 +              val learns =
   3.304 +                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps))
   3.305 +                  access_G
   3.306 +            in
   3.307 +              MaSh.query_external ctxt algorithm max_suggs learns goal_feats
   3.308 +            end
   3.309 +          else
   3.310 +            let
   3.311 +              val int_goal_feats =
   3.312 +                map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
   3.313 +            in
   3.314 +              MaSh.query_internal ctxt algorithm num_facts num_feats ffds freqs visible_facts
   3.315 +                max_suggs goal_feats int_goal_feats
   3.316 +            end
   3.317  
   3.318 -    val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   3.319 -  in
   3.320 -    find_mash_suggestions ctxt max_suggs suggs facts chained unknown
   3.321 -    |> apply2 (map fact_of_raw_fact)
   3.322 +        val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   3.323 +      in
   3.324 +        find_mash_suggestions ctxt max_suggs suggs facts chained unknown
   3.325 +        |> apply2 (map fact_of_raw_fact)
   3.326 +      end)
   3.327    end
   3.328  
   3.329 -fun mash_unlearn () = (clear_state (); writeln "Reset MaSh")
   3.330 +fun mash_unlearn ctxt = (clear_state ctxt; writeln "Reset MaSh")
   3.331  
   3.332  fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
   3.333      (accum as (access_G, (fact_xtab, feat_xtab))) =
   3.334 @@ -1363,164 +1377,169 @@
   3.335    let
   3.336      val timer = Timer.startRealTimer ()
   3.337      fun next_commit_time () = Timer.checkRealTimer timer + commit_timeout
   3.338 -
   3.339 -    val {access_G, ...} = get_state ctxt
   3.340 -    val is_in_access_G = is_fact_in_graph access_G o snd
   3.341 -    val no_new_facts = forall is_in_access_G facts
   3.342    in
   3.343 -    if no_new_facts andalso not run_prover then
   3.344 -      if auto_level < 2 then
   3.345 -        "No new " ^ (if run_prover then "automatic" else "Isar") ^ " proofs to learn" ^
   3.346 -        (if auto_level = 0 andalso not run_prover then
   3.347 -           "\n\nHint: Try " ^ sendback learn_proverN ^ " to learn from an automatic prover"
   3.348 -         else
   3.349 -           "")
   3.350 -      else
   3.351 -        ""
   3.352 -    else
   3.353 +    (case get_state ctxt of
   3.354 +      NONE => "MaSh is busy\nPlease try again later"
   3.355 +    | SOME {access_G, ...} =>
   3.356        let
   3.357 -        val name_tabs = build_name_tables nickname_of_thm facts
   3.358 +        val is_in_access_G = is_fact_in_graph access_G o snd
   3.359 +        val no_new_facts = forall is_in_access_G facts
   3.360 +      in
   3.361 +        if no_new_facts andalso not run_prover then
   3.362 +          if auto_level < 2 then
   3.363 +            "No new " ^ (if run_prover then "automatic" else "Isar") ^ " proofs to learn" ^
   3.364 +            (if auto_level = 0 andalso not run_prover then
   3.365 +               "\n\nHint: Try " ^ sendback learn_proverN ^ " to learn from an automatic prover"
   3.366 +             else
   3.367 +               "")
   3.368 +          else
   3.369 +            ""
   3.370 +        else
   3.371 +          let
   3.372 +            val name_tabs = build_name_tables nickname_of_thm facts
   3.373  
   3.374 -        fun deps_of status th =
   3.375 -          if status = Non_Rec_Def orelse status = Rec_Def then
   3.376 -            SOME []
   3.377 -          else if run_prover then
   3.378 -            prover_dependencies_of ctxt params prover auto_level facts name_tabs th
   3.379 -            |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
   3.380 -          else
   3.381 -            isar_dependencies_of name_tabs th
   3.382 +            fun deps_of status th =
   3.383 +              if status = Non_Rec_Def orelse status = Rec_Def then
   3.384 +                SOME []
   3.385 +              else if run_prover then
   3.386 +                prover_dependencies_of ctxt params prover auto_level facts name_tabs th
   3.387 +                |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
   3.388 +              else
   3.389 +                isar_dependencies_of name_tabs th
   3.390  
   3.391 -        fun do_commit [] [] [] state = state
   3.392 -          | do_commit learns relearns flops
   3.393 -              {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =
   3.394 -            let
   3.395 -              val was_empty = Graph.is_empty access_G
   3.396 +            fun do_commit [] [] [] state = state
   3.397 +              | do_commit learns relearns flops
   3.398 +                  {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =
   3.399 +                let
   3.400 +                  val was_empty = Graph.is_empty access_G
   3.401  
   3.402 -              val (learns, (access_G', xtabs')) =
   3.403 -                fold_map (learn_wrt_access_graph ctxt) learns (access_G, xtabs)
   3.404 -                |>> map_filter I
   3.405 -              val (relearns, access_G'') =
   3.406 -                fold_map (relearn_wrt_access_graph ctxt) relearns access_G'
   3.407 +                  val (learns, (access_G', xtabs')) =
   3.408 +                    fold_map (learn_wrt_access_graph ctxt) learns (access_G, xtabs)
   3.409 +                    |>> map_filter I
   3.410 +                  val (relearns, access_G'') =
   3.411 +                    fold_map (relearn_wrt_access_graph ctxt) relearns access_G'
   3.412  
   3.413 -              val access_G''' = access_G'' |> fold flop_wrt_access_graph flops
   3.414 -              val dirty_facts' =
   3.415 -                (case (was_empty, dirty_facts) of
   3.416 -                  (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
   3.417 -                | _ => NONE)
   3.418 +                  val access_G''' = access_G'' |> fold flop_wrt_access_graph flops
   3.419 +                  val dirty_facts' =
   3.420 +                    (case (was_empty, dirty_facts) of
   3.421 +                      (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
   3.422 +                    | _ => NONE)
   3.423  
   3.424 -              val (ffds', freqs') =
   3.425 -                if null relearns then
   3.426 -                  recompute_ffds_freqs_from_learns
   3.427 -                    (map (fn (name, _, feats, deps) => (name, feats, deps)) learns) xtabs' num_facts0
   3.428 -                    ffds freqs
   3.429 -                else
   3.430 -                  recompute_ffds_freqs_from_access_G access_G''' xtabs'
   3.431 -            in
   3.432 -              {access_G = access_G''', xtabs = xtabs', ffds = ffds', freqs = freqs',
   3.433 -               dirty_facts = dirty_facts'}
   3.434 -            end
   3.435 +                  val (ffds', freqs') =
   3.436 +                    if null relearns then
   3.437 +                      recompute_ffds_freqs_from_learns
   3.438 +                        (map (fn (name, _, feats, deps) => (name, feats, deps)) learns) xtabs'
   3.439 +                        num_facts0 ffds freqs
   3.440 +                    else
   3.441 +                      recompute_ffds_freqs_from_access_G access_G''' xtabs'
   3.442 +                in
   3.443 +                  {access_G = access_G''', xtabs = xtabs', ffds = ffds', freqs = freqs',
   3.444 +                   dirty_facts = dirty_facts'}
   3.445 +                end
   3.446  
   3.447 -        fun commit last learns relearns flops =
   3.448 -          (if debug andalso auto_level = 0 then writeln "Committing..." else ();
   3.449 -           map_state ctxt (do_commit (rev learns) relearns flops);
   3.450 -           if not last andalso auto_level = 0 then
   3.451 -             let val num_proofs = length learns + length relearns in
   3.452 -               writeln ("Learned " ^ string_of_int num_proofs ^ " " ^
   3.453 -                 (if run_prover then "automatic" else "Isar") ^ " proof" ^
   3.454 -                 plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout)
   3.455 -             end
   3.456 -           else
   3.457 -             ())
   3.458 +            fun commit last learns relearns flops =
   3.459 +              (if debug andalso auto_level = 0 then writeln "Committing..." else ();
   3.460 +               map_state ctxt (do_commit (rev learns) relearns flops);
   3.461 +               if not last andalso auto_level = 0 then
   3.462 +                 let val num_proofs = length learns + length relearns in
   3.463 +                   writeln ("Learned " ^ string_of_int num_proofs ^ " " ^
   3.464 +                     (if run_prover then "automatic" else "Isar") ^ " proof" ^
   3.465 +                     plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout)
   3.466 +                 end
   3.467 +               else
   3.468 +                 ())
   3.469  
   3.470 -        fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
   3.471 -          | learn_new_fact (parents, ((_, stature as (_, status)), th))
   3.472 -              (learns, (num_nontrivial, next_commit, _)) =
   3.473 -            let
   3.474 -              val name = nickname_of_thm th
   3.475 -              val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th]
   3.476 -              val deps = these (deps_of status th)
   3.477 -              val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1
   3.478 -              val learns = (name, parents, feats, deps) :: learns
   3.479 -              val (learns, next_commit) =
   3.480 -                if Timer.checkRealTimer timer > next_commit then
   3.481 -                  (commit false learns [] []; ([], next_commit_time ()))
   3.482 -                else
   3.483 -                  (learns, next_commit)
   3.484 -              val timed_out = Timer.checkRealTimer timer > learn_timeout
   3.485 -            in
   3.486 -              (learns, (num_nontrivial, next_commit, timed_out))
   3.487 -            end
   3.488 +            fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
   3.489 +              | learn_new_fact (parents, ((_, stature as (_, status)), th))
   3.490 +                  (learns, (num_nontrivial, next_commit, _)) =
   3.491 +                let
   3.492 +                  val name = nickname_of_thm th
   3.493 +                  val feats = features_of ctxt (Thm.theory_name_of_thm th) stature [Thm.prop_of th]
   3.494 +                  val deps = these (deps_of status th)
   3.495 +                  val num_nontrivial = num_nontrivial |> not (null deps) ? Integer.add 1
   3.496 +                  val learns = (name, parents, feats, deps) :: learns
   3.497 +                  val (learns, next_commit) =
   3.498 +                    if Timer.checkRealTimer timer > next_commit then
   3.499 +                      (commit false learns [] []; ([], next_commit_time ()))
   3.500 +                    else
   3.501 +                      (learns, next_commit)
   3.502 +                  val timed_out = Timer.checkRealTimer timer > learn_timeout
   3.503 +                in
   3.504 +                  (learns, (num_nontrivial, next_commit, timed_out))
   3.505 +                end
   3.506  
   3.507 -        val (num_new_facts, num_nontrivial) =
   3.508 -          if no_new_facts then
   3.509 -            (0, 0)
   3.510 -          else
   3.511 -            let
   3.512 -              val new_facts = facts
   3.513 -                |> sort (crude_thm_ord ctxt o apply2 snd)
   3.514 -                |> attach_parents_to_facts []
   3.515 -                |> filter_out (is_in_access_G o snd)
   3.516 -              val (learns, (num_nontrivial, _, _)) =
   3.517 -                ([], (0, next_commit_time (), false))
   3.518 -                |> fold learn_new_fact new_facts
   3.519 -            in
   3.520 -              commit true learns [] []; (length new_facts, num_nontrivial)
   3.521 -            end
   3.522 +            val (num_new_facts, num_nontrivial) =
   3.523 +              if no_new_facts then
   3.524 +                (0, 0)
   3.525 +              else
   3.526 +                let
   3.527 +                  val new_facts = facts
   3.528 +                    |> sort (crude_thm_ord ctxt o apply2 snd)
   3.529 +                    |> attach_parents_to_facts []
   3.530 +                    |> filter_out (is_in_access_G o snd)
   3.531 +                  val (learns, (num_nontrivial, _, _)) =
   3.532 +                    ([], (0, next_commit_time (), false))
   3.533 +                    |> fold learn_new_fact new_facts
   3.534 +                in
   3.535 +                  commit true learns [] []; (length new_facts, num_nontrivial)
   3.536 +                end
   3.537  
   3.538 -        fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   3.539 -          | relearn_old_fact ((_, (_, status)), th)
   3.540 -              ((relearns, flops), (num_nontrivial, next_commit, _)) =
   3.541 -            let
   3.542 -              val name = nickname_of_thm th
   3.543 -              val (num_nontrivial, relearns, flops) =
   3.544 -                (case deps_of status th of
   3.545 -                  SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops)
   3.546 -                | NONE => (num_nontrivial, relearns, name :: flops))
   3.547 -              val (relearns, flops, next_commit) =
   3.548 -                if Timer.checkRealTimer timer > next_commit then
   3.549 -                  (commit false [] relearns flops; ([], [], next_commit_time ()))
   3.550 -                else
   3.551 -                  (relearns, flops, next_commit)
   3.552 -              val timed_out = Timer.checkRealTimer timer > learn_timeout
   3.553 -            in
   3.554 -              ((relearns, flops), (num_nontrivial, next_commit, timed_out))
   3.555 -            end
   3.556 +            fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   3.557 +              | relearn_old_fact ((_, (_, status)), th)
   3.558 +                  ((relearns, flops), (num_nontrivial, next_commit, _)) =
   3.559 +                let
   3.560 +                  val name = nickname_of_thm th
   3.561 +                  val (num_nontrivial, relearns, flops) =
   3.562 +                    (case deps_of status th of
   3.563 +                      SOME deps => (num_nontrivial + 1, (name, deps) :: relearns, flops)
   3.564 +                    | NONE => (num_nontrivial, relearns, name :: flops))
   3.565 +                  val (relearns, flops, next_commit) =
   3.566 +                    if Timer.checkRealTimer timer > next_commit then
   3.567 +                      (commit false [] relearns flops; ([], [], next_commit_time ()))
   3.568 +                    else
   3.569 +                      (relearns, flops, next_commit)
   3.570 +                  val timed_out = Timer.checkRealTimer timer > learn_timeout
   3.571 +                in
   3.572 +                  ((relearns, flops), (num_nontrivial, next_commit, timed_out))
   3.573 +                end
   3.574  
   3.575 -        val num_nontrivial =
   3.576 -          if not run_prover then
   3.577 -            num_nontrivial
   3.578 -          else
   3.579 -            let
   3.580 -              val max_isar = 1000 * max_dependencies
   3.581 +            val num_nontrivial =
   3.582 +              if not run_prover then
   3.583 +                num_nontrivial
   3.584 +              else
   3.585 +                let
   3.586 +                  val max_isar = 1000 * max_dependencies
   3.587  
   3.588 -              fun priority_of th =
   3.589 -                Random.random_range 0 max_isar +
   3.590 -                (case try (Graph.get_node access_G) (nickname_of_thm th) of
   3.591 -                  SOME (Isar_Proof, _, deps) => ~100 * length deps
   3.592 -                | SOME (Automatic_Proof, _, _) => 2 * max_isar
   3.593 -                | SOME (Isar_Proof_wegen_Prover_Flop, _, _) => max_isar
   3.594 -                | NONE => 0)
   3.595 +                  fun priority_of th =
   3.596 +                    Random.random_range 0 max_isar +
   3.597 +                    (case try (Graph.get_node access_G) (nickname_of_thm th) of
   3.598 +                      SOME (Isar_Proof, _, deps) => ~100 * length deps
   3.599 +                    | SOME (Automatic_Proof, _, _) => 2 * max_isar
   3.600 +                    | SOME (Isar_Proof_wegen_Prover_Flop, _, _) => max_isar
   3.601 +                    | NONE => 0)
   3.602  
   3.603 -              val old_facts = facts
   3.604 -                |> filter is_in_access_G
   3.605 -                |> map (`(priority_of o snd))
   3.606 -                |> sort (int_ord o apply2 fst)
   3.607 -                |> map snd
   3.608 -              val ((relearns, flops), (num_nontrivial, _, _)) =
   3.609 -                (([], []), (num_nontrivial, next_commit_time (), false))
   3.610 -                |> fold relearn_old_fact old_facts
   3.611 -            in
   3.612 -              commit true [] relearns flops; num_nontrivial
   3.613 -            end
   3.614 -      in
   3.615 -        if verbose orelse auto_level < 2 then
   3.616 -          "Learned " ^ string_of_int num_new_facts ^ " fact" ^ plural_s num_new_facts ^ " and " ^
   3.617 -          string_of_int num_nontrivial ^ " nontrivial " ^
   3.618 -          (if run_prover then "automatic and " else "") ^ "Isar proof" ^ plural_s num_nontrivial ^
   3.619 -          (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer) else "")
   3.620 -        else
   3.621 -          ""
   3.622 -      end
   3.623 +                  val old_facts = facts
   3.624 +                    |> filter is_in_access_G
   3.625 +                    |> map (`(priority_of o snd))
   3.626 +                    |> sort (int_ord o apply2 fst)
   3.627 +                    |> map snd
   3.628 +                  val ((relearns, flops), (num_nontrivial, _, _)) =
   3.629 +                    (([], []), (num_nontrivial, next_commit_time (), false))
   3.630 +                    |> fold relearn_old_fact old_facts
   3.631 +                in
   3.632 +                  commit true [] relearns flops; num_nontrivial
   3.633 +                end
   3.634 +          in
   3.635 +            if verbose orelse auto_level < 2 then
   3.636 +              "Learned " ^ string_of_int num_new_facts ^ " fact" ^ plural_s num_new_facts ^
   3.637 +              " and " ^ string_of_int num_nontrivial ^ " nontrivial " ^
   3.638 +              (if run_prover then "automatic and " else "") ^ "Isar proof" ^
   3.639 +              plural_s num_nontrivial ^
   3.640 +              (if verbose then " in " ^ string_of_time (Timer.checkRealTimer timer) else "")
   3.641 +            else
   3.642 +              ""
   3.643 +          end
   3.644 +      end)
   3.645    end
   3.646  
   3.647  fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained run_prover =
   3.648 @@ -1552,21 +1571,23 @@
   3.649    end
   3.650  
   3.651  fun mash_can_suggest_facts ctxt =
   3.652 -  not (Graph.is_empty (#access_G (get_state ctxt)))
   3.653 +  (case get_state ctxt of
   3.654 +    NONE => false
   3.655 +  | SOME {access_G, ...} => not (Graph.is_empty access_G))
   3.656  
   3.657 -fun mash_can_suggest_facts_fast () =
   3.658 -  (case peek_state () of
   3.659 +fun mash_can_suggest_facts_fast ctxt =
   3.660 +  (case peek_state ctxt of
   3.661      NONE => false
   3.662 -  | SOME (_, {access_G, ...}) => not (Graph.is_empty access_G));
   3.663 +  | SOME (_, {access_G, ...}) => not (Graph.is_empty access_G))
   3.664  
   3.665  (* Generate more suggestions than requested, because some might be thrown out later for various
   3.666     reasons (e.g., duplicates). *)
   3.667  fun generous_max_suggestions max_facts = 3 * max_facts div 2 + 25
   3.668  
   3.669 -val mepo_weight = 0.5
   3.670 -val mash_weight = 0.5
   3.671 +val mepo_weight = 0.5 (* FUDGE *)
   3.672 +val mash_weight = 0.5 (* FUDGE *)
   3.673  
   3.674 -val max_facts_to_learn_before_query = 100
   3.675 +val max_facts_to_learn_before_query = 100 (* FUDGE *)
   3.676  
   3.677  (* The threshold should be large enough so that MaSh does not get activated for Auto Sledgehammer. *)
   3.678  val min_secs_for_learning = 10
   3.679 @@ -1600,27 +1621,29 @@
   3.680            ()
   3.681  
   3.682        val mash_enabled = is_mash_enabled ()
   3.683 -      val mash_fast = mash_can_suggest_facts_fast ()
   3.684 +      val mash_fast = mash_can_suggest_facts_fast ctxt
   3.685  
   3.686        fun please_learn () =
   3.687          if mash_fast then
   3.688 -          let
   3.689 -            val {access_G, xtabs = ((num_facts0, _), _), ...} = get_state ctxt
   3.690 -            val is_in_access_G = is_fact_in_graph access_G o snd
   3.691 -            val min_num_facts_to_learn = length facts - num_facts0
   3.692 -          in
   3.693 -            if min_num_facts_to_learn <= max_facts_to_learn_before_query then
   3.694 -              (case length (filter_out is_in_access_G facts) of
   3.695 -                0 => ()
   3.696 -              | num_facts_to_learn =>
   3.697 -                if num_facts_to_learn <= max_facts_to_learn_before_query then
   3.698 -                  mash_learn_facts ctxt params prover 2 false timeout facts
   3.699 -                  |> (fn "" => () | s => writeln (MaShN ^ ": " ^ s))
   3.700 -                else
   3.701 -                  maybe_launch_thread true num_facts_to_learn)
   3.702 -            else
   3.703 -              maybe_launch_thread false min_num_facts_to_learn
   3.704 -          end
   3.705 +          (case get_state ctxt of
   3.706 +            NONE => maybe_launch_thread false (length facts)
   3.707 +          | SOME {access_G, xtabs = ((num_facts0, _), _), ...} =>
   3.708 +            let
   3.709 +              val is_in_access_G = is_fact_in_graph access_G o snd
   3.710 +              val min_num_facts_to_learn = length facts - num_facts0
   3.711 +            in
   3.712 +              if min_num_facts_to_learn <= max_facts_to_learn_before_query then
   3.713 +                (case length (filter_out is_in_access_G facts) of
   3.714 +                  0 => ()
   3.715 +                | num_facts_to_learn =>
   3.716 +                  if num_facts_to_learn <= max_facts_to_learn_before_query then
   3.717 +                    mash_learn_facts ctxt params prover 2 false timeout facts
   3.718 +                    |> (fn "" => () | s => writeln (MaShN ^ ": " ^ s))
   3.719 +                  else
   3.720 +                    maybe_launch_thread true num_facts_to_learn)
   3.721 +              else
   3.722 +                maybe_launch_thread false min_num_facts_to_learn
   3.723 +            end)
   3.724          else
   3.725            maybe_launch_thread false (length facts)
   3.726