src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57294 ef9d4e1ceb00
parent 57291 1bac14e0a728
child 57296 8a98f08a0523
equal deleted inserted replaced
57293:4e619ee65a61 57294:ef9d4e1ceb00
   585     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   585     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   586     |> map (apfst fact_of_name)
   586     |> map (apfst fact_of_name)
   587   end
   587   end
   588 
   588 
   589 (* experimental *)
   589 (* experimental *)
   590 fun k_nearest_neighbors_cpp avail_num adv_max get_deps get_syms advno syms =
   590 fun k_nearest_neighbors_cpp max_suggs learns cfeats =
   591   let
   591   let
   592     val ocs = TextIO.openOut "adv_syms"
   592     val ocs = TextIO.openOut "adv_syms"
   593     val ocd = TextIO.openOut "adv_deps"
   593     val ocd = TextIO.openOut "adv_deps"
   594     val ocq = TextIO.openOut "adv_seq"
   594     val ocq = TextIO.openOut "adv_seq"
   595     val occ = TextIO.openOut "adv_conj"
   595     val occ = TextIO.openOut "adv_conj"
   596     fun os oc s = TextIO.output (oc, s)
   596     fun os oc s = TextIO.output (oc, s)
   597     fun oi oc i = os oc (Int.toString i)
   597     fun ol _ _ _   [] = ()
   598     fun ol _  _ _   [] = ()
   598       | ol _ f _   [e] = f e
   599       | ol _  f _   [e] = f e
       
   600       | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
   599       | ol oc f sep (h :: t) = (f h; os oc sep; ol oc f sep t)
   601     fun do_n n =
   600     fun do_learn (name, feats, deps) =
   602       (oi ocs n; os ocs ":"; ol ocs (fn i => (os ocs "\""; oi ocs i; os ocs "\"")) ", " (get_syms n); os ocs "\n";
   601       (os ocs name; os ocs ":"; ol ocs (fn (sy, _) => (os ocs "\""; os ocs sy; os ocs "\"")) ", " feats; os ocs "\n";
   603        oi ocd n; os ocd ":"; ol ocd (fn i => oi ocd i) " " (get_deps n); os ocd "\n";
   602        os ocd name; os ocd ":"; ol ocd (os ocd) " " deps; os ocd "\n";
   604        oi ocq n; os ocq "\n")
   603        os ocq name; os ocq "\n")
   605     fun for n = if n = avail_num then () else (do_n n; for (n + 1))
       
   606     fun forkexec no =
   604     fun forkexec no =
   607       let
   605       let
   608         val cmd =
   606         val cmd =
   609           "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^
   607           "~/misc/predict/knn " ^ string_of_int number_of_nearest_neighbors ^
   610           " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
   608           " adv_syms adv_deps " ^ string_of_int no ^ " adv_seq < adv_conj"
   611       in
   609       in
   612         fst (Isabelle_System.bash_output cmd)
   610         fst (Isabelle_System.bash_output cmd)
   613         |> space_explode " "
   611         |> space_explode " "
   614         |> map_filter (Option.map (rpair 1.0) o Int.fromString)
   612         |> filter_out (curry (op =) "")
   615       end
   613       end
   616   in
   614   in
   617     (for 0; ol occ (fn i => (os occ "\""; oi occ i; os occ "\"")) ", " syms; TextIO.closeOut ocs;
   615     (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats; TextIO.closeOut ocs;
   618      TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
   616      TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
   619      forkexec (advno + avail_num - adv_max))
   617      forkexec max_suggs)
   620   end
   618   end
   621 
   619 
   622 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   620 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   623 
   621 
   624 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   622 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   625 
   623 
   626 fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) =
   624 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   627   let
   625   let
   628     val visible_fact_set = Symtab.make_set visible_facts
   626     val visible_fact_set = Symtab.make_set visible_facts
   629 
   627     val learns =
   630     val learns' =
   628       (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
   631       (learns |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
       
   632       (if null hints then [] else [(".goal", feats, hints)])
   629       (if null hints then [] else [(".goal", feats, hints)])
   633 
   630   in
   634     val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   631     if engine = MaSh_SML_kNN_Cpp then
   635       fold (fn (fact, feats, deps) =>
   632       k_nearest_neighbors_cpp max_suggs learns (map fst feats)
   636             fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   633     else
   637           let
   634       let
   638             fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   635         val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   639               (case Symtab.lookup tab feat of
   636           fold (fn (fact, feats, deps) =>
   640                 SOME i => ((i, weight), xtab)
   637                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   641               | NONE => ((n, weight), add_to_xtab feat xtab))
   638               let
   642 
   639                 fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   643             val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   640                   (case Symtab.lookup tab feat of
   644           in
   641                     SOME i => ((i, weight), xtab)
   645             (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   642                   | NONE => ((n, weight), add_to_xtab feat xtab))
   646              add_to_xtab fact fact_xtab, feat_xtab')
   643 
   647           end)
   644                 val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   648         learns' ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   645               in
   649 
   646                 (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   650     val facts = rev rev_facts
   647                  add_to_xtab fact fact_xtab, feat_xtab')
   651     val fact_vec = Vector.fromList facts
   648               end)
   652 
   649             learns ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   653     val deps_vec = Vector.fromList (rev rev_depss)
   650 
   654 
   651         val facts = rev rev_facts
   655     val num_visible_facts = length visible_facts
   652         val fact_vec = Vector.fromList facts
   656     val get_deps = curry Vector.sub deps_vec
   653 
   657   in
   654         val deps_vec = Vector.fromList (rev rev_depss)
   658     trace_msg ctxt (fn () => "MaSh_SML " ^ " query " ^ encode_features feats ^ " from {" ^
   655 
   659       elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   656         val num_visible_facts = length visible_facts
   660     (if engine = MaSh_SML_kNN then
   657         val get_deps = curry Vector.sub deps_vec
   661        let
   658       in
   662          val facts_ary = Array.array (num_feats, [])
   659         trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   663          val _ =
   660           elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   664            fold (fn feats => fn fact =>
   661         (if engine = MaSh_SML_kNN then
   665                let val fact' = fact - 1 in
   662            let
   666                  List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   663              val facts_ary = Array.array (num_feats, [])
   667                    feats;
   664              val _ =
   668                  fact'
   665                fold (fn feats => fn fact =>
   669                end)
   666                    let val fact' = fact - 1 in
   670              rev_featss num_facts
   667                      List.app (fn (feat, weight) =>
   671          val get_facts = curry Array.sub facts_ary
   668                        map_array_at facts_ary (cons (fact', weight)) feat) feats;
   672          val feats' = map_filter (fn (feat, weight) =>
   669                      fact'
   673            Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   670                    end)
   674        in
   671                  rev_featss num_facts
   675          k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
   672              val get_facts = curry Array.sub facts_ary
   676        end
   673              val feats' = map_filter (fn (feat, weight) =>
   677      else
   674                Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   678        let
   675            in
   679          val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   676              k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
   680          val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   677            end
   681          val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   678          else
   682        in
   679            let
   683          (case engine of
   680              val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   684            MaSh_SML_kNN_Cpp =>
   681              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   685            k_nearest_neighbors_cpp num_facts num_visible_facts get_deps get_unweighted_feats
   682              val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   686              max_suggs (map fst feats')
   683            in
   687          | MaSh_SML_NB opts =>
   684              (case engine of
   688            naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
   685                MaSh_SML_NB opts =>
   689              max_suggs feats'
   686                naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
   690          | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
   687                  max_suggs int_feats
   691              get_unweighted_feats num_feats max_suggs feats')
   688              | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
   692        end)
   689                  get_unweighted_feats num_feats max_suggs int_feats)
   693     |> map (curry Vector.sub fact_vec o fst)
   690            end)
       
   691         |> map (curry Vector.sub fact_vec o fst)
       
   692       end
   694   end
   693   end
   695 
   694 
   696 end;
   695 end;
   697 
   696 
   698 
   697