src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53140 a1235e90da5f
parent 53137 a33298b49d9f
child 53141 d27e99a6a679
equal deleted inserted replaced
53139:07a6e11f1631 53140:a1235e90da5f
    71     Proof.context -> params -> string -> int -> raw_fact list
    71     Proof.context -> params -> string -> int -> raw_fact list
    72     -> string Symtab.table * string Symtab.table -> thm
    72     -> string Symtab.table * string Symtab.table -> thm
    73     -> bool * string list
    73     -> bool * string list
    74   val attach_parents_to_facts :
    74   val attach_parents_to_facts :
    75     ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
    75     ('a * thm) list -> ('a * thm) list -> (string list * ('a * thm)) list
    76   val weight_mepo_facts : 'a list -> ('a * real) list
    76   val num_extra_feature_facts : int
    77   val weight_mash_facts : 'a list -> ('a * real) list
    77   val extra_feature_weight_factor : real
       
    78   val weight_facts_smoothly : 'a list -> ('a * real) list
       
    79   val weight_facts_steeply : 'a list -> ('a * real) list
    78   val find_mash_suggestions :
    80   val find_mash_suggestions :
    79     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    81     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    80     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    82     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    81   val add_const_counts : term -> int Symtab.table -> int Symtab.table
    83   val add_const_counts : term -> int Symtab.table -> int Symtab.table
    82   val mash_suggested_facts :
    84   val mash_suggested_facts :
   901   map (nickname_of_thm o snd)
   903   map (nickname_of_thm o snd)
   902   #> maximal_wrt_graph access_G
   904   #> maximal_wrt_graph access_G
   903 
   905 
   904 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   906 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   905 
   907 
   906 (* FUDGE *)
   908 val num_extra_feature_facts = 10 (* FUDGE *)
   907 fun weight_of_mepo_fact rank =
   909 val extra_feature_weight_factor = 0.1
   908   Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
       
   909 
       
   910 fun weight_mepo_facts facts =
       
   911   facts ~~ map weight_of_mepo_fact (0 upto length facts - 1)
       
   912 
       
   913 val weight_raw_mash_facts = weight_mepo_facts
       
   914 val weight_mash_facts = weight_raw_mash_facts
       
   915 
   910 
   916 (* FUDGE *)
   911 (* FUDGE *)
   917 fun weight_of_proximity_fact rank =
   912 fun weight_of_proximity_fact rank =
   918   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   913   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   919 
   914 
   920 fun weight_proximity_facts facts =
   915 fun weight_facts_smoothly facts =
   921   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
   916   facts ~~ map weight_of_proximity_fact (0 upto length facts - 1)
       
   917 
       
   918 (* FUDGE *)
       
   919 fun steep_weight_of_fact rank =
       
   920   Math.pow (0.62, log2 (Real.fromInt (rank + 1)))
       
   921 
       
   922 fun weight_facts_steeply facts =
       
   923   facts ~~ map steep_weight_of_fact (0 upto length facts - 1)
   922 
   924 
   923 val max_proximity_facts = 100
   925 val max_proximity_facts = 100
   924 
   926 
   925 fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
   927 fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
   926   | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   928   | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   931       val proximity =
   933       val proximity =
   932         facts |> sort (crude_thm_ord o pairself snd o swap)
   934         facts |> sort (crude_thm_ord o pairself snd o swap)
   933               |> take max_proximity_facts
   935               |> take max_proximity_facts
   934       val mess =
   936       val mess =
   935         [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   937         [(0.90 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   936          (0.08 (* FUDGE *), (weight_raw_mash_facts raw_mash, raw_unknown)),
   938          (0.08 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown)),
   937          (0.02 (* FUDGE *), (weight_proximity_facts proximity, []))]
   939          (0.02 (* FUDGE *), (weight_facts_smoothly proximity, []))]
   938       val unknown =
   940       val unknown =
   939         raw_unknown
   941         raw_unknown
   940         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   942         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   941                 [unknown_chained, proximity]
   943                 [unknown_chained, proximity]
   942     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   944     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
  1287                 (accepts |> filter_out in_add))
  1289                 (accepts |> filter_out in_add))
  1288         |> take max_facts
  1290         |> take max_facts
  1289       fun mepo () =
  1291       fun mepo () =
  1290         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1292         mepo_suggested_facts ctxt params prover max_facts NONE hyp_ts concl_t
  1291                              facts
  1293                              facts
  1292         |> weight_mepo_facts
  1294         |> weight_facts_steeply
  1293       fun mash () =
  1295       fun mash () =
  1294         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1296         mash_suggested_facts ctxt params prover (generous_max_facts max_facts)
  1295             hyp_ts concl_t facts
  1297             hyp_ts concl_t facts
  1296         |>> weight_mash_facts
  1298         |>> weight_facts_steeply
  1297       val mess =
  1299       val mess =
  1298         (* the order is important for the "case" expression below *)
  1300         (* the order is important for the "case" expression below *)
  1299         [] |> (if effective_fact_filter <> mepoN then
  1301         [] |> (if effective_fact_filter <> mepoN then
  1300                  cons (mash_weight, (mash ()))
  1302                  cons (mash_weight, (mash ()))
  1301                else
  1303                else