src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57356 9816f692b0ca
parent 57355 a9e0f9d35125
child 57357 30ee18eb23ac
equal deleted inserted replaced
57355:a9e0f9d35125 57356:9816f692b0ca
    58       (string * real) list
    58       (string * real) list
    59   end
    59   end
    60 
    60 
    61   structure MaSh_SML :
    61   structure MaSh_SML :
    62   sig
    62   sig
    63     val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) ->
    63     val k_nearest_neighbors : int -> (int -> int list) -> (int -> (int * real) list) -> int ->
       
    64       int list -> (int * real) list -> (int * real) list
       
    65     val naive_bayes : (bool * bool) -> int -> (int -> int list) -> (int -> int list) -> int ->
    64       int -> (int * real) list -> (int * real) list
    66       int -> (int * real) list -> (int * real) list
    65     val naive_bayes : (bool * bool) -> int -> int -> (int -> int list) -> (int -> int list) ->
    67     val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
    66       int -> int -> (int * real) list -> (int * real) list
    68       int -> int -> (int * real) list -> (int * real) list
    67     val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) ->
       
    68       (int -> int list) -> int -> int -> (int * real) list -> (int * real) list
       
    69     val query : Proof.context -> bool -> mash_engine -> string list -> int ->
    69     val query : Proof.context -> bool -> mash_engine -> string list -> int ->
    70       (string * (string * real) list * string list) list * string list * (string * real) list ->
    70       (string * (string * real) list * string list) list * string list * (string * real) list ->
    71       string list
    71       string list
    72   end
    72   end
    73 
    73 
   421 
   421 
   422 exception EXIT of unit
   422 exception EXIT of unit
   423 
   423 
   424 (*
   424 (*
   425   num_facts = maximum number of theorems to check dependencies and symbols
   425   num_facts = maximum number of theorems to check dependencies and symbols
   426   num_visible_facts = do not return theorems over or equal to this number.
       
   427     Must satisfy: num_visible_facts <= num_facts.
       
   428   get_deps = returns dependencies of a theorem
   426   get_deps = returns dependencies of a theorem
   429   get_sym_ths = get theorems that have this feature
   427   get_sym_ths = get theorems that have this feature
   430   max_suggs = number of suggestions to return
   428   max_suggs = number of suggestions to return
   431   feats = features of the goal
   429   feats = features of the goal
   432 *)
   430 *)
   433 fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
   431 fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts feats =
   434   let
   432   let
   435     (* Can be later used for TFIDF *)
   433     (* Can be later used for TFIDF *)
   436     fun sym_wght _ = 1.0
   434     fun sym_wght _ = 1.0
   437 
   435 
   438     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
   436     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
   455       end
   453       end
   456 
   454 
   457     val _ = List.app do_feat feats
   455     val _ = List.app do_feat feats
   458     val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
   456     val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
   459     val no_recommends = Unsynchronized.ref 0
   457     val no_recommends = Unsynchronized.ref 0
   460     val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
   458     val recommends = Array.tabulate (num_facts, rpair 0.0)
   461     val age = Unsynchronized.ref 1000000000.0
   459     val age = Unsynchronized.ref 1000000000.0
   462 
   460 
   463     fun inc_recommend j v =
   461     fun inc_recommend j v =
   464       let val ov = snd (Array.sub (recommends, j)) in
   462       let val ov = snd (Array.sub (recommends, j)) in
   465         if ov <= 0.0 then
   463         if ov <= 0.0 then
   468           (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ())
   466           (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ())
   469       end
   467       end
   470 
   468 
   471     val k = Unsynchronized.ref 0
   469     val k = Unsynchronized.ref 0
   472     fun do_k k =
   470     fun do_k k =
   473       if k >= num_visible_facts then
   471       if k >= num_facts then
   474         raise EXIT ()
   472         raise EXIT ()
   475       else
   473       else
   476         let
   474         let
   477           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
   475           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
   478           val o1 = Math.sqrt o2
   476           val o1 = Math.sqrt o2
   494 
   492 
   495     fun ret acc at =
   493     fun ret acc at =
   496       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   494       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   497   in
   495   in
   498     while1 (); while2 ();
   496     while1 (); while2 ();
   499     heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
   497     heap (Real.compare o pairself snd) max_suggs num_facts recommends;
   500     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   498     ret [] (Integer.max 0 (num_facts - max_suggs))
   501   end
   499   end
   502 
   500 
   503 val nb_def_prior_weight = 21 (* FUDGE *)
   501 val nb_def_prior_weight = 21 (* FUDGE *)
   504 
   502 
   505 fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
   503 fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
   539     val dffreq = Array.array (num_feats, 0)
   537     val dffreq = Array.array (num_feats, 0)
   540   in
   538   in
   541     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   539     learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   542   end
   540   end
   543 
   541 
   544 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
   542 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts max_suggs feats
   545     (tfreq, sfreq, idf) =
   543     (tfreq, sfreq, idf) =
   546   let
   544   let
   547     val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
   545     val tau = if kuehlwein_params then 0.05 else 0.02 (* FUDGE *)
   548     val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *)
   546     val pos_weight = if kuehlwein_params then 10.0 else 2.0 (* FUDGE *)
   549     val def_val = ~15.0 (* FUDGE *)
   547     val def_val = ~15.0 (* FUDGE *)
   574         val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
   572         val sum_of_weights = Inttab.fold fold_sfh sfh 0.0
   575       in
   573       in
   576         res + tau * sum_of_weights
   574         res + tau * sum_of_weights
   577       end
   575       end
   578 
   576 
   579     val posterior = Array.tabulate (num_visible_facts, (fn j => (j, log_posterior j)))
   577     val posterior = Array.tabulate (num_facts, (fn j => (j, log_posterior j)))
   580 
   578 
   581     fun ret acc at =
   579     fun ret acc at =
   582       if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   580       if at = num_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   583   in
   581   in
   584     heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
   582     heap (Real.compare o pairself snd) max_suggs num_facts posterior;
   585     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   583     ret [] (Integer.max 0 (num_facts - max_suggs))
   586   end
   584   end
   587 
   585 
   588 fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
   586 fun naive_bayes opts num_facts get_deps get_feats num_feats max_suggs feats =
   589   learn num_facts get_deps get_feats num_feats
   587   learn num_facts get_deps get_feats num_feats
   590   |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
   588   |> naive_bayes_query opts num_facts max_suggs feats
   591 
   589 
   592 (* experimental *)
   590 (* experimental *)
   593 fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs
   591 fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
   594     feats =
       
   595   let
   592   let
   596     fun name_of_fact j = "f" ^ string_of_int j
   593     fun name_of_fact j = "f" ^ string_of_int j
   597     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
   594     fun fact_of_name s = the (Int.fromString (unprefix "f" s))
   598     fun name_of_feature j = "F" ^ string_of_int j
   595     fun name_of_feature j = "F" ^ string_of_int j
   599     fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
   596     fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
   600 
   597 
   601     val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
   598     val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
   602       map name_of_fact (get_deps j))) (0 upto num_facts - 1)
   599       map name_of_fact (get_deps j))) (0 upto num_facts - 1)
   603     val parents' = parents_of num_visible_facts
   600     val parents' = parents_of num_facts
   604     val feats' = map (apfst name_of_feature) feats
   601     val feats' = map (apfst name_of_feature) feats
   605   in
   602   in
   606     MaSh_Py.unlearn ctxt overlord;
   603     MaSh_Py.unlearn ctxt overlord;
   607     OS.Process.sleep (seconds 2.0); (* hack *)
   604     OS.Process.sleep (seconds 2.0); (* hack *)
   608     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   605     MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   653 
   650 
   654 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   651 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   655 
   652 
   656 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   653 fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
   657   let
   654   let
   658     val visible_fact_set = Symtab.make_set visible_facts
   655     val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
   659     val learns =
       
   660       (learns0 |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
       
   661       (if null hints then [] else [(".hints", feats, hints)])
       
   662   in
   656   in
   663     if engine = MaSh_SML_kNN_Cpp then
   657     if engine = MaSh_SML_kNN_Cpp then
   664       k_nearest_neighbors_cpp max_suggs learns (map fst feats)
   658       k_nearest_neighbors_cpp max_suggs learns (map fst feats)
   665     else if engine = MaSh_SML_NB_Cpp then
   659     else if engine = MaSh_SML_NB_Cpp then
   666       naive_bayes_cpp max_suggs learns (map fst feats)
   660       naive_bayes_cpp max_suggs learns (map fst feats)
   667     else
   661     else
   668       let
   662       let
   669         val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   663         val (rev_depss, rev_featss, (num_facts, fact_tab, rev_facts), (num_feats, feat_tab, _)) =
   670           fold (fn (fact, feats, deps) =>
   664           fold (fn (fact, feats, deps) =>
   671                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   665                 fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   672               let
   666               let
   673                 fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   667                 fun add_feat (feat, weight) (xtab as (n, tab, _)) =
   674                   (case Symtab.lookup tab feat of
   668                   (case Symtab.lookup tab feat of
   685         val facts = rev rev_facts
   679         val facts = rev rev_facts
   686         val fact_vec = Vector.fromList facts
   680         val fact_vec = Vector.fromList facts
   687 
   681 
   688         val deps_vec = Vector.fromList (rev rev_depss)
   682         val deps_vec = Vector.fromList (rev rev_depss)
   689 
   683 
   690         val num_visible_facts = length visible_facts
       
   691         val get_deps = curry Vector.sub deps_vec
   684         val get_deps = curry Vector.sub deps_vec
       
   685 
       
   686         val int_visible_facts = map (Symtab.lookup fact_tab) visible_facts
   692       in
   687       in
   693         trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   688         trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   694           elide_string 1000 (space_implode " " (take num_visible_facts facts)) ^ "}");
   689           elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
   695         (if engine = MaSh_SML_kNN then
   690         (if engine = MaSh_SML_kNN then
   696            let
   691            let
   697              val facts_ary = Array.array (num_feats, [])
   692              val facts_ary = Array.array (num_feats, [])
   698              val _ =
   693              val _ =
   699                fold (fn feats => fn fact =>
   694                fold (fn feats => fn fact =>
   702                        map_array_at facts_ary (cons (fact', weight)) feat) feats;
   697                        map_array_at facts_ary (cons (fact', weight)) feat) feats;
   703                      fact'
   698                      fact'
   704                    end)
   699                    end)
   705                  rev_featss num_facts
   700                  rev_featss num_facts
   706              val get_facts = curry Array.sub facts_ary
   701              val get_facts = curry Array.sub facts_ary
   707              val feats' = map_filter (fn (feat, weight) =>
   702              val int_feats = map_filter (fn (feat, weight) =>
   708                Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   703                Option.map (rpair weight) (Symtab.lookup feat_tab feat)) feats
   709            in
   704            in
   710              k_nearest_neighbors num_facts num_visible_facts get_deps get_facts max_suggs feats'
   705              k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts int_feats
   711            end
   706            end
   712          else
   707          else
   713            let
   708            let
   714              val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   709              val unweighted_feats_ary = Vector.fromList (map (map fst) (rev rev_featss))
   715              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   710              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   716              val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   711              val int_feats = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats
   717            in
   712            in
   718              (case engine of
   713              (case engine of
   719                MaSh_SML_NB opts =>
   714                MaSh_SML_NB opts =>
   720                naive_bayes opts num_facts num_visible_facts get_deps get_unweighted_feats num_feats
   715                naive_bayes opts num_facts get_deps get_unweighted_feats num_feats max_suggs
   721                  max_suggs int_feats
   716                  int_feats
   722              | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps
   717              | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
   723                  get_unweighted_feats num_feats max_suggs int_feats)
   718                  get_unweighted_feats num_feats max_suggs int_feats)
   724            end)
   719            end)
   725         |> map (curry Vector.sub fact_vec o fst)
   720         |> map (curry Vector.sub fact_vec o fst)
   726       end
   721       end
   727   end
   722   end