src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50412 e83ab94e3e6e
parent 50401 8e5d7ef3da76
child 50434 960a3429615c
equal deleted inserted replaced
50411:c9023d78d1a6 50412:e83ab94e3e6e
    39   val mash_QUERY :
    39   val mash_QUERY :
    40     Proof.context -> bool -> int -> string list * (string * real) list
    40     Proof.context -> bool -> int -> string list * (string * real) list
    41     -> (string * real) list
    41     -> (string * real) list
    42   val mash_unlearn : Proof.context -> unit
    42   val mash_unlearn : Proof.context -> unit
    43   val nickname_of : thm -> string
    43   val nickname_of : thm -> string
    44   val suggested_facts :
    44   val find_suggested_facts :
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    45     (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    46   val mesh_facts :
    46   val mesh_facts :
    47     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
    47     int -> (real * ((('a * thm) * real) list * ('a * thm) list)) list
    48     -> ('a * thm) list
    48     -> ('a * thm) list
    49   val theory_ord : theory * theory -> order
    49   val theory_ord : theory * theory -> order
    57   val isar_dependencies_of : unit Symtab.table -> thm -> string list option
    57   val isar_dependencies_of : unit Symtab.table -> thm -> string list option
    58   val atp_dependencies_of :
    58   val atp_dependencies_of :
    59     Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    59     Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    60     -> thm -> bool * string list option
    60     -> thm -> bool * string list option
    61   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
    61   val weight_mash_facts : ('a * thm) list -> (('a * thm) * real) list
       
    62   val find_mash_suggestions :
       
    63     int -> (Symtab.key * 'a) list -> ('b * thm) list -> ('b * thm) list
       
    64     -> ('b * thm) list -> ('b * thm) list
    62   val mash_suggested_facts :
    65   val mash_suggested_facts :
    63     Proof.context -> params -> string -> int -> term list -> term -> fact list
    66     Proof.context -> params -> string -> int -> term list -> term -> fact list
    64     -> fact list
    67     -> fact list
    65   val mash_learn_proof :
    68   val mash_learn_proof :
    66     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    69     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    67     -> unit
    70     -> unit
    68   val mash_learn :
    71   val mash_learn :
    69     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    72     Proof.context -> params -> fact_override -> thm list -> bool -> unit
    70   val is_mash_enabled : unit -> bool
    73   val is_mash_enabled : unit -> bool
    71   val mash_can_suggest_facts : Proof.context -> bool
    74   val mash_can_suggest_facts : Proof.context -> bool
       
    75   val generous_max_facts : int -> int
    72   val relevant_facts :
    76   val relevant_facts :
    73     Proof.context -> params -> string -> int -> fact_override -> term list
    77     Proof.context -> params -> string -> int -> fact_override -> term list
    74     -> term -> fact list -> fact list
    78     -> term -> fact list -> fact list
    75   val kill_learners : unit -> unit
    79   val kill_learners : unit -> unit
    76   val running_learners : unit -> unit
    80   val running_learners : unit -> unit
   408       | NONE => hint
   412       | NONE => hint
   409     end
   413     end
   410   else
   414   else
   411     backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th
   415     backquote_thm (Proof_Context.init_global (Thm.theory_of_thm th)) th
   412 
   416 
   413 fun suggested_facts suggs facts =
   417 fun find_suggested_facts suggs facts =
   414   let
   418   let
   415     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   419     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   416     val tab = Symtab.empty |> fold add_fact facts
   420     val tab = Symtab.empty |> fold add_fact facts
   417     fun find_sugg (name, weight) =
   421     fun find_sugg (name, weight) =
   418       Symtab.lookup tab name |> Option.map (rpair weight)
   422       Symtab.lookup tab name |> Option.map (rpair weight)
   730   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   734   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   731 
   735 
   732 fun weight_proximity_facts facts =
   736 fun weight_proximity_facts facts =
   733   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   737   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   734 
   738 
       
   739 fun find_mash_suggestions max_facts suggs facts chained unknown =
       
   740   let
       
   741     val raw_mash =
       
   742       facts |> find_suggested_facts suggs
       
   743             (* The weights currently returned by "mash.py" are too spaced out to
       
   744                make any sense. *)
       
   745             |> map fst
       
   746     val proximity = facts |> sort (thm_ord o pairself snd o swap)
       
   747     val mess =
       
   748       [(0.8000 (* FUDGE *), (map (rpair 1.0) chained, [])),
       
   749        (0.1333 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
       
   750        (0.0667 (* FUDGE *), (weight_proximity_facts proximity, []))]
       
   751   in mesh_facts max_facts mess end
       
   752 
   735 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   753 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   736                          concl_t facts =
   754                          concl_t facts =
   737   let
   755   let
   738     val thy = Proof_Context.theory_of ctxt
   756     val thy = Proof_Context.theory_of ctxt
   739     val (fact_G, suggs) =
   757     val (fact_G, suggs) =
   747                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   765                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   748             in
   766             in
   749               (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   767               (fact_G, mash_QUERY ctxt overlord max_facts (parents, feats))
   750             end)
   768             end)
   751     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   769     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   752     val raw_mash =
       
   753       facts |> suggested_facts suggs
       
   754             (* The weights currently returned by "mash.py" are too spaced out to
       
   755                make any sense. *)
       
   756             |> map fst
       
   757     val proximity = facts |> sort (thm_ord o pairself snd o swap)
       
   758     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   770     val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   759     val mess =
   771   in find_mash_suggestions max_facts suggs facts chained unknown end
   760       [(0.8000 (* FUDGE *), (map (rpair 1.0) chained, [])),
       
   761        (0.1333 (* FUDGE *), (weight_raw_mash_facts raw_mash, unknown)),
       
   762        (0.0667 (* FUDGE *), (weight_proximity_facts proximity, []))]
       
   763   in mesh_facts max_facts mess end
       
   764 
   772 
   765 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   773 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   766   let
   774   let
   767     fun maybe_add_from from (accum as (parents, graph)) =
   775     fun maybe_add_from from (accum as (parents, graph)) =
   768       try_graph ctxt "updating graph" accum (fn () =>
   776       try_graph ctxt "updating graph" accum (fn () =>