src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 50841 087e3c531e86
parent 50840 a5cc092156da
child 50857 80768e28c9ee
equal deleted inserted replaced
50840:a5cc092156da 50841:087e3c531e86
   533 val logical_consts =
   533 val logical_consts =
   534   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
   534   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
   535 
   535 
   536 val max_pattern_breadth = 10
   536 val max_pattern_breadth = 10
   537 
   537 
   538 fun interesting_terms_types_and_classes ctxt prover thy_name term_max_depth
   538 fun term_features_of ctxt prover thy_name term_max_depth type_max_depth ts =
   539                                         type_max_depth ts =
       
   540   let
   539   let
   541     val thy = Proof_Context.theory_of ctxt
   540     val thy = Proof_Context.theory_of ctxt
   542     val fixes = map snd (Variable.dest_fixes ctxt)
   541     val fixes = map snd (Variable.dest_fixes ctxt)
   543     val classes = Sign.classes_of thy
   542     val classes = Sign.classes_of thy
   544     fun is_bad_const (x as (s, _)) args =
       
   545       member (op =) logical_consts s orelse
       
   546       fst (is_built_in_const_for_prover ctxt prover x args)
       
   547     fun add_classes @{sort type} = I
   543     fun add_classes @{sort type} = I
   548       | add_classes S =
   544       | add_classes S =
   549         fold (`(Sorts.super_classes classes)
   545         fold (`(Sorts.super_classes classes)
   550               #> swap #> op ::
   546               #> swap #> op ::
   551               #> subtract (op =) @{sort type}
   547               #> subtract (op =) @{sort type}
   556          ? insert (op = o pairself fst) (type_feature_of s))
   552          ? insert (op = o pairself fst) (type_feature_of s))
   557         #> fold do_add_type Ts
   553         #> fold do_add_type Ts
   558       | do_add_type (TFree (_, S)) = add_classes S
   554       | do_add_type (TFree (_, S)) = add_classes S
   559       | do_add_type (TVar (_, S)) = add_classes S
   555       | do_add_type (TVar (_, S)) = add_classes S
   560     fun add_type T = type_max_depth >= 0 ? do_add_type T
   556     fun add_type T = type_max_depth >= 0 ? do_add_type T
   561     fun patternify_term _ 0 _ = []
   557     fun patternify_term _ 0 _ = ([], [])
   562       | patternify_term args _ (Const (x as (s, _))) =
   558       | patternify_term args _ (Const (x as (s, _))) =
   563         if is_bad_const x args then [] else [s]
   559         (if member (op =) logical_consts s then (true, args)
   564       | patternify_term _ depth (Free (s, _)) =
   560          else is_built_in_const_for_prover ctxt prover x args)
   565         if depth = term_max_depth andalso member (op =) fixes s then
   561         |>> (fn true => [] | false => [s])
   566           [thy_name ^ Long_Name.separator ^ s]
   562       | patternify_term args depth (Free (s, _)) =
   567         else
   563         (if depth = term_max_depth andalso member (op =) fixes s then
   568           []
   564            [thy_name ^ Long_Name.separator ^ s]
       
   565          else
       
   566            [], args)
   569       | patternify_term args depth (t $ u) =
   567       | patternify_term args depth (t $ u) =
   570         let
   568         let
   571           val ps =
   569           val (ps, u_args) =
   572             take max_pattern_breadth (patternify_term (u :: args) depth t)
   570             patternify_term (u :: args) depth t
   573           val qs =
   571             |>> take max_pattern_breadth
   574             take max_pattern_breadth ("" :: patternify_term [] (depth - 1) u)
   572           val (qs, args) =
   575         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
   573             case u_args of
   576       | patternify_term _ _ _ = []
   574               [] => ([], [])
       
   575             | arg :: args' =>
       
   576               if arg = u then
       
   577                 (patternify_term [] (depth - 1) u
       
   578                  |> fst |> cons "" |> take max_pattern_breadth,
       
   579                  args')
       
   580               else
       
   581                 ([], args')
       
   582         in
       
   583           (map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs,
       
   584            args)
       
   585         end
       
   586       | patternify_term _ _ _ = ([], [])
   577     fun add_term_pattern feature_of =
   587     fun add_term_pattern feature_of =
   578       union (op = o pairself fst) o map feature_of oo patternify_term []
   588       union (op = o pairself fst) o map feature_of o fst oo patternify_term []
   579     fun add_term_patterns _ 0 _ = I
   589     fun add_term_patterns _ 0 _ = I
   580       | add_term_patterns feature_of depth t =
   590       | add_term_patterns feature_of depth t =
   581         add_term_pattern feature_of depth t
   591         add_term_pattern feature_of depth t
   582         #> add_term_patterns feature_of (depth - 1) t
   592         #> add_term_patterns feature_of (depth - 1) t
   583     fun add_term feature_of = add_term_patterns feature_of term_max_depth
   593     fun add_term feature_of = add_term_patterns feature_of term_max_depth
   600 
   610 
   601 (* TODO: Generate type classes for types? *)
   611 (* TODO: Generate type classes for types? *)
   602 fun features_of ctxt prover thy (scope, status) ts =
   612 fun features_of ctxt prover thy (scope, status) ts =
   603   let val thy_name = Context.theory_name thy in
   613   let val thy_name = Context.theory_name thy in
   604     thy_feature_of thy_name ::
   614     thy_feature_of thy_name ::
   605     interesting_terms_types_and_classes ctxt prover thy_name term_max_depth
   615     term_features_of ctxt prover thy_name term_max_depth type_max_depth ts
   606         type_max_depth ts
       
   607     |> status <> General ? cons (status_feature_of status)
   616     |> status <> General ? cons (status_feature_of status)
   608     |> scope <> Global ? cons local_feature
   617     |> scope <> Global ? cons local_feature
   609     |> exists (not o is_lambda_free) ts ? cons lams_feature
   618     |> exists (not o is_lambda_free) ts ? cons lams_feature
   610     |> exists (exists_Const is_exists) ts ? cons skos_feature
   619     |> exists (exists_Const is_exists) ts ? cons skos_feature
   611   end
   620   end