better way to take invisible facts into account than 'island' business
authorblanchet
Tue May 20 09:38:39 2014 +0200 (2014-05-20)
changeset 57013ed95456499e6
parent 57012 43fd82a537a3
child 57014 b7999893ffcc
better way to take invisible facts into account than 'island' business
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 02:47:23 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue May 20 09:38:39 2014 +0200
     1.3 @@ -108,12 +108,6 @@
     1.4  val relearn_isarN = "relearn_isar"
     1.5  val relearn_proverN = "relearn_prover"
     1.6  
     1.7 -val learned_proof_prefix = ".."
     1.8 -
     1.9 -fun learned_proof_name () =
    1.10 -  learned_proof_prefix ^ Date.fmt "%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
    1.11 -  serial_string ()
    1.12 -
    1.13  fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
    1.14  fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
    1.15  
    1.16 @@ -382,68 +376,62 @@
    1.17    end
    1.18  
    1.19  (*
    1.20 -  avail_no = maximum number of theorems to check dependencies and symbols
    1.21 +  avail_num = maximum number of theorems to check dependencies and symbols
    1.22 +  adv_max = do not return theorems over or equal to this number. Must satisfy: adv_max <= avail_num
    1.23    get_deps = returns dependencies of a theorem
    1.24    get_sym_ths = get theorems that have this feature
    1.25 -  knns    = number of nearest neighbours
    1.26 -  advno   = number of predictions to return
    1.27 -  syms    = symbols of the conjecture
    1.28 +  knns = number of nearest neighbours
    1.29 +  advno = number of predictions to return
    1.30 +  syms = symbols of the conjecture
    1.31  *)
    1.32 -fun knn avail_no get_deps get_sym_ths knns advno syms =
    1.33 +fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
    1.34    let
    1.35      (* Can be later used for TFIDF *)
    1.36 -    fun sym_wght _ = 1.0
    1.37 -
    1.38 -    val overlaps_sqr = Array.tabulate (avail_no, (fn i => (i, 0.0)))
    1.39 -
    1.40 +    fun sym_wght _ = 1.0;
    1.41 +    val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
    1.42      fun inc_overlap j v =
    1.43        let
    1.44          val ov = snd (Array.sub (overlaps_sqr,j))
    1.45        in
    1.46          Array.update (overlaps_sqr, j, (j, v + ov))
    1.47 -      end
    1.48 -
    1.49 +      end;
    1.50      fun do_sym (s, con_wght) =
    1.51        let
    1.52 -        val sw = sym_wght s
    1.53 -        val w2 = sw * sw * con_wght
    1.54 -        fun do_th (j, prem_wght) = if j < avail_no then inc_overlap j (w2 * prem_wght) else ()
    1.55 +        val sw = sym_wght s;
    1.56 +        val w2 = sw * sw * con_wght;
    1.57 +        fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
    1.58        in
    1.59          ignore (map do_th (get_sym_ths s))
    1.60 -      end
    1.61 -
    1.62 -    val _ = ignore (map do_sym syms)
    1.63 -    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
    1.64 -    val recommends = Array.tabulate (avail_no, (fn j => (j, 0.0)))
    1.65 -
    1.66 +      end;
    1.67 +    val () = ignore (map do_sym syms);
    1.68 +    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
    1.69 +    val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
    1.70      fun inc_recommend j v =
    1.71 +      if j >= adv_max then () else
    1.72        let
    1.73          val ov = snd (Array.sub (recommends,j))
    1.74        in
    1.75          Array.update (recommends, j, (j, v + ov))
    1.76 -      end
    1.77 -
    1.78 +      end;
    1.79      fun for k =
    1.80        if k = knns then () else
    1.81 -      if k >= avail_no then () else
    1.82 +      if k >= adv_max then () else
    1.83        let
    1.84 -        val (j, o2) = Array.sub (overlaps_sqr, avail_no - k - 1)
    1.85 -        val o1 = Math.sqrt o2
    1.86 -        val _ = inc_recommend j o1
    1.87 -        val ds = get_deps j
    1.88 -        val l = Real.fromInt (length ds)
    1.89 +        val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
    1.90 +        val o1 = Math.sqrt o2;
    1.91 +        val () = inc_recommend j o1;
    1.92 +        val ds = get_deps j;
    1.93 +        val l = Real.fromInt (length ds);
    1.94          val _ = map (fn d => inc_recommend d (o1 / l)) ds
    1.95        in
    1.96          for (k + 1)
    1.97 -      end
    1.98 -
    1.99 -    val _ = for 0
   1.100 -    val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends
   1.101 -
   1.102 +      end;
   1.103 +    val () = for 0;
   1.104 +    val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   1.105      fun ret acc at =
   1.106        if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   1.107    in
   1.108 -    ret [] (max 0 (avail_no - advno))
   1.109 +    ret [] (max 0 (adv_max - advno))
   1.110    end
   1.111  
   1.112  val knns = 40 (* FUDGE *)
   1.113 @@ -456,11 +444,18 @@
   1.114    let
   1.115      val str_of_feat = space_implode "|"
   1.116  
   1.117 -    val (depss0, featss, (_, _, facts0), (num_feats, feat_tab, _)) =
   1.118 -      fold_rev (fn fact => fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   1.119 +    val visible_facts = Graph.all_preds access_G parents
   1.120 +    val visible_fact_set = Symtab.make_set visible_facts
   1.121 +
   1.122 +    val all_nodes =
   1.123 +      Graph.schedule (K I) access_G
   1.124 +      |> List.partition (Symtab.defined visible_fact_set o fst)
   1.125 +      |> op @
   1.126 +
   1.127 +    val (rev_depss, featss, (_, _, rev_facts), (num_feats, feat_tab, _)) =
   1.128 +      fold (fn (fact, (_, feats, deps)) =>
   1.129 +            fn (depss, featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   1.130            let
   1.131 -            val (_, feats, deps) = Graph.get_node access_G fact
   1.132 -
   1.133              fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   1.134                (case Symtab.lookup tab feat of
   1.135                  SOME i => ((i, weight), xtab)
   1.136 @@ -471,12 +466,12 @@
   1.137              (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
   1.138               add_to_xtab fact fact_xtab, feat_xtab')
   1.139            end)
   1.140 -        (Graph.all_preds access_G parents) ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   1.141 +        all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   1.142  
   1.143 -    val facts = rev facts0
   1.144 +    val facts = rev rev_facts
   1.145      val fact_ary = Array.fromList facts
   1.146  
   1.147 -    val deps_ary = Array.fromList (rev depss0)
   1.148 +    val deps_ary = Array.fromList (rev rev_depss)
   1.149      val facts_ary = Array.array (num_feats, [])
   1.150      val _ =
   1.151        fold (fn feats => fn fact =>
   1.152 @@ -487,15 +482,13 @@
   1.153            end)
   1.154          featss (length featss)
   1.155    in
   1.156 -    trace_msg ctxt (fn () =>
   1.157 -      "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.158 -       elide_string 1000 (space_implode " " facts) ^ "}");
   1.159 -    knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
   1.160 -      max_suggs
   1.161 +    trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   1.162 +      elide_string 1000 (space_implode " " facts) ^ "}");
   1.163 +    knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
   1.164 +      (curry Array.sub facts_ary) knns max_suggs
   1.165        (map_filter (fn (feat, weight) =>
   1.166           Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   1.167      |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   1.168 -    |> filter_out (String.isPrefix learned_proof_prefix)
   1.169    end
   1.170  
   1.171  end;
   1.172 @@ -517,9 +510,10 @@
   1.173    Graph.default_node (parent, (Isar_Proof, [], []))
   1.174    #> Graph.add_edge (parent, name)
   1.175  
   1.176 -fun add_node kind name feats deps =
   1.177 +fun add_node kind name parents feats deps =
   1.178    Graph.default_node (name, (kind, feats, deps))
   1.179    #> Graph.map_node name (K (kind, feats, deps))
   1.180 +  #> fold (add_edge_to name) parents
   1.181  
   1.182  fun try_graph ctxt when def f =
   1.183    f ()
   1.184 @@ -575,9 +569,7 @@
   1.185             fun extract_line_and_add_node line =
   1.186               (case extract_node line of
   1.187                 NONE => I (* should not happen *)
   1.188 -             | SOME (kind, name, parents, feats, deps) =>
   1.189 -               add_node kind name feats deps
   1.190 -               #> fold (add_edge_to name) parents)
   1.191 +             | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
   1.192  
   1.193             val (access_G, num_known_facts) =
   1.194               (case string_ord (version', version) of
   1.195 @@ -1132,13 +1124,8 @@
   1.196      find_maxes Symtab.empty ([], Graph.maximals G)
   1.197    end
   1.198  
   1.199 -fun graph_islands G =
   1.200 -  Graph.fold (fn (m, (_, (preds, succs))) =>
   1.201 -    (Graph.Keys.is_empty preds andalso Graph.Keys.is_empty succs) ? cons m) G [];
   1.202 -
   1.203 -(* islands represent learned proofs associated with no facts *)
   1.204  fun maximal_wrt_access_graph access_G facts =
   1.205 -  map (nickname_of_thm o snd) facts @ graph_islands access_G
   1.206 +  map (nickname_of_thm o snd) facts
   1.207    |> maximal_wrt_graph access_G
   1.208  
   1.209  fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   1.210 @@ -1240,7 +1227,7 @@
   1.211  fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
   1.212    let
   1.213      fun maybe_learn_from from (accum as (parents, G)) =
   1.214 -      try_graph ctxt "updating G" accum (fn () =>
   1.215 +      try_graph ctxt "updating graph" accum (fn () =>
   1.216          (from :: parents, Graph.add_edge_acyclic (from, name) G))
   1.217      val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
   1.218      val (parents, G) = ([], G) |> fold maybe_learn_from parents
   1.219 @@ -1275,6 +1262,9 @@
   1.220      Async_Manager.thread MaShN birth_time death_time desc task
   1.221    end
   1.222  
   1.223 +fun learned_proof_name () =
   1.224 +  Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
   1.225 +
   1.226  fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
   1.227    if is_mash_enabled () then
   1.228      launch_thread timeout (fn () =>
   1.229 @@ -1285,19 +1275,18 @@
   1.230          map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty} =>
   1.231            let
   1.232              val name = learned_proof_name ()
   1.233 +            val parents = maximal_wrt_access_graph access_G facts
   1.234              val deps = used_ths
   1.235                |> filter (is_fact_in_graph access_G)
   1.236                |> map nickname_of_thm
   1.237            in
   1.238              if Config.get ctxt sml then
   1.239 -              let val access_G = access_G |> add_node Automatic_Proof name feats deps in
   1.240 +              let val access_G = access_G |> add_node Automatic_Proof name parents feats deps in
   1.241                  {access_G = access_G, num_known_facts = num_known_facts + 1,
   1.242                   dirty = Option.map (cons name) dirty}
   1.243                end
   1.244              else
   1.245 -              let val parents = maximal_wrt_access_graph access_G facts in
   1.246 -                (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.247 -              end
   1.248 +              (MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]; state)
   1.249            end);
   1.250          (true, "")
   1.251        end)