src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50814 4247cbd78aaf
parent 50755 4c781d65c0d6
child 50825 aed1d7242050
equal deleted inserted replaced
50813:b6659475b5af 50814:4247cbd78aaf
    47   val mash_unlearn : Proof.context -> unit
    47   val mash_unlearn : Proof.context -> unit
    48   val nickname_of_thm : thm -> string
    48   val nickname_of_thm : thm -> string
    49   val find_suggested_facts :
    49   val find_suggested_facts :
    50     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    50     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    51   val mesh_facts :
    51   val mesh_facts :
    52     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
    52     ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    53     -> ('a * thm) list
    53     -> 'a list
    54   val theory_ord : theory * theory -> order
    54   val theory_ord : theory * theory -> order
    55   val thm_ord : thm * thm -> order
    55   val thm_ord : thm * thm -> order
    56   val goal_of_thm : theory -> thm -> thm
    56   val goal_of_thm : theory -> thm -> thm
    57   val run_prover_for_mash :
    57   val run_prover_for_mash :
    58     Proof.context -> params -> string -> fact list -> thm -> prover_result
    58     Proof.context -> params -> string -> fact list -> thm -> prover_result
    79   val mash_learn :
    79   val mash_learn :
    80     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    80     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    81   val is_mash_enabled : unit -> bool
    81   val is_mash_enabled : unit -> bool
    82   val mash_can_suggest_facts : Proof.context -> bool
    82   val mash_can_suggest_facts : Proof.context -> bool
    83   val generous_max_facts : int -> int
    83   val generous_max_facts : int -> int
       
    84   val mepo_weight : real
       
    85   val mash_weight : real
    84   val relevant_facts :
    86   val relevant_facts :
    85     Proof.context -> params -> string -> int -> fact_override -> term list
    87     Proof.context -> params -> string -> int -> fact_override -> term list
    86     -> term -> fact list -> fact list
    88     -> term -> fact list -> fact list
    87   val kill_learners : unit -> unit
    89   val kill_learners : unit -> unit
    88   val running_learners : unit -> unit
    90   val running_learners : unit -> unit
   442   | normalize_scores max_facts xs =
   444   | normalize_scores max_facts xs =
   443     let val avg = avg (map snd (take max_facts xs)) in
   445     let val avg = avg (map snd (take max_facts xs)) in
   444       map (apsnd (curry Real.* (1.0 / avg))) xs
   446       map (apsnd (curry Real.* (1.0 / avg))) xs
   445     end
   447     end
   446 
   448 
   447 fun mesh_facts max_facts [(_, (sels, unks))] =
   449 fun mesh_facts _ max_facts [(_, (sels, unks))] =
   448     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   450     map fst (take max_facts sels) @ take (max_facts - length sels) unks
   449   | mesh_facts max_facts mess =
   451   | mesh_facts eq max_facts mess =
   450     let
   452     let
   451       val mess =
   453       val mess =
   452         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   454         mess |> map (apsnd (apfst (normalize_scores max_facts #> `length)))
   453       val fact_eq = Thm.eq_thm o pairself snd
   455       val fact_eq = eq
   454       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   456       fun score_in fact (global_weight, ((sel_len, sels), unks)) =
   455         let
   457         let
   456           fun score_at j =
   458           fun score_at j =
   457             case try (nth sels) j of
   459             case try (nth sels) j of
   458               SOME (_, score) => SOME (global_weight * score)
   460               SOME (_, score) => SOME (global_weight * score)
   767        (0.16 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   769        (0.16 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   768        (0.04 (* FUDGE *), (weight_proximity_facts proximity, []))]
   770        (0.04 (* FUDGE *), (weight_proximity_facts proximity, []))]
   769     val unknown =
   771     val unknown =
   770       raw_unknown
   772       raw_unknown
   771       |> fold (subtract (Thm.eq_thm_prop o pairself snd)) [chained, proximity]
   773       |> fold (subtract (Thm.eq_thm_prop o pairself snd)) [chained, proximity]
   772   in (mesh_facts max_facts mess, unknown) end
   774   in (mesh_facts (Thm.eq_thm o pairself snd) max_facts mess, unknown) end
   773 
   775 
   774 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   776 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   775                          concl_t facts =
   777                          concl_t facts =
   776   let
   778   let
   777     val thy = Proof_Context.theory_of ctxt
   779     val thy = Proof_Context.theory_of ctxt
  1052 
  1054 
  1053 (* Generate more suggestions than requested, because some might be thrown out
  1055 (* Generate more suggestions than requested, because some might be thrown out
  1054    later for various reasons. *)
  1056    later for various reasons. *)
  1055 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts div 2)
  1057 fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts div 2)
  1056 
  1058 
       
  1059 val mepo_weight = 0.5
       
  1060 val mash_weight = 0.5
       
  1061 
  1057 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1062 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
  1058    Sledgehammer and Try. *)
  1063    Sledgehammer and Try. *)
  1059 val min_secs_for_learning = 15
  1064 val min_secs_for_learning = 15
  1060 
  1065 
  1061 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1066 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
  1104       fun mash () =
  1109       fun mash () =
  1105         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1110         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1106             hyp_ts concl_t facts
  1111             hyp_ts concl_t facts
  1107         |>> weight_mash_facts
  1112         |>> weight_mash_facts
  1108       val mess =
  1113       val mess =
  1109         [] |> (if fact_filter <> mashN then cons (0.5, (mepo (), [])) else I)
  1114         [] |> (if fact_filter <> mashN then cons (mepo_weight, (mepo (), []))
  1110            |> (if fact_filter <> mepoN then cons (0.5, (mash ())) else I)
  1115                else I)
       
  1116            |> (if fact_filter <> mepoN then cons (mash_weight, (mash ()))
       
  1117                else I)
  1111     in
  1118     in
  1112       mesh_facts max_facts mess
  1119       mesh_facts (Thm.eq_thm o pairself snd) max_facts mess
  1113       |> not (null add_ths) ? prepend_facts add_ths
  1120       |> not (null add_ths) ? prepend_facts add_ths
  1114     end
  1121     end
  1115 
  1122 
  1116 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
  1123 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
  1117 fun running_learners () = Async_Manager.running_threads MaShN "learner"
  1124 fun running_learners () = Async_Manager.running_threads MaShN "learner"