tuning
authorblanchet
Tue May 20 16:11:37 2014 +0200 (2014-05-20)
changeset 57017afdf75c0de58
parent 57016 c44ce6f4067d
child 57018 142950e9c7e2
tuning
src/HOL/Tools/ATP/atp_proof.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:00:00 2014 +0200
     1.2 +++ b/src/HOL/Tools/ATP/atp_proof.ML	Tue May 20 16:11:37 2014 +0200
     1.3 @@ -183,8 +183,7 @@
     1.4    |> find_first (fn (_, pattern) => String.isSubstring pattern output)
     1.5    |> Option.map fst
     1.6  
     1.7 -fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures
     1.8 -                                       output =
     1.9 +fun extract_tstplike_proof_and_outcome verbose proof_delims known_failures output =
    1.10    (case (extract_tstplike_proof proof_delims output,
    1.11        extract_known_atp_failure known_failures output) of
    1.12      (_, SOME ProofIncomplete) => ("", NONE)
     2.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:00:00 2014 +0200
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 16:11:37 2014 +0200
     2.3 @@ -284,22 +284,16 @@
     2.4  structure MaSh_SML =
     2.5  struct
     2.6  
     2.7 -fun max a b = if a > b then a else b
     2.8 -
     2.9  exception BOTTOM of int
    2.10  
    2.11  fun heap cmp bnd a =
    2.12    let
    2.13      fun maxson l i =
    2.14 -      let
    2.15 -        val i31 = i + i + i + 1
    2.16 -      in
    2.17 +      let val i31 = i + i + i + 1 in
    2.18          if i31 + 2 < l then
    2.19 -          let
    2.20 -            val x = Unsynchronized.ref i31;
    2.21 -            val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
    2.22 -            val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
    2.23 -          in
    2.24 +          let val x = Unsynchronized.ref i31 in
    2.25 +            if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
    2.26 +            if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
    2.27              !x
    2.28            end
    2.29          else
    2.30 @@ -354,7 +348,7 @@
    2.31      val _ = for (((l + 1) div 3) - 1)
    2.32  
    2.33      fun for2 i =
    2.34 -      if i < max 2 (l - bnd) then () else
    2.35 +      if i < Integer.max 2 (l - bnd) then () else
    2.36        let
    2.37          val e = Array.sub (a, i)
    2.38          val _ = Array.update (a, i, Array.sub (a, 0))
    2.39 @@ -387,51 +381,57 @@
    2.40  fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
    2.41    let
    2.42      (* Can be later used for TFIDF *)
    2.43 -    fun sym_wght _ = 1.0;
    2.44 -    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
    2.45 +    fun sym_wght _ = 1.0
    2.46 +
    2.47 +    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
    2.48 +
    2.49      fun inc_overlap j v =
    2.50        let
    2.51 -        val ov = snd (Array.sub (overlaps_sqr,j))
    2.52 +        val ov = snd (Array.sub (overlaps_sqr, j))
    2.53        in
    2.54          Array.update (overlaps_sqr, j, (j, v + ov))
    2.55 -      end;
    2.56 +      end
    2.57 +
    2.58      fun do_sym (s, con_wght) =
    2.59        let
    2.60 -        val sw = sym_wght s;
    2.61 -        val w2 = sw * sw * con_wght;
    2.62 +        val sw = sym_wght s
    2.63 +        val w2 = sw * sw * con_wght
    2.64 +
    2.65          fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
    2.66        in
    2.67 -        ignore (map do_th (get_sym_ths s))
    2.68 -      end;
    2.69 -    val () = ignore (map do_sym syms);
    2.70 -    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
    2.71 -    val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
    2.72 +        List.app do_th (get_sym_ths s)
    2.73 +      end
    2.74 +
    2.75 +    val _ = List.app do_sym syms
    2.76 +    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
    2.77 +    val recommends = Array.tabulate (adv_max, rpair 0.0)
    2.78 +
    2.79      fun inc_recommend j v =
    2.80 -      if j >= adv_max then () else
    2.81 -      let
    2.82 -        val ov = snd (Array.sub (recommends,j))
    2.83 -      in
    2.84 -        Array.update (recommends, j, (j, v + ov))
    2.85 -      end;
    2.86 +      if j >= adv_max then ()
    2.87 +      else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
    2.88 +
    2.89      fun for k =
    2.90 -      if k = knns then () else
    2.91 -      if k >= adv_max then () else
    2.92 -      let
    2.93 -        val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
    2.94 -        val o1 = Math.sqrt o2;
    2.95 -        val () = inc_recommend j o1;
    2.96 -        val ds = get_deps j;
    2.97 -        val l = Real.fromInt (length ds);
    2.98 -        val _ = map (fn d => inc_recommend d (o1 / l)) ds
    2.99 -      in
   2.100 -        for (k + 1)
   2.101 -      end;
   2.102 -    val () = for 0;
   2.103 -    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   2.104 +      if k = knns orelse k >= adv_max then
   2.105 +        ()
   2.106 +      else
   2.107 +        let
   2.108 +          val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
   2.109 +          val o1 = Math.sqrt o2
   2.110 +          val _ = inc_recommend j o1
   2.111 +          val ds = get_deps j
   2.112 +          val l = Real.fromInt (length ds)
   2.113 +          val _ = map (fn d => inc_recommend d (o1 / l)) ds
   2.114 +        in
   2.115 +          for (k + 1)
   2.116 +        end
   2.117 +
   2.118 +    val _ = for 0
   2.119 +    val _ = heap (Real.compare o pairself snd) advno recommends
   2.120 +
   2.121      fun ret acc at =
   2.122 -      if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   2.123 +      if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   2.124    in
   2.125 -    ret [] (max 0 (adv_max - advno))
   2.126 +    ret [] (Integer.max 0 (adv_max - advno))
   2.127    end
   2.128  
   2.129  val knns = 40 (* FUDGE *)
   2.130 @@ -440,7 +440,7 @@
   2.131  
   2.132  fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   2.133  
   2.134 -fun learn_and_query ctxt parents access_G max_suggs hints feats =
   2.135 +fun query ctxt parents access_G max_suggs hints feats =
   2.136    let
   2.137      val str_of_feat = space_implode "|"
   2.138  
   2.139 @@ -469,9 +469,9 @@
   2.140          all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   2.141  
   2.142      val facts = rev rev_facts
   2.143 -    val fact_ary = Array.fromList facts
   2.144 +    val fact_vec = Vector.fromList facts
   2.145  
   2.146 -    val deps_ary = Array.fromList (rev rev_depss)
   2.147 +    val deps_vec = Vector.fromList (rev rev_depss)
   2.148      val facts_ary = Array.array (num_feats, [])
   2.149      val _ =
   2.150        fold (fn feats => fn fact =>
   2.151 @@ -484,11 +484,11 @@
   2.152    in
   2.153      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   2.154        elide_string 1000 (space_implode " " facts) ^ "}");
   2.155 -    knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
   2.156 +    knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
   2.157        (curry Array.sub facts_ary) knns max_suggs
   2.158        (map_filter (fn (feat, weight) =>
   2.159           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   2.160 -    |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   2.161 +    |> map (curry Vector.sub fact_vec o fst)
   2.162    end
   2.163  
   2.164  end;
   2.165 @@ -625,7 +625,7 @@
   2.166    Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   2.167  
   2.168  fun clear_state ctxt overlord =
   2.169 -  (* "unlearn" also removes the state file *)
   2.170 +  (* "MaSh_Py.unlearn" also removes the state file *)
   2.171    Synchronized.change global_state (fn _ =>
   2.172      (if Config.get ctxt sml then wipe_out_mash_state_dir ()
   2.173       else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
   2.174 @@ -971,9 +971,6 @@
   2.175      | NONE => false)
   2.176    | is_size_def _ _ = false
   2.177  
   2.178 -fun no_dependencies_for_status status =
   2.179 -  status = Non_Rec_Def orelse status = Rec_Def
   2.180 -
   2.181  fun trim_dependencies deps =
   2.182    if length deps > max_dependencies then NONE else SOME deps
   2.183  
   2.184 @@ -1022,18 +1019,17 @@
   2.185        val num_isar_deps = length isar_deps
   2.186      in
   2.187        if verbose andalso auto_level = 0 then
   2.188 -        "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
   2.189 -        " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
   2.190 -        |> Output.urgent_message
   2.191 +        Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
   2.192 +          string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
   2.193 +          " facts.")
   2.194        else
   2.195          ();
   2.196        (case run_prover_for_mash ctxt params prover name facts goal of
   2.197          {outcome = NONE, used_facts, ...} =>
   2.198          (if verbose andalso auto_level = 0 then
   2.199             let val num_facts = length used_facts in
   2.200 -             "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   2.201 -             plural_s num_facts ^ "."
   2.202 -             |> Output.urgent_message
   2.203 +             Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
   2.204 +               plural_s num_facts ^ ".")
   2.205             end
   2.206           else
   2.207             ();
   2.208 @@ -1187,40 +1183,57 @@
   2.209        |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
   2.210        |> map (apsnd (fn r => weight * factor * r))
   2.211  
   2.212 -    val (access_G, suggs) =
   2.213 +    fun query_args access_G =
   2.214 +      let
   2.215 +        val parents = maximal_wrt_access_graph access_G facts
   2.216 +        val hints = chained
   2.217 +          |> filter (is_fact_in_graph access_G o snd)
   2.218 +          |> map (nickname_of_thm o snd)
   2.219 +
   2.220 +        val goal_feats =
   2.221 +          features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
   2.222 +        val chained_feats = chained
   2.223 +          |> map (rpair 1.0)
   2.224 +          |> map (chained_or_extra_features_of chained_feature_factor)
   2.225 +          |> rpair [] |-> fold (union (eq_fst (op =)))
   2.226 +        val extra_feats = facts
   2.227 +          |> take (Int.max (0, num_extra_feature_facts - length chained))
   2.228 +          |> filter fact_has_right_theory
   2.229 +          |> weight_facts_steeply
   2.230 +          |> map (chained_or_extra_features_of extra_feature_factor)
   2.231 +          |> rpair [] |-> fold (union (eq_fst (op =)))
   2.232 +        val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
   2.233 +          |> debug ? sort (Real.compare o swap o pairself snd)
   2.234 +      in
   2.235 +        (parents, hints, feats)
   2.236 +      end
   2.237 +
   2.238 +    val sml = Config.get ctxt sml
   2.239 +
   2.240 +    val (access_G, py_suggs) =
   2.241        peek_state ctxt overlord (fn {access_G, ...} =>
   2.242          if Graph.is_empty access_G then
   2.243            (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   2.244          else
   2.245 -          let
   2.246 -            val parents = maximal_wrt_access_graph access_G facts
   2.247 -            val goal_feats =
   2.248 -              features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
   2.249 -            val chained_feats = chained
   2.250 -              |> map (rpair 1.0)
   2.251 -              |> map (chained_or_extra_features_of chained_feature_factor)
   2.252 -              |> rpair [] |-> fold (union (eq_fst (op =)))
   2.253 -            val extra_feats = facts
   2.254 -              |> take (Int.max (0, num_extra_feature_facts - length chained))
   2.255 -              |> filter fact_has_right_theory
   2.256 -              |> weight_facts_steeply
   2.257 -              |> map (chained_or_extra_features_of extra_feature_factor)
   2.258 -              |> rpair [] |-> fold (union (eq_fst (op =)))
   2.259 -            val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
   2.260 -              |> debug ? sort (Real.compare o swap o pairself snd)
   2.261 -            val hints = chained
   2.262 -              |> filter (is_fact_in_graph access_G o snd)
   2.263 -              |> map (nickname_of_thm o snd)
   2.264 -          in
   2.265 -            (access_G,
   2.266 -             if Config.get ctxt sml then
   2.267 -               MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
   2.268 -             else
   2.269 -               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
   2.270 -          end)
   2.271 +          (access_G,
   2.272 +           if sml then
   2.273 +             []
   2.274 +           else
   2.275 +             let val (parents, hints, feats) = query_args access_G in
   2.276 +               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
   2.277 +             end))
   2.278 +
   2.279 +    val sml_suggs =
   2.280 +      if sml then
   2.281 +        let val (parents, hints, feats) = query_args access_G in
   2.282 +          MaSh_SML.query ctxt parents access_G max_facts hints feats
   2.283 +        end
   2.284 +      else
   2.285 +        []
   2.286 +
   2.287      val unknown = filter_out (is_fact_in_graph access_G o snd) facts
   2.288    in
   2.289 -    find_mash_suggestions ctxt max_facts suggs facts chained unknown
   2.290 +    find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
   2.291      |> pairself (map fact_of_raw_fact)
   2.292    end
   2.293  
   2.294 @@ -1323,7 +1336,7 @@
   2.295          val name_tabs = build_name_tables nickname_of_thm facts
   2.296  
   2.297          fun deps_of status th =
   2.298 -          if no_dependencies_for_status status then
   2.299 +          if status = Non_Rec_Def orelse status = Rec_Def then
   2.300              SOME []
   2.301            else if run_prover then
   2.302              prover_dependencies_of ctxt params prover auto_level facts name_tabs th
   2.303 @@ -1355,18 +1368,13 @@
   2.304              end
   2.305  
   2.306          fun commit last learns relearns flops =
   2.307 -          (if debug andalso auto_level = 0 then
   2.308 -             Output.urgent_message "Committing..."
   2.309 -           else
   2.310 -             ();
   2.311 +          (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
   2.312             map_state ctxt overlord (do_commit (rev learns) relearns flops);
   2.313             if not last andalso auto_level = 0 then
   2.314               let val num_proofs = length learns + length relearns in
   2.315 -               "Learned " ^ string_of_int num_proofs ^ " " ^
   2.316 -               (if run_prover then "automatic" else "Isar") ^ " proof" ^
   2.317 -               plural_s num_proofs ^ " in the last " ^
   2.318 -               string_of_time commit_timeout ^ "."
   2.319 -               |> Output.urgent_message
   2.320 +               Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
   2.321 +                 (if run_prover then "automatic" else "Isar") ^ " proof" ^
   2.322 +                 plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
   2.323               end
   2.324             else
   2.325               ())
   2.326 @@ -1478,14 +1486,12 @@
   2.327        |> Output.urgent_message
   2.328    in
   2.329      if run_prover then
   2.330 -      ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
   2.331 -       " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
   2.332 -       ").\n\nCollecting Isar proofs first..."
   2.333 -       |> Output.urgent_message;
   2.334 +      (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
   2.335 +         plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
   2.336 +         string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
   2.337         learn 1 false;
   2.338 -       "Now collecting automatic proofs. This may take several hours. You can safely stop the \
   2.339 -       \learning process at any point."
   2.340 -       |> Output.urgent_message;
   2.341 +       Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
   2.342 +         \can safely stop the learning process at any point.";
   2.343         learn 0 true)
   2.344      else
   2.345        (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^