take weights into consideration in knn
authorblanchet
Mon May 19 23:43:53 2014 +0200 (2014-05-19)
changeset 57010121b63d7bcdb
parent 57009 8cb6a5f1ae84
child 57011 a4428f517f46
take weights into consideration in knn
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon May 19 23:43:53 2014 +0200
     1.3 @@ -207,10 +207,15 @@
     1.4    encode_unweighted_feature names ^
     1.5    (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
     1.6  
     1.7 +fun decode_feature s =
     1.8 +  (case space_explode "=" s of
     1.9 +    [feat, weight] => (decode_unweighted_feature feat, Real.fromString weight |> the_default 1.0)
    1.10 +  | _ => (decode_unweighted_feature s, 1.0))
    1.11 +
    1.12  val encode_unweighted_features = map encode_unweighted_feature #> space_implode " "
    1.13 -val decode_unweighted_features = space_explode " " #> map decode_unweighted_feature
    1.14  
    1.15  val encode_features = map encode_feature #> space_implode " "
    1.16 +val decode_features = space_explode " " #> map decode_feature
    1.17  
    1.18  fun str_of_learn (name, parents, feats, deps) =
    1.19    "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
    1.20 @@ -431,9 +436,8 @@
    1.21  
    1.22  fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
    1.23  
    1.24 -fun map_array_index ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
    1.25 +fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
    1.26  
    1.27 -(* TODO: take weight components of "feats" into consideration *)
    1.28  fun learn_and_query ctxt parents access_G max_suggs hints feats =
    1.29    let
    1.30      val str_of_feat = space_implode "|"
    1.31 @@ -443,12 +447,12 @@
    1.32            let
    1.33              val (_, feats, deps) = Graph.get_node access_G fact
    1.34  
    1.35 -            fun add_feat feat (xtab as (n, tab, _)) =
    1.36 +            fun add_feat (feat, weight) (xtab as (n, tab, _)) =
    1.37                (case Symtab.lookup tab feat of
    1.38 -                SOME i => (i, xtab)
    1.39 -              | NONE => (n, add_to_xtab feat xtab))
    1.40 +                SOME i => ((i, weight), xtab)
    1.41 +              | NONE => ((n, weight), add_to_xtab feat xtab))
    1.42  
    1.43 -            val (feats', feat_xtab') = fold_map (add_feat o str_of_feat) feats feat_xtab
    1.44 +            val (feats', feat_xtab') = fold_map (add_feat o apfst str_of_feat) feats feat_xtab
    1.45            in
    1.46              (map_filter (Symtab.lookup fact_tab) deps :: depss, feats' :: featss,
    1.47               add_to_xtab fact fact_xtab, feat_xtab')
    1.48 @@ -463,15 +467,17 @@
    1.49      val _ =
    1.50        fold (fn feats => fn fact =>
    1.51            let val fact' = fact - 1 in
    1.52 -            (List.app (map_array_index facts_ary (cons fact')) feats; fact')
    1.53 +            List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
    1.54 +              feats;
    1.55 +            fact'
    1.56            end)
    1.57          featss (length featss)
    1.58    in
    1.59      (trace_msg ctxt (fn () =>
    1.60         "MaSh_SML query " ^ encode_features feats ^ " from {" ^
    1.61          elide_string 1000 (space_implode " " facts) ^ "}");
    1.62 -     knn (Array.length deps_ary) (curry Array.sub deps_ary)
    1.63 -       (map (rpair 1.0) (* FIXME *) o curry Array.sub facts_ary) knns max_suggs
    1.64 +     knn (Array.length deps_ary) (curry Array.sub deps_ary) (curry Array.sub facts_ary) knns
    1.65 +       max_suggs
    1.66         (map_filter (fn (feat, weight) =>
    1.67            Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
    1.68       |> map ((fn i => Array.sub (fact_ary, i)) o fst))
    1.69 @@ -515,7 +521,7 @@
    1.70    string_of_int (length (Graph.maximals G)) ^ " maximal"
    1.71  
    1.72  type mash_state =
    1.73 -  {access_G : (proof_kind * string list list * string list) Graph.T,
    1.74 +  {access_G : (proof_kind * (string list * real) list * string list) Graph.T,
    1.75     num_known_facts : int,
    1.76     dirty : string list option}
    1.77  
    1.78 @@ -523,7 +529,7 @@
    1.79  
    1.80  local
    1.81  
    1.82 -val version = "*** MaSh version 20140516 ***"
    1.83 +val version = "*** MaSh version 20140519 ***"
    1.84  
    1.85  exception FILE_VERSION_TOO_NEW of unit
    1.86  
    1.87 @@ -532,8 +538,8 @@
    1.88      [head, tail] =>
    1.89      (case (space_explode " " head, map (unprefix " ") (space_explode ";" tail)) of
    1.90        ([kind, name], [parents, feats, deps]) =>
    1.91 -      SOME (proof_kind_of_str kind, decode_str name, decode_strs parents,
    1.92 -        decode_unweighted_features feats, decode_strs deps)
    1.93 +      SOME (proof_kind_of_str kind, decode_str name, decode_strs parents, decode_features feats,
    1.94 +        decode_strs deps)
    1.95      | _ => NONE)
    1.96    | _ => NONE)
    1.97  
    1.98 @@ -573,7 +579,7 @@
    1.99  
   1.100  fun str_of_entry (kind, name, parents, feats, deps) =
   1.101    str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
   1.102 -  encode_unweighted_features feats ^ "; " ^ encode_strs deps ^ "\n"
   1.103 +  encode_features feats ^ "; " ^ encode_strs deps ^ "\n"
   1.104  
   1.105  fun save_state _ (state as {dirty = SOME [], ...}) = state
   1.106    | save_state ctxt {access_G, num_known_facts, dirty} =
   1.107 @@ -1225,7 +1231,7 @@
   1.108      val (parents, G) = ([], G) |> fold maybe_learn_from parents
   1.109      val (deps, _) = ([], G) |> fold maybe_learn_from deps
   1.110    in
   1.111 -    ((name, parents, feats, deps) :: learns, G)
   1.112 +    ((name, parents, map fst feats, deps) :: learns, G)
   1.113    end
   1.114  
   1.115  fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
   1.116 @@ -1259,7 +1265,7 @@
   1.117      launch_thread timeout (fn () =>
   1.118        let
   1.119          val thy = Proof_Context.theory_of ctxt
   1.120 -        val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
   1.121 +        val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t]
   1.122        in
   1.123          peek_state ctxt overlord (fn {access_G, ...} =>
   1.124            let
   1.125 @@ -1269,7 +1275,7 @@
   1.126                         |> map nickname_of_thm
   1.127            in
   1.128              if Config.get ctxt sml then () (* TODO: implement *)
   1.129 -            else MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]
   1.130 +            else MaSh_Py.learn ctxt overlord true [("", parents, map fst feats, deps)]
   1.131            end);
   1.132          (true, "")
   1.133        end)
   1.134 @@ -1361,7 +1367,6 @@
   1.135                val name = nickname_of_thm th
   1.136                val feats =
   1.137                  features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
   1.138 -                |> map fst
   1.139                val deps = deps_of status th |> these
   1.140                val n = n |> not (null deps) ? Integer.add 1
   1.141                val learns = (name, parents, feats, deps) :: learns