cleaner handling of learned proofs
authorblanchet
Tue May 20 02:47:23 2014 +0200 (2014-05-20)
changeset 5701243fd82a537a3
parent 57011 a4428f517f46
child 57013 ed95456499e6
cleaner handling of learned proofs
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 00:13:31 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 02:47:23 2014 +0200
     1.3 @@ -108,6 +108,12 @@
     1.4  val relearn_isarN = "relearn_isar"
     1.5  val relearn_proverN = "relearn_prover"
     1.6  
     1.7 +val learned_proof_prefix = ".."
     1.8 +
     1.9 +fun learned_proof_name () =
    1.10 +  learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
    1.11 +  serial_string ()
    1.12 +
    1.13  fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
    1.14  fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
    1.15  
    1.16 @@ -297,8 +303,8 @@
    1.17          if i31 + 2 < l then
    1.18            let
    1.19              val x = Unsynchronized.ref i31;
    1.20 -            val () = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
    1.21 -            val () = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
    1.22 +            val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
    1.23 +            val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
    1.24            in
    1.25              !x
    1.26            end
    1.27 @@ -312,7 +318,7 @@
    1.28          val j = maxson l i
    1.29        in
    1.30          if cmp (Array.sub (a, j), e) = GREATER then
    1.31 -          let val () = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
    1.32 +          let val _ = Array.update (a, i, Array.sub (a, j)) in trickledown l j e end
    1.33          else Array.update (a, i, e)
    1.34        end
    1.35  
    1.36 @@ -321,7 +327,7 @@
    1.37      fun bubbledown l i =
    1.38        let
    1.39          val j = maxson l i
    1.40 -        val () = Array.update (a, i, Array.sub (a, j))
    1.41 +        val _ = Array.update (a, i, Array.sub (a, j))
    1.42        in
    1.43          bubbledown l j
    1.44        end
    1.45 @@ -334,7 +340,7 @@
    1.46        in
    1.47          if cmp (Array.sub (a, father), e) = LESS then
    1.48            let
    1.49 -            val () = Array.update (a, i, Array.sub (a, father))
    1.50 +            val _ = Array.update (a, i, Array.sub (a, father))
    1.51            in
    1.52              if father > 0 then trickleup father e else Array.update (a, 0, e)
    1.53            end
    1.54 @@ -351,24 +357,24 @@
    1.55          for (i - 1)
    1.56        end
    1.57  
    1.58 -    val () = for (((l + 1) div 3) - 1)
    1.59 +    val _ = for (((l + 1) div 3) - 1)
    1.60  
    1.61      fun for2 i =
    1.62        if i < max 2 (l - bnd) then () else
    1.63        let
    1.64          val e = Array.sub (a, i)
    1.65 -        val () = Array.update (a, i, Array.sub (a, 0))
    1.66 -        val () = trickleup (bubble i 0) e
    1.67 +        val _ = Array.update (a, i, Array.sub (a, 0))
    1.68 +        val _ = trickleup (bubble i 0) e
    1.69        in
    1.70          for2 (i - 1)
    1.71        end
    1.72  
    1.73 -    val () = for2 (l - 1)
    1.74 +    val _ = for2 (l - 1)
    1.75    in
    1.76      if l > 1 then
    1.77        let
    1.78          val e = Array.sub (a, 1)
    1.79 -        val () = Array.update (a, 1, Array.sub (a, 0))
    1.80 +        val _ = Array.update (a, 1, Array.sub (a, 0))
    1.81        in
    1.82          Array.update (a, 0, e)
    1.83        end
    1.84 @@ -386,46 +392,54 @@
    1.85  fun knn avail_no get_deps get_sym_ths knns advno syms =
    1.86    let
    1.87      (* Can be later used for TFIDF *)
    1.88 -    fun sym_wght _ = 1.0;
    1.89 -    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)));
    1.90 +    fun sym_wght _ = 1.0
    1.91 +
    1.92 +    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)))
    1.93 +
    1.94      fun inc_overlap j v =
    1.95        let
    1.96          val ov = snd (Array.sub (overlaps_sqr,j))
    1.97        in
    1.98          Array.update (overlaps_sqr, j, (j, v + ov))
    1.99 -      end;
   1.100 +      end
   1.101 +
   1.102      fun do_sym (s, con_wght) =
   1.103        let
   1.104 -        val sw = sym_wght s;
   1.105 -        val w2 = sw * sw * con_wght;
   1.106 +        val sw = sym_wght s
   1.107 +        val w2 = sw * sw * con_wght
   1.108          fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
   1.109        in
   1.110          ignore (map do_th (get_sym_ths s))
   1.111 -      end;
   1.112 -    val () = ignore (map do_sym syms);
   1.113 -    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
   1.114 -    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)));
   1.115 +      end
   1.116 +
   1.117 +    val _ = ignore (map do_sym syms)
   1.118 +    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
   1.119 +    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)))
   1.120 +
   1.121      fun inc_recommend j v =
   1.122        let
   1.123          val ov = snd (Array.sub (recommends,j))
   1.124        in
   1.125          Array.update (recommends, j, (j, v + ov))
   1.126 -      end;
   1.127 +      end
   1.128 +
   1.129      fun for k =
   1.130        if k = knns then () else
   1.131        if k >= avail_no then () else
   1.132        let
   1.133 -        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1);
   1.134 -        val o1 = Math.sqrt o2;
   1.135 -        val () = inc_recommend j o1;
   1.136 -        val ds = get_deps j;
   1.137 -        val l = Real.fromInt (length ds);
   1.138 +        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1)
   1.139 +        val o1 = Math.sqrt o2
   1.140 +        val _ = inc_recommend j o1
   1.141 +        val ds = get_deps j
   1.142 +        val l = Real.fromInt (length ds)
   1.143          val _ = map (fn d => inc_recommend d (o1 / l)) ds
   1.144        in
   1.145          for (k + 1)
   1.146 -      end;
   1.147 -    val () = for 0;
   1.148 -    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   1.149 +      end
   1.150 +
   1.151 +    val _ = for 0
   1.152 +    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends
   1.153 +
   1.154      fun ret acc at =
   1.155        if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   1.156    in
   1.157 @@ -473,14 +487,15 @@
   1.158            end)
   1.159          featss (length featss)
   1.160    in
   1.161 -    (trace_msg ctxt (fn () =>
   1.162 -       "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.163 -        elide_string 1000 (space_implode " " facts) ^ "}");
   1.164 -     knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
   1.165 -       max_suggs
   1.166 -       (map_filter (fn (feat, weight) =>
   1.167 -          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   1.168 -     |> map ((fn i => Array.sub (fact_ary, i)) o fst))
   1.169 +    trace_msg ctxt (fn () =>
   1.170 +      "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.171 +       elide_string 1000 (space_implode " " facts) ^ "}");
   1.172 +    knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
   1.173 +      max_suggs
   1.174 +      (map_filter (fn (feat, weight) =>
   1.175 +         Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   1.176 +    |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   1.177 +    |> filter_out (String.isPrefix learned_proof_prefix)
   1.178    end
   1.179  
   1.180  end;
   1.181 @@ -502,10 +517,9 @@
   1.182    Graph.default_node (parent, (Isar_Proof, [], []))
   1.183    #> Graph.add_edge (parent, name)
   1.184  
   1.185 -fun add_node kind name parents feats deps =
   1.186 +fun add_node kind name feats deps =
   1.187    Graph.default_node (name, (kind, feats, deps))
   1.188    #> Graph.map_node name (K (kind, feats, deps))
   1.189 -  #> fold (add_edge_to name) parents;
   1.190  
   1.191  fun try_graph ctxt when def f =
   1.192    f ()
   1.193 @@ -526,7 +540,6 @@
   1.194  fun graph_info G =
   1.195    string_of_int (length (Graph.keys G)) ^ " node(s), " ^
   1.196    string_of_int (fold (Integer.add o length o snd) (Graph.dest G) 0) ^ " edge(s), " ^
   1.197 -  string_of_int (length (Graph.minimals G)) ^ " minimal, " ^
   1.198    string_of_int (length (Graph.maximals G)) ^ " maximal"
   1.199  
   1.200  type mash_state =
   1.201 @@ -562,7 +575,9 @@
   1.202             fun extract_line_and_add_node line =
   1.203               (case extract_node line of
   1.204                 NONE => I (* should not happen *)
   1.205 -             | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
   1.206 +             | SOME (kind, name, parents, feats, deps) =>
   1.207 +               add_node kind name feats deps
   1.208 +               #> fold (add_edge_to name) parents)
   1.209  
   1.210             val (access_G, num_known_facts) =
   1.211               (case string_ord (version', version) of
   1.212 @@ -1095,40 +1110,36 @@
   1.213    let
   1.214      val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
   1.215  
   1.216 -    fun insert_new seen name =
   1.217 -      not (Symtab.defined seen name) ? insert (op =) name
   1.218 +    fun insert_new seen name = not (Symtab.defined seen name) ? insert (op =) name
   1.219  
   1.220      fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   1.221  
   1.222      fun find_maxes _ (maxs, []) = map snd maxs
   1.223        | find_maxes seen (maxs, new :: news) =
   1.224 -        find_maxes
   1.225 -            (seen |> num_keys (Graph.imm_succs G new) > 1
   1.226 -                     ? Symtab.default (new, ()))
   1.227 -            (if Symtab.defined tab new then
   1.228 -               let
   1.229 -                 val newp = Graph.all_preds G [new]
   1.230 -                 fun is_ancestor x yp = member (op =) yp x
   1.231 -                 val maxs =
   1.232 -                   maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   1.233 -               in
   1.234 -                 if exists (is_ancestor new o fst) maxs then
   1.235 -                   (maxs, news)
   1.236 -                 else
   1.237 -                   ((newp, new)
   1.238 -                    :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   1.239 -                    news)
   1.240 -               end
   1.241 -             else
   1.242 -               (maxs, Graph.Keys.fold (insert_new seen)
   1.243 -                                      (Graph.imm_preds G new) news))
   1.244 +        find_maxes (seen |> num_keys (Graph.imm_succs G new) > 1 ? Symtab.default (new, ()))
   1.245 +          (if Symtab.defined tab new then
   1.246 +             let
   1.247 +               val newp = Graph.all_preds G [new]
   1.248 +               fun is_ancestor x yp = member (op =) yp x
   1.249 +               val maxs = maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   1.250 +             in
   1.251 +               if exists (is_ancestor new o fst) maxs then (maxs, news)
   1.252 +               else ((newp, new) :: filter_out (fn (_, max) => is_ancestor max newp) maxs, news)
   1.253 +             end
   1.254 +           else
   1.255 +             (maxs, Graph.Keys.fold (insert_new seen) (Graph.imm_preds G new) news))
   1.256    in
   1.257      find_maxes Symtab.empty ([], Graph.maximals G)
   1.258    end
   1.259  
   1.260 -fun maximal_wrt_access_graph access_G =
   1.261 -  map (nickname_of_thm o snd)
   1.262 -  #> maximal_wrt_graph access_G
   1.263 +fun graph_islands G =
   1.264 +  Graph.fold (fn (m, (_, (preds, succs))) =>
   1.265 +    (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G [];
   1.266 +
   1.267 +(* islands represent learned proofs associated with no facts *)
   1.268 +fun maximal_wrt_access_graph access_G facts =
   1.269 +  map (nickname_of_thm o snd) facts @ graph_islands access_G
   1.270 +  |> maximal_wrt_graph access_G
   1.271  
   1.272  fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   1.273  
   1.274 @@ -1264,9 +1275,6 @@
   1.275      Async_Manager.thread MaShN birth_time death_time desc task
   1.276    end
   1.277  
   1.278 -fun fresh_enough_name () =
   1.279 -  Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
   1.280 -
   1.281  fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
   1.282    if is_mash_enabled () then
   1.283      launch_thread timeout (fn () =>
   1.284 @@ -1276,18 +1284,20 @@
   1.285        in
   1.286          map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
   1.287            let
   1.288 -            val name = fresh_enough_name ()
   1.289 -            val parents = maximal_wrt_access_graph access_G facts
   1.290 +            val name = learned_proof_name ()
   1.291              val deps = used_ths
   1.292                |> filter (is_fact_in_graph access_G)
   1.293                |> map nickname_of_thm
   1.294            in
   1.295              if Config.get ctxt sml then
   1.296 -              {access_G = add_node Automatic_Proof name parents feats deps access_G,
   1.297 -               num_known_facts = num_known_facts + 1,
   1.298 -               dirty = Option.map (cons name) dirty}
   1.299 +              let val access_G = access_G |> add_node Automatic_Proof name feats deps in
   1.300 +                {access_G = access_G, num_known_facts = num_known_facts + 1,
   1.301 +                 dirty = Option.map (cons name) dirty}
   1.302 +              end
   1.303              else
   1.304 -              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.305 +              let val parents = maximal_wrt_access_graph access_G facts in
   1.306 +                (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.307 +              end
   1.308            end);
   1.309          (true, "")
   1.310        end)