src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 53159 a5805fe4e91c
parent 53156 f79f4693868b
child 53197 6c5e7143e1f6
equal deleted inserted replaced
53158:4b9df3461eda 53159:a5805fe4e91c
   611       | add_classes S =
   611       | add_classes S =
   612         fold (`(Sorts.super_classes classes)
   612         fold (`(Sorts.super_classes classes)
   613               #> swap #> op ::
   613               #> swap #> op ::
   614               #> subtract (op =) @{sort type} #> map massage_long_name
   614               #> subtract (op =) @{sort type} #> map massage_long_name
   615               #> map class_feature_of
   615               #> map class_feature_of
   616               #> union (op = o pairself fst)) S
   616               #> union (eq_fst (op =))) S
   617 
   617 
   618     fun pattify_type 0 _ = []
   618     fun pattify_type 0 _ = []
   619       | pattify_type _ (Type (s, [])) =
   619       | pattify_type _ (Type (s, [])) =
   620         if member (op =) bad_types s then [] else [massage_long_name s]
   620         if member (op =) bad_types s then [] else [massage_long_name s]
   621       | pattify_type depth (Type (s, U :: Ts)) =
   621       | pattify_type depth (Type (s, U :: Ts)) =
   627       | pattify_type _ (TFree (_, S)) =
   627       | pattify_type _ (TFree (_, S)) =
   628         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   628         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   629       | pattify_type _ (TVar (_, S)) =
   629       | pattify_type _ (TVar (_, S)) =
   630         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   630         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
   631     fun add_type_pat depth T =
   631     fun add_type_pat depth T =
   632       union (op = o pairself fst) (map type_feature_of (pattify_type depth T))
   632       union (eq_fst (op =)) (map type_feature_of (pattify_type depth T))
   633     fun add_type_pats 0 _ = I
   633     fun add_type_pats 0 _ = I
   634       | add_type_pats depth t =
   634       | add_type_pats depth t =
   635         add_type_pat depth t #> add_type_pats (depth - 1) t
   635         add_type_pat depth t #> add_type_pats (depth - 1) t
   636     fun add_type T =
   636     fun add_type T =
   637       add_type_pats type_max_depth T
   637       add_type_pats type_max_depth T
   674           map_product (fn ppw as (p, pw) =>
   674           map_product (fn ppw as (p, pw) =>
   675               fn ("", _) => ppw
   675               fn ("", _) => ppw
   676                | (q, qw) => (p ^ "(" ^ q ^ ")", pw + qw)) ps qs
   676                | (q, qw) => (p ^ "(" ^ q ^ ")", pw + qw)) ps qs
   677         end
   677         end
   678       | pattify_term _ _ _ _ = []
   678       | pattify_term _ _ _ _ = []
   679     fun add_term_pat Ts depth =
   679     fun add_term_pat Ts = union (eq_fst (op =)) oo pattify_term Ts []
   680       union (op = o pairself fst) o pattify_term Ts [] depth
       
   681     fun add_term_pats _ 0 _ = I
   680     fun add_term_pats _ 0 _ = I
   682       | add_term_pats Ts depth t =
   681       | add_term_pats Ts depth t =
   683         add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t
   682         add_term_pat Ts depth t #> add_term_pats Ts (depth - 1) t
   684     fun add_term Ts = add_term_pats Ts term_max_depth
   683     fun add_term Ts = add_term_pats Ts term_max_depth
   685     fun add_subterms Ts t =
   684     fun add_subterms Ts t =
   903 
   902 
   904 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   903 fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm
   905 
   904 
   906 val chained_feature_factor = 0.5
   905 val chained_feature_factor = 0.5
   907 val extra_feature_factor = 0.1
   906 val extra_feature_factor = 0.1
   908 val num_extra_feature_facts = 0 (* FUDGE *)
   907 val num_extra_feature_facts = 10 (* FUDGE *)
   909 
   908 
   910 (* FUDGE *)
   909 (* FUDGE *)
   911 fun weight_of_proximity_fact rank =
   910 fun weight_of_proximity_fact rank =
   912   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   911   Math.pow (1.3, 15.5 - 0.2 * Real.fromInt rank) + 15.0
   913 
   912 
   924 val max_proximity_facts = 100
   923 val max_proximity_facts = 100
   925 
   924 
   926 fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
   925 fun find_mash_suggestions _ _ [] _ _ raw_unknown = ([], raw_unknown)
   927   | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   926   | find_mash_suggestions ctxt max_facts suggs facts chained raw_unknown =
   928     let
   927     let
   929       val inter_fact = inter (Thm.eq_thm_prop o pairself snd)
   928       val inter_fact = inter (eq_snd Thm.eq_thm_prop)
   930       val raw_mash = find_suggested_facts ctxt facts suggs
   929       val raw_mash = find_suggested_facts ctxt facts suggs
   931       val proximate = take max_proximity_facts facts
   930       val proximate = take max_proximity_facts facts
   932       val unknown_chained = inter_fact raw_unknown chained
   931       val unknown_chained = inter_fact raw_unknown chained
   933       val unknown_proximate = inter_fact raw_unknown proximate
   932       val unknown_proximate = inter_fact raw_unknown proximate
   934       val mess =
   933       val mess =
   935         [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   934         [(0.9 (* FUDGE *), (map (rpair 1.0) unknown_chained, [])),
   936          (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
   935          (0.4 (* FUDGE *), (weight_facts_smoothly unknown_proximate, [])),
   937          (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
   936          (0.1 (* FUDGE *), (weight_facts_steeply raw_mash, raw_unknown))]
   938       val unknown =
   937       val unknown =
   939         raw_unknown
   938         raw_unknown
   940         |> fold (subtract (Thm.eq_thm_prop o pairself snd))
   939         |> fold (subtract (eq_snd Thm.eq_thm_prop))
   941                 [unknown_chained, unknown_proximate]
   940                 [unknown_chained, unknown_proximate]
   942     in (mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess, unknown) end
   941     in (mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess, unknown) end
   943 
   942 
   944 fun add_const_counts t =
   943 fun add_const_counts t =
   945   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1))
   944   fold (fn s => Symtab.map_default (s, 0) (Integer.add 1))
   946        (Term.add_const_names t [])
   945        (Term.add_const_names t [])
   947 
   946 
   969                             (Local, General) (concl_t :: hyp_ts)
   968                             (Local, General) (concl_t :: hyp_ts)
   970               val chained_feats =
   969               val chained_feats =
   971                 chained
   970                 chained
   972                 |> map (rpair 1.0)
   971                 |> map (rpair 1.0)
   973                 |> map (chained_or_extra_features_of chained_feature_factor)
   972                 |> map (chained_or_extra_features_of chained_feature_factor)
   974                 |> rpair [] |-> fold (union (op = o pairself fst))
   973                 |> rpair [] |-> fold (union (eq_fst (op =)))
   975               val extra_feats =
   974               val extra_feats =
   976                 facts
   975                 facts
   977                 |> take (Int.max (0, num_extra_feature_facts - length chained))
   976                 |> take (Int.max (0, num_extra_feature_facts - length chained))
   978                 |> weight_facts_steeply
   977                 |> weight_facts_steeply
   979                 |> map (chained_or_extra_features_of extra_feature_factor)
   978                 |> map (chained_or_extra_features_of extra_feature_factor)
   980                 |> rpair [] |-> fold (union (op = o pairself fst))
   979                 |> rpair [] |-> fold (union (eq_fst (op =)))
   981               val feats =
   980               val feats =
   982                 fold (union (op = o pairself fst)) [chained_feats, extra_feats]
   981                 fold (union (eq_fst (op =))) [chained_feats, extra_feats]
   983                      goal_feats
   982                      goal_feats
   984               val hints =
   983               val hints =
   985                 chained |> filter (is_fact_in_graph access_G o snd)
   984                 chained |> filter (is_fact_in_graph access_G o snd)
   986                         |> map (nickname_of_thm o snd)
   985                         |> map (nickname_of_thm o snd)
   987             in
   986             in
  1347            |> (if effective_fact_filter <> mashN then
  1346            |> (if effective_fact_filter <> mashN then
  1348                  cons (mepo_weight, (mepo (), []))
  1347                  cons (mepo_weight, (mepo (), []))
  1349                else
  1348                else
  1350                  I)
  1349                  I)
  1351       val mesh =
  1350       val mesh =
  1352         mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess
  1351         mesh_facts (eq_snd Thm.eq_thm_prop) max_facts mess
  1353         |> add_and_take
  1352         |> add_and_take
  1354     in
  1353     in
  1355       if save then MaSh.save ctxt overlord else ();
  1354       if save then MaSh.save ctxt overlord else ();
  1356       case (fact_filter, mess) of
  1355       case (fact_filter, mess) of
  1357         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
  1356         (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>