src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53127 60801776d8af
parent 53121 5f727525b1ac
child 53128 ea1b62ed5a54
equal deleted inserted replaced
53126:f4d2c64c7aa8 53127:60801776d8af
    60   val thm_less : thm * thm -> bool
    60   val thm_less : thm * thm -> bool
    61   val goal_of_thm : theory -> thm -> thm
    61   val goal_of_thm : theory -> thm -> thm
    62   val run_prover_for_mash :
    62   val run_prover_for_mash :
    63     Proof.context -> params -> string -> fact list -> thm -> prover_result
    63     Proof.context -> params -> string -> fact list -> thm -> prover_result
    64   val features_of :
    64   val features_of :
    65     Proof.context -> string -> theory -> stature -> term list
    65     Proof.context -> string -> theory -> int -> int Symtab.table -> stature
    66     -> (string * real) list
    66     -> term list -> (string * real) list
    67   val trim_dependencies : string list -> string list option
    67   val trim_dependencies : string list -> string list option
    68   val isar_dependencies_of :
    68   val isar_dependencies_of :
    69     string Symtab.table * string Symtab.table -> thm -> string list
    69     string Symtab.table * string Symtab.table -> thm -> string list
    70   val prover_dependencies_of :
    70   val prover_dependencies_of :
    71     Proof.context -> params -> string -> int -> raw_fact list
    71     Proof.context -> params -> string -> int -> raw_fact list
    76   val weight_mepo_facts : 'a list -> ('a * real) list
    76   val weight_mepo_facts : 'a list -> ('a * real) list
    77   val weight_mash_facts : 'a list -> ('a * real) list
    77   val weight_mash_facts : 'a list -> ('a * real) list
    78   val find_mash_suggestions :
    78   val find_mash_suggestions :
    79     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    79     Proof.context -> int -> string list -> ('b * thm) list -> ('b * thm) list
    80     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    80     -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
       
    81   val add_const_counts : term -> int Symtab.table -> int Symtab.table
    81   val mash_suggested_facts :
    82   val mash_suggested_facts :
    82     Proof.context -> params -> string -> int -> term list -> term
    83     Proof.context -> params -> string -> int -> term list -> term
    83     -> raw_fact list -> fact list * fact list
    84     -> raw_fact list -> fact list * fact list
    84   val mash_learn_proof :
    85   val mash_learn_proof :
    85     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    86     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
   515 fun status_feature_of status = (string_of_status status, 2.0 (* FUDGE *))
   516 fun status_feature_of status = (string_of_status status, 2.0 (* FUDGE *))
   516 val local_feature = ("local", 8.0 (* FUDGE *))
   517 val local_feature = ("local", 8.0 (* FUDGE *))
   517 val lams_feature = ("lams", 2.0 (* FUDGE *))
   518 val lams_feature = ("lams", 2.0 (* FUDGE *))
   518 val skos_feature = ("skos", 2.0 (* FUDGE *))
   519 val skos_feature = ("skos", 2.0 (* FUDGE *))
   519 
   520 
       
   521 fun weighted_const_feature_of num_facts const_tab const_s s =
       
   522   ("c" ^ s,
       
   523    if num_facts = 0 then
       
   524      0.0
       
   525    else
       
   526      let val count = Symtab.lookup const_tab const_s |> the_default 0 in
       
   527        16.0 + (Real.fromInt num_facts / Real.fromInt count)
       
   528      end)
       
   529 
   520 (* The following "crude" functions should be progressively phased out, since
   530 (* The following "crude" functions should be progressively phased out, since
   521    they create visibility edges that do not exist in Isabelle, resulting in
   531    they create visibility edges that do not exist in Isabelle, resulting in
   522    failed lookups later on. *)
   532    failed lookups later on. *)
   523 
   533 
   524 fun crude_theory_ord p =
   534 fun crude_theory_ord p =
   586   | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
   596   | crude_str_of_typ (TFree (_, S)) = crude_str_of_sort S
   587   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   597   | crude_str_of_typ (TVar (_, S)) = crude_str_of_sort S
   588 
   598 
   589 val max_pat_breadth = 10
   599 val max_pat_breadth = 10
   590 
   600 
   591 fun term_features_of ctxt prover thy_name term_max_depth type_max_depth ts =
   601 fun term_features_of ctxt prover thy_name num_facts const_tab term_max_depth
       
   602                      type_max_depth ts =
   592   let
   603   let
   593     val thy = Proof_Context.theory_of ctxt
   604     val thy = Proof_Context.theory_of ctxt
   594 
   605 
   595     val pass_args = map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")")
   606     val pass_args = map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")")
   596     fun is_built_in (x as (s, _)) args =
   607     fun is_built_in (x as (s, _)) args =
   655         add_term_pat Ts feature_of depth t
   666         add_term_pat Ts feature_of depth t
   656         #> add_term_pats Ts feature_of (depth - 1) t
   667         #> add_term_pats Ts feature_of (depth - 1) t
   657     fun add_term Ts feature_of = add_term_pats Ts feature_of term_max_depth
   668     fun add_term Ts feature_of = add_term_pats Ts feature_of term_max_depth
   658     fun add_subterms Ts t =
   669     fun add_subterms Ts t =
   659       case strip_comb t of
   670       case strip_comb t of
   660         (Const (x as (_, T)), args) =>
   671         (Const (x as (s, T)), args) =>
   661         let val (built_in, args) = is_built_in x args in
   672         let val (built_in, args) = is_built_in x args in
   662           (not built_in ? add_term Ts const_feature_of t)
   673           (not built_in
       
   674              ? add_term Ts (weighted_const_feature_of num_facts const_tab s) t)
   663           #> add_subtypes T
   675           #> add_subtypes T
   664           #> fold (add_subterms Ts) args
   676           #> fold (add_subterms Ts) args
   665         end
   677         end
   666       | (head, args) =>
   678       | (head, args) =>
   667         (case head of
   679         (case head of
   676 
   688 
   677 val term_max_depth = 2
   689 val term_max_depth = 2
   678 val type_max_depth = 2
   690 val type_max_depth = 2
   679 
   691 
   680 (* TODO: Generate type classes for types? *)
   692 (* TODO: Generate type classes for types? *)
   681 fun features_of ctxt prover thy (scope, status) ts =
   693 fun features_of ctxt prover thy num_facts const_tab (scope, status) ts =
   682   let val thy_name = Context.theory_name thy in
   694   let val thy_name = Context.theory_name thy in
   683     thy_feature_of thy_name ::
   695     thy_feature_of thy_name ::
   684     term_features_of ctxt prover thy_name term_max_depth type_max_depth ts
   696     term_features_of ctxt prover thy_name num_facts const_tab term_max_depth
       
   697                      type_max_depth ts
   685     |> status <> General ? cons (status_feature_of status)
   698     |> status <> General ? cons (status_feature_of status)
   686     |> scope <> Global ? cons local_feature
   699     |> scope <> Global ? cons local_feature
   687     |> exists (not o is_lambda_free) ts ? cons lams_feature
   700     |> exists (not o is_lambda_free) ts ? cons lams_feature
   688     |> exists (exists_Const is_exists) ts ? cons skos_feature
   701     |> exists (exists_Const is_exists) ts ? cons skos_feature
   689   end
   702   end
   920         raw_unknown
   933         raw_unknown
   921         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   934         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   922                 [unknown_chained, proximity]
   935                 [unknown_chained, proximity]
   923     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   936     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   924 
   937 
   925 val max_learn_on_query = 500
   938 fun add_const_counts t =
   926 
   939   fold (fn s => Symtab.map_default (s, ~1) (Integer.add 1))
   927 fun mash_suggested_facts ctxt ({overlord, learn, ...} : params) prover max_facts
   940        (Term.add_const_names t [])
   928                          hyp_ts concl_t facts =
   941 
       
   942 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
       
   943                          concl_t facts =
   929   let
   944   let
   930     val thy = Proof_Context.theory_of ctxt
   945     val thy = Proof_Context.theory_of ctxt
   931     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
   946     val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained)
       
   947     val const_tab = fold (add_const_counts o prop_of o snd) facts Symtab.empty
   932     val (access_G, suggs) =
   948     val (access_G, suggs) =
   933       peek_state ctxt (fn {access_G, num_known_facts, ...} =>
   949       peek_state ctxt (fn {access_G, ...} =>
   934           if Graph.is_empty access_G then
   950           if Graph.is_empty access_G then
   935             (access_G, [])
   951             (access_G, [])
   936           else
   952           else
   937             let
   953             let
   938               val parents = maximal_wrt_access_graph access_G facts
   954               val parents = maximal_wrt_access_graph access_G facts
   939               val feats =
   955               val feats =
   940                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   956                 features_of ctxt prover thy (length facts) const_tab
       
   957                             (Local, General) (concl_t :: hyp_ts)
   941               val hints =
   958               val hints =
   942                 chained |> filter (is_fact_in_graph access_G o snd)
   959                 chained |> filter (is_fact_in_graph access_G o snd)
   943                         |> map (nickname_of_thm o snd)
   960                         |> map (nickname_of_thm o snd)
   944             in
   961             in
   945               (access_G, MaSh.query ctxt overlord max_facts
   962               (access_G, MaSh.query ctxt overlord max_facts
   987                      used_ths =
  1004                      used_ths =
   988   launch_thread (timeout |> the_default one_day) (fn () =>
  1005   launch_thread (timeout |> the_default one_day) (fn () =>
   989       let
  1006       let
   990         val thy = Proof_Context.theory_of ctxt
  1007         val thy = Proof_Context.theory_of ctxt
   991         val name = freshish_name ()
  1008         val name = freshish_name ()
   992         val feats = features_of ctxt prover thy (Local, General) [t] |> map fst
  1009         val feats =
       
  1010           features_of ctxt prover thy 0 Symtab.empty (Local, General) [t]
       
  1011           |> map fst
   993       in
  1012       in
   994         peek_state ctxt (fn {access_G, ...} =>
  1013         peek_state ctxt (fn {access_G, ...} =>
   995             let
  1014             let
   996               val parents = maximal_wrt_access_graph access_G facts
  1015               val parents = maximal_wrt_access_graph access_G facts
   997               val deps =
  1016               val deps =
  1084           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1103           | learn_new_fact (parents, ((_, stature as (_, status)), th))
  1085                            (learns, (n, next_commit, _)) =
  1104                            (learns, (n, next_commit, _)) =
  1086             let
  1105             let
  1087               val name = nickname_of_thm th
  1106               val name = nickname_of_thm th
  1088               val feats =
  1107               val feats =
  1089                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
  1108                 features_of ctxt prover (theory_of_thm th) 0 Symtab.empty
       
  1109                             stature [prop_of th]
  1090                 |> map fst
  1110                 |> map fst
  1091               val deps = deps_of status th |> these
  1111               val deps = deps_of status th |> these
  1092               val n = n |> not (null deps) ? Integer.add 1
  1112               val n = n |> not (null deps) ? Integer.add 1
  1093               val learns = (name, parents, feats, deps) :: learns
  1113               val learns = (name, parents, feats, deps) :: learns
  1094               val (learns, next_commit) =
  1114               val (learns, next_commit) =