src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50383 4274b25ff4e7
parent 50382 cb564ff43c28
child 50389 ad0ac9112d2c
equal deleted inserted replaced
50382:cb564ff43c28 50383:4274b25ff4e7
    42   val mash_unlearn : Proof.context -> unit
    42   val mash_unlearn : Proof.context -> unit
    43   val nickname_of : thm -> string
    43   val nickname_of : thm -> string
    44   val suggested_facts :
    44   val suggested_facts :
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    46   val mesh_facts :
    46   val mesh_facts :
    47     int -> ((('a * thm) * real) list * ('a * thm) list) list -> ('a * thm) list
    47     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
       
    48     -> ('a * thm) list
    48   val theory_ord : theory * theory -> order
    49   val theory_ord : theory * theory -> order
    49   val thm_ord : thm * thm -> order
    50   val thm_ord : thm * thm -> order
    50   val goal_of_thm : theory -> thm -> thm
    51   val goal_of_thm : theory -> thm -> thm
    51   val run_prover_for_mash :
    52   val run_prover_for_mash :
    52     Proof.context -> params -> string -> fact list -> thm -> prover_result
    53     Proof.context -> params -> string -> fact list -> thm -> prover_result
    57   val atp_dependencies_of :
    58   val atp_dependencies_of :
    58     Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    59     Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    59     -> thm -> bool * string list option
    60     -> thm -> bool * string list option
    60   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    61   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    61   val mash_suggested_facts :
    62   val mash_suggested_facts :
    62     Proof.context -> params -> string -> int -> term list -> term
    63     Proof.context -> params -> string -> int -> term list -> term -> fact list
    63     -> fact list -> fact list * fact list
    64     -> fact list
    64   val mash_learn_proof :
    65   val mash_learn_proof :
    65     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    66     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    66     -> unit
    67     -> unit
    67   val mash_learn :
    68   val mash_learn :
    68     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    69     Proof.context -> params -> fact_override -> thm list -> bool -> unit
   296 
   297 
   297 val empty_state = {fact_G = Graph.empty, dirty = SOME []}
   298 val empty_state = {fact_G = Graph.empty, dirty = SOME []}
   298 
   299 
   299 local
   300 local
   300 
   301 
   301 val version = "*** MaSh version 20121204a ***"
   302 val version = "*** MaSh version 20121205a ***"
   302 
   303 
   303 exception Too_New of unit
   304 exception Too_New of unit
   304 
   305 
   305 fun extract_node line =
   306 fun extract_node line =
   306   case space_explode ":" line of
   307   case space_explode ":" line of
   423     val tab = Symtab.empty |> fold add_fact facts
   424     val tab = Symtab.empty |> fold add_fact facts
   424     fun find_sugg (name, weight) =
   425     fun find_sugg (name, weight) =
   425       Symtab.lookup tab name |> Option.map (rpair weight)
   426       Symtab.lookup tab name |> Option.map (rpair weight)
   426   in map_filter find_sugg suggs end
   427   in map_filter find_sugg suggs end
   427 
   428 
   428 fun sum_avg [] = 0
   429 fun scaled_avg [] = 0
   429   | sum_avg xs =
   430   | scaled_avg xs =
   430     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   431     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   431 
   432 
   432 fun normalize_scores [] = []
   433 fun avg [] = 0.0
   433   | normalize_scores ((fact, score) :: tail) =
   434   | avg xs = fold (curry (op +)) xs 0.0 / Real.fromInt (length xs)
   434     (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail
   435 
   435 
   436 fun normalize_scores _ [] = []
   436 fun mesh_facts max_facts [(sels, unks)] =
   437   | normalize_scores max_facts xs =
       
   438     let val avg = avg (map snd (take max_facts xs)) in
       
   439       map (apsnd (curry Real.* (1.0 / avg))) xs
       
   440     end
       
   441 
       
   442 fun mesh_facts max_facts [(_, (sels, unks))] =
   437     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   443     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   438   | mesh_facts max_facts mess =
   444   | mesh_facts max_facts mess =
   439     let
   445     let
   440       val mess = mess |> map (apfst (normalize_scores #> `length))
   446       val mess =
       
   447         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   441       val fact_eq = Thm.eq_thm o pairself snd
   448       val fact_eq = Thm.eq_thm o pairself snd
   442       fun score_at sels = try (nth sels) #> Option.map snd
   449       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   443       fun score_in fact ((sel_len, sels), unks) =
   450         let
   444         case find_index (curry fact_eq fact o fst) sels of
   451           fun score_at j =
   445           ~1 => (case find_index (curry fact_eq fact) unks of
   452             case try (nth sels) j of
   446                    ~1 => score_at sels sel_len
   453               SOME (_, score) => SOME (global_weight * score)
   447                  | _ => NONE)
   454             | NONE => NONE
   448         | rank => score_at sels rank
   455         in
   449       fun weight_of fact = mess |> map_filter (score_in fact) |> sum_avg
   456           case find_index (curry fact_eq fact o fst) sels of
       
   457             ~1 => (case find_index (curry fact_eq fact) unks of
       
   458                      ~1 => score_at sel_len
       
   459                    | _ => NONE)
       
   460           | rank => score_at rank
       
   461         end
       
   462       fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg
   450       val facts =
   463       val facts =
   451         fold (union fact_eq o map fst o take max_facts o snd o fst) mess []
   464         fold (union fact_eq o map fst o take max_facts o snd o fst o snd) mess
       
   465              []
   452     in
   466     in
   453       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   467       facts |> map (`weight_of) |> sort (int_ord o swap o pairself fst)
   454             |> map snd |> take max_facts
   468             |> map snd |> take max_facts
   455     end
   469     end
   456 
   470 
   457 fun thy_feature_of s = ("y" ^ s, 1.0 (* FUDGE *))
   471 fun thy_feature_of s = ("y" ^ s, 1.0 (* FUDGE *))
   458 fun term_feature_of s = ("c" ^ s, 1.0 (* FUDGE *))
   472 fun term_feature_of s = ("c" ^ s, 1.0 (* FUDGE *))
   459 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *))
   473 fun type_feature_of s = ("t" ^ s, 1.0 (* FUDGE *))
   460 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
   474 fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
   461 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *))
   475 fun status_feature_of status = (string_of_status status, 1.0 (* FUDGE *))
   462 val local_feature = ("local", 20.0 (* FUDGE *))
   476 val local_feature = ("local", 1.0 (* FUDGE *))
   463 val lams_feature = ("lams", 1.0 (* FUDGE *))
   477 val lams_feature = ("lams", 1.0 (* FUDGE *))
   464 val skos_feature = ("skos", 1.0 (* FUDGE *))
   478 val skos_feature = ("skos", 1.0 (* FUDGE *))
   465 
   479 
   466 fun theory_ord p =
   480 fun theory_ord p =
   467   if Theory.eq_thy p then
   481   if Theory.eq_thy p then
   529       | patternify_term _ 0 _ = []
   543       | patternify_term _ 0 _ = []
   530       | patternify_term args depth (t $ u) =
   544       | patternify_term args depth (t $ u) =
   531         let
   545         let
   532           val ps = patternify_term (u :: args) depth t
   546           val ps = patternify_term (u :: args) depth t
   533           val qs = "" :: patternify_term [] (depth - 1) u
   547           val qs = "" :: patternify_term [] (depth - 1) u
   534         in map_product (fn p => fn "" => p | q => "(" ^ q ^ ")") ps qs end
   548         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   535       | patternify_term _ _ _ = []
   549       | patternify_term _ _ _ = []
   536     val add_term_pattern =
   550     val add_term_pattern =
   537       union (op = o pairself fst) o map term_feature_of oo patternify_term []
   551       union (op = o pairself fst) o map term_feature_of oo patternify_term []
   538     fun add_term_patterns ~1 _ = I
   552     fun add_term_patterns ~1 _ = I
   539       | add_term_patterns depth t =
   553       | add_term_patterns depth t =
   690              else
   704              else
   691                (maxs, Graph.Keys.fold (insert_new seen)
   705                (maxs, Graph.Keys.fold (insert_new seen)
   692                                       (Graph.imm_preds fact_G new) news))
   706                                       (Graph.imm_preds fact_G new) news))
   693   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   707   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   694 
   708 
   695 (* Generate more suggestions than requested, because some might be thrown out
       
   696    later for various reasons and "meshing" gives better results with some
       
   697    slack. *)
       
   698 fun max_suggs_of max_facts = max_facts + Int.min (50, max_facts)
       
   699 
       
   700 fun is_fact_in_graph fact_G (_, th) =
   709 fun is_fact_in_graph fact_G (_, th) =
   701   can (Graph.get_node fact_G) (nickname_of th)
   710   can (Graph.get_node fact_G) (nickname_of th)
   702 
   711 
   703 fun interleave 0 _ _ = []
       
   704   | interleave n [] ys = take n ys
       
   705   | interleave n xs [] = take n xs
       
   706   | interleave 1 (x :: _) _ = [x]
       
   707   | interleave n (x :: xs) (y :: ys) = x :: y :: interleave (n - 2) xs ys
       
   708 
       
   709 (* factor that controls whether unknown global facts should be included *)
   712 (* factor that controls whether unknown global facts should be included *)
   710 val include_unk_global_factor = 15
   713 val include_unk_global_factor = 15
   711 
   714 
   712 val weight_mash_facts = weight_mepo_facts (* use MePo weights for now *)
   715 (* use MePo weights for now *)
       
   716 val weight_raw_mash_facts = weight_mepo_facts
       
   717 val weight_mash_facts = weight_raw_mash_facts
       
   718 
       
   719 (* FUDGE *)
       
   720 fun weight_of_proximity_fact rank =
       
   721   Math.pow (1.3, 15.5 - 0.05 * Real.fromInt rank) + 15.0
       
   722 
       
   723 fun weight_proximity_facts facts =
       
   724   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   713 
   725 
   714 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   726 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   715                          concl_t facts =
   727                          concl_t facts =
   716   let
   728   let
   717     val thy = Proof_Context.theory_of ctxt
   729     val thy = Proof_Context.theory_of ctxt
   723             let
   735             let
   724               val parents = maximal_in_graph fact_G facts
   736               val parents = maximal_in_graph fact_G facts
   725               val feats =
   737               val feats =
   726                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   738                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   727             in
   739             in
   728               (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts)
   740               (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   729                                   (parents, feats))
       
   730             end)
   741             end)
   731     val (chained, unchained) =
   742     val (chained, unchained) =
   732       List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
   743       List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
   733     val sels =
   744     val raw_mash =
   734       facts |> suggested_facts suggs
   745       facts |> suggested_facts suggs
   735             (* The weights currently returned by "mash.py" are too spaced out to
   746             (* The weights currently returned by "mash.py" are too spaced out to
   736                make any sense. *)
   747                make any sense. *)
   737             |> map fst
   748             |> map fst
   738             |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained)
   749     val proximity =
   739     val (unk_global, unk_local) =
   750       chained @ (facts |> subtract (Thm.eq_thm_prop o pairself snd) chained
   740       unchained |> filter_out (is_fact_in_graph fact_G)
   751                        |> sort (thm_ord o pairself snd o swap))
   741                 |> List.partition (fn ((_, (scope, _)), _) => scope = Global)
   752     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   742     val (small_unk_global, big_unk_global) =
   753     val mess =
   743       ([], unk_global)
   754       [(0.667 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
   744       |> include_unk_global_factor * length unk_global <= max_facts ? swap
   755        (0.333 (* FUDGE *), (weight_proximity_facts proximity, []))]
   745   in
   756   in mesh_facts max_facts mess end
   746     (interleave max_facts (chained @ unk_local @ small_unk_global) sels,
       
   747      big_unk_global)
       
   748   end
       
   749 
   757 
   750 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   758 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   751   let
   759   let
   752     fun maybe_add_from from (accum as (parents, graph)) =
   760     fun maybe_add_from from (accum as (parents, graph)) =
   753       try_graph ctxt "updating graph" accum (fn () =>
   761       try_graph ctxt "updating graph" accum (fn () =>
   993   end
  1001   end
   994 
  1002 
   995 fun is_mash_enabled () = (getenv "MASH" = "yes")
  1003 fun is_mash_enabled () = (getenv "MASH" = "yes")
   996 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
  1004 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
   997 
  1005 
       
  1006 (* Generate more suggestions than requested, because some might be thrown out
       
  1007    later for various reasons. *)
       
  1008 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
       
  1009 
   998 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1010 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
   999    Sledgehammer and Try. *)
  1011    Sledgehammer and Try. *)
  1000 val min_secs_for_learning = 15
  1012 val min_secs_for_learning = 15
  1001 
  1013 
  1002 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1014 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1038       fun mepo () =
  1050       fun mepo () =
  1039         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1051         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1040                              facts
  1052                              facts
  1041         |> weight_mepo_facts
  1053         |> weight_mepo_facts
  1042       fun mash () =
  1054       fun mash () =
  1043         mash_suggested_facts ctxt params prover max_facts hyp_ts concl_t facts
  1055         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1044         |>> weight_mash_facts
  1056             hyp_ts concl_t facts
       
  1057         |> weight_mash_facts
  1045       val mess =
  1058       val mess =
  1046         [] |> (if fact_filter <> mashN then cons (mepo (), []) else I)
  1059         [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
  1047            |> (if fact_filter <> mepoN then cons (mash ()) else I)
  1060            |> (if fact_filter <> mepoN then cons (0.5, (mash (), [])) else I)
  1048     in
  1061     in
  1049       mesh_facts max_facts mess
  1062       mesh_facts max_facts mess
  1050       |> not (null add_ths) ? prepend_facts add_ths
  1063       |> not (null add_ths) ? prepend_facts add_ths
  1051     end
  1064     end
  1052 
  1065