src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57353 ee493eb30c7b
parent 57306 ff10067b2248
child 57354 ded92100ffd7
equal deleted inserted replaced
57352:9801e9fa9270 57353:ee493eb30c7b
   417       end
   417       end
   418     else
   418     else
   419       ()
   419       ()
   420   end
   420   end
   421 
   421 
       
   422 val number_of_nearest_neighbors = 10 (* FUDGE *)
       
   423 
       
   424 exception EXIT of unit
       
   425 
   422 (*
   426 (*
   423   num_facts = maximum number of theorems to check dependencies and symbols
   427   num_facts = maximum number of theorems to check dependencies and symbols
   424   num_visible_facts = do not return theorems over or equal to this number.
   428   num_visible_facts = do not return theorems over or equal to this number.
   425     Must satisfy: num_visible_facts <= num_facts.
   429     Must satisfy: num_visible_facts <= num_facts.
   426   get_deps = returns dependencies of a theorem
   430   get_deps = returns dependencies of a theorem
   428   max_suggs = number of suggestions to return
   432   max_suggs = number of suggestions to return
   429   feats = features of the goal
   433   feats = features of the goal
   430 *)
   434 *)
   431 fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
   435 fun k_nearest_neighbors num_facts num_visible_facts get_deps get_sym_ths max_suggs feats =
   432   let
   436   let
   433     val number_of_nearest_neighbors = 40 (* FUDGE *)
       
   434 
       
   435     (* Can be later used for TFIDF *)
   437     (* Can be later used for TFIDF *)
   436     fun sym_wght _ = 1.0
   438     fun sym_wght _ = 1.0
   437 
   439 
   438     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
   440     val overlaps_sqr = Array.tabulate (num_facts, rpair 0.0)
   439 
   441 
   453       in
   455       in
   454         List.app do_th (get_sym_ths s)
   456         List.app do_th (get_sym_ths s)
   455       end
   457       end
   456 
   458 
   457     val _ = List.app do_feat feats
   459     val _ = List.app do_feat feats
   458     val _ = heap (Real.compare o pairself snd) number_of_nearest_neighbors overlaps_sqr
   460     val _ = heap (Real.compare o pairself snd) num_facts overlaps_sqr
       
   461     val no_recommends = Unsynchronized.ref 0
   459     val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
   462     val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
       
   463     val age = Unsynchronized.ref 1000000000.0
   460 
   464 
   461     fun inc_recommend j v =
   465     fun inc_recommend j v =
   462       if j >= num_visible_facts then ()
   466       let val ov = snd (Array.sub (recommends, j)) in
   463       else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
   467         if ov <= 0.0 then
   464 
   468           (no_recommends := !no_recommends + 1; Array.update (recommends, j, (j, !age + ov)))
   465     fun for k =
   469         else
   466       if k = number_of_nearest_neighbors orelse k >= num_visible_facts then
   470           (if ov < !age + 1000.0 then Array.update (recommends, j, (j, v + ov)) else ())
   467         ()
   471       end
       
   472 
       
   473     val k = Unsynchronized.ref 0
       
   474     fun do_k k =
       
   475       if k >= num_visible_facts then
       
   476         raise EXIT ()
   468       else
   477       else
   469         let
   478         let
   470           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
   479           val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
   471           val o1 = Math.sqrt o2
   480           val o1 = Math.sqrt o2
   472           val _ = inc_recommend j o1
   481           val _ = inc_recommend j o1
   473           val ds = get_deps j
   482           val ds = get_deps j
   474           val l = Real.fromInt (length ds)
   483           val l = Real.fromInt (length ds)
   475         in
   484         in
   476           List.app (fn d => inc_recommend d (o1 / l)) ds; for (k + 1)
   485           List.app (fn d => inc_recommend d (o1 / l)) ds
   477         end
   486         end
       
   487 
       
   488     fun while1 () =
       
   489       if !k = number_of_nearest_neighbors then () else (do_k (!k); k := !k + 1; while1 ())
       
   490       handle EXIT () => ()
       
   491 
       
   492     fun while2 () =
       
   493       if !no_recommends >= max_suggs then ()
       
   494       else (do_k (!k); k := !k + 1; age := !age - 10000.0; while2 ())
       
   495       handle EXIT () => ()
   478 
   496 
   479     fun ret acc at =
   497     fun ret acc at =
   480       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   498       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   481   in
   499   in
   482     for 0;
   500     while1 (); while2 ();
   483     heap (Real.compare o pairself snd) max_suggs recommends;
   501     heap (Real.compare o pairself snd) max_suggs recommends;
   484     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   502     ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   485   end
   503   end
   486 
   504 
   487 val nb_def_prior_weight = 21 (* FUDGE *)
   505 val nb_def_prior_weight = 21 (* FUDGE *)
   622     (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
   640     (List.app do_learn learns; ol occ (fn sy => (os occ "\""; os occ sy; os occ "\"")) ", " cfeats;
   623      TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
   641      TextIO.closeOut ocs; TextIO.closeOut ocd; TextIO.closeOut ocq; TextIO.closeOut occ;
   624      forkexec max_suggs)
   642      forkexec max_suggs)
   625   end
   643   end
   626 
   644 
   627 val cpp_number_of_nearest_neighbors = 10 (* FUDGE *)
       
   628 
       
   629 val k_nearest_neighbors_cpp =
   645 val k_nearest_neighbors_cpp =
   630   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int cpp_number_of_nearest_neighbors)
   646   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
   631 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   647 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
   632 
   648 
   633 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   649 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   634 
   650 
   635 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   651 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
  1494         (true, "")
  1510         (true, "")
  1495       end)
  1511       end)
  1496   else
  1512   else
  1497     ()
  1513     ()
  1498 
  1514 
  1499 fun sendback sub =
  1515 fun sendback sub = Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
  1500   Active.sendback_markup [Markup.padding_command] (sledgehammerN ^ " " ^ sub)
       
  1501 
  1516 
  1502 val commit_timeout = seconds 30.0
  1517 val commit_timeout = seconds 30.0
  1503 
  1518 
  1504 (* The timeout is understood in a very relaxed fashion. *)
  1519 (* The timeout is understood in a very relaxed fashion. *)
  1505 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level
  1520 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level