src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57291 1bac14e0a728
parent 57281 bb671e6b740d
child 57294 ef9d4e1ceb00
equal deleted inserted replaced
57290:bc06471cb7b7 57291:1bac14e0a728
    36   val extract_suggestions : string -> string * (string * real) list
    36   val extract_suggestions : string -> string * (string * real) list
    37 
    37 
    38   datatype mash_engine =
    38   datatype mash_engine =
    39     MaSh_Py
    39     MaSh_Py
    40   | MaSh_SML_kNN
    40   | MaSh_SML_kNN
       
    41   | MaSh_SML_kNN_Cpp
    41   | MaSh_SML_NB of bool * bool
    42   | MaSh_SML_NB of bool * bool
    42   | MaSh_SML_NB_Py
    43   | MaSh_SML_NB_Py
    43 
    44 
    44   val is_mash_enabled : unit -> bool
    45   val is_mash_enabled : unit -> bool
    45   val the_mash_engine : unit -> mash_engine
    46   val the_mash_engine : unit -> mash_engine
   155   end
   156   end
   156 
   157 
   157 datatype mash_engine =
   158 datatype mash_engine =
   158   MaSh_Py
   159   MaSh_Py
   159 | MaSh_SML_kNN
   160 | MaSh_SML_kNN
       
   161 | MaSh_SML_kNN_Cpp
   160 | MaSh_SML_NB of bool * bool
   162 | MaSh_SML_NB of bool * bool
   161 | MaSh_SML_NB_Py
   163 | MaSh_SML_NB_Py
   162 
   164 
   163 val default_MaSh_SML_NB = MaSh_SML_NB (false, true)
   165 val default_MaSh_SML_NB = MaSh_SML_NB (false, true)
   164 
   166 
   167     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
   169     (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
   168       "yes" => SOME default_MaSh_SML_NB
   170       "yes" => SOME default_MaSh_SML_NB
   169     | "py" => SOME MaSh_Py
   171     | "py" => SOME MaSh_Py
   170     | "sml" => SOME default_MaSh_SML_NB
   172     | "sml" => SOME default_MaSh_SML_NB
   171     | "sml_knn" => SOME MaSh_SML_kNN
   173     | "sml_knn" => SOME MaSh_SML_kNN
       
   174     | "sml_knn_cpp" => SOME MaSh_SML_kNN_Cpp
   172     | "sml_nb" => SOME default_MaSh_SML_NB
   175     | "sml_nb" => SOME default_MaSh_SML_NB
   173     | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
   176     | "sml_nbCC" => SOME (MaSh_SML_NB (false, false))
   174     | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
   177     | "sml_nbCD" => SOME (MaSh_SML_NB (false, true))
   175     | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
   178     | "sml_nbDC" => SOME (MaSh_SML_NB (true, false))
   176     | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
   179     | "sml_nbDD" => SOME (MaSh_SML_NB (true, true))
   581     OS.Process.sleep (seconds 2.0); (* hack *)
   584     OS.Process.sleep (seconds 2.0); (* hack *)
   582     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   585     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   583     |> map (apfst fact_of_name)
   586     |> map (apfst fact_of_name)
   584   end
   587   end
   585 
   588 
       
   589 (* experimental *)
       
   590 fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms =
       
   591   let
       
   592     val ocs = TextIO.openOut "adv_syms"
       
   593     val ocd = TextIO.openOut "adv_deps"
       
   594     val ocq = TextIO.openOut "adv_seq"
       
   595     val occ = TextIO.openOut "adv_conj"
       
   596     fun os oc s = TextIO.output (oc, s)
       
   597     fun oi oc i = os oc (Int.toString i)
       
   598     fun ol _  _ _   [] = ()
       
   599       | ol _  f _   [e] = f e
       
   600       | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
       
   601     fun do_n n =
       
   602       (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n";
       
   603        oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n";
       
   604        oi ocq n; os ocq "\n")
       
   605     fun for n = if n = avail_num then () else (do_n n; for (n + 1))
       
   606     fun forkexec no =
       
   607       let
       
   608         val cmd =
       
   609           "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^
       
   610           " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
       
   611       in
       
   612         fst (Isabelle_System.bash_output cmd)
       
   613         |> space_explode " "
       
   614         |> map_filter (Option.map (rpair 1.0) o Int.fromString)
       
   615       end
       
   616   in
       
   617     (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs;
       
   618      TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
       
   619      forkexec (advno + avail_num - adv_max))
       
   620   end
       
   621 
   586 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   622 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   587 
   623 
   588 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   624 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   589 
   625 
   590 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
   626 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
   617     val deps_vec = Vector.fromList (rev rev_depss)
   653     val deps_vec = Vector.fromList (rev rev_depss)
   618 
   654 
   619     val num_visible_facts = length visible_facts
   655     val num_visible_facts = length visible_facts
   620     val get_deps = curry Vector.sub deps_vec
   656     val get_deps = curry Vector.sub deps_vec
   621   in
   657   in
   622     trace_msg ctxt (fn () => "MaSh_SML " ^ (if engine = MaSh_SML_kNN then "kNN" else "NB") ^
   658     trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^
   623       " query " ^ encode_features feats ^ " from {" ^
       
   624       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   659       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   625     (if engine = MaSh_SML_kNN then
   660     (if engine = MaSh_SML_kNN then
   626        let
   661        let
   627          val facts_ary = Array.array (num_feats, [])
   662          val facts_ary = Array.array (num_feats, [])
   628          val _ =
   663          val _ =
   644          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   679          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   645          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   680          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   646          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   681          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   647        in
   682        in
   648          (case engine of
   683          (case engine of
   649            MaSh_SML_NB opts => naive_bayes opts
   684            MaSh_SML_kNN_Cpp =>
   650          | _ => naive_bayes_py ctxt overlord)
   685            k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats
   651            num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats'
   686              max_suggs (map fst feats')
       
   687          | MaSh_SML_NB opts =>
       
   688            naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
       
   689              max_suggs feats'
       
   690          | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
       
   691              get_unweighted_feats num_feats max_suggs feats')
   652        end)
   692        end)
   653     |> map (curry Vector.sub fact_vec o fst)
   693     |> map (curry Vector.sub fact_vec o fst)
   654   end
   694   end
   655 
   695 
   656 end;
   696 end;