generate deep type patterns in MaSh
authorblanchet
Mon, 19 Aug 2013 14:26:59 +0200
changeset 53082 369e39511555
parent 53081 2a62d848a56a
child 53083 019ecbb18e3f
generate deep type patterns in MaSh
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Aug 19 12:05:33 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Aug 19 14:26:59 2013 +0200
@@ -339,7 +339,7 @@
 
 local
 
-val version = "*** MaSh version 20130207a ***"
+val version = "*** MaSh version 20130819a ***"
 
 exception Too_New of unit
 
@@ -567,16 +567,20 @@
 val logical_consts =
   [@{const_name prop}, @{const_name Pure.conjunction}] @ atp_logical_consts
 
-val max_pattern_breadth = 10
+val max_pat_breadth = 10
 
 fun term_features_of ctxt prover thy_name term_max_depth type_max_depth ts =
   let
     val thy = Proof_Context.theory_of ctxt
+
+    val pass_args = map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")")
     fun is_built_in (x as (s, _)) args =
       if member (op =) logical_consts s then (true, args)
       else is_built_in_const_of_prover ctxt prover x args
+
     val fixes = map snd (Variable.dest_fixes ctxt)
     val classes = Sign.classes_of thy
+
     fun add_classes @{sort type} = I
       | add_classes S =
         fold (`(Sorts.super_classes classes)
@@ -584,58 +588,70 @@
               #> subtract (op =) @{sort type}
               #> map class_feature_of
               #> union (op = o pairself fst)) S
-    fun do_add_type (Type (s, Ts)) =
-        (not (member (op =) bad_types s)
-         ? insert (op = o pairself fst) (type_feature_of s))
-        #> fold do_add_type Ts
-      | do_add_type (TFree (_, S)) = add_classes S
-      | do_add_type (TVar (_, S)) = add_classes S
-    fun add_type T = type_max_depth >= 0 ? do_add_type T
-    fun patternify_term _ 0 _ = []
-      | patternify_term args _ (Const (x as (s, _))) =
+
+    fun pattify_type 0 _ = []
+      | pattify_type _ (Type (s, [])) =
+        if member (op =) bad_types s then [] else [s]
+      | pattify_type depth (Type (s, U :: Ts)) =
+        let
+          val T = Type (s, Ts)
+          val ps = take max_pat_breadth (pattify_type depth T)
+          val qs = take max_pat_breadth ("" :: pattify_type (depth - 1) U)
+        in pass_args ps qs end
+      | pattify_type _ _ = []
+    fun add_type_pat depth T =
+      union (op = o pairself fst)
+            (map type_feature_of (pattify_type depth T) @
+             fold_atyps_sorts (fn (_, S) => add_classes S) T [])
+    fun add_type_pats 0 _ = I
+      | add_type_pats depth t =
+        add_type_pat depth t #> add_type_pats (depth - 1) t
+    val add_type = add_type_pats type_max_depth
+
+    fun pattify_term _ 0 _ = []
+      | pattify_term args _ (Const (x as (s, _))) =
         if fst (is_built_in x args) then [] else [s]
-      | patternify_term _ depth (Free (s, _)) =
+      | pattify_term _ depth (Free (s, _)) =
         if depth = term_max_depth andalso member (op =) fixes s then
           [thy_name ^ Long_Name.separator ^ s]
         else
           []
-      | patternify_term args depth (t $ u) =
+      | pattify_term args depth (t $ u) =
         let
-          val ps =
-            take max_pattern_breadth (patternify_term (u :: args) depth t)
-          val qs =
-            take max_pattern_breadth ("" :: patternify_term [] (depth - 1) u)
-        in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
-      | patternify_term _ _ _ = []
-    fun add_term_pattern feature_of =
-      union (op = o pairself fst) o map feature_of oo patternify_term []
-    fun add_term_patterns _ 0 _ = I
-      | add_term_patterns feature_of depth t =
-        add_term_pattern feature_of depth t
-        #> add_term_patterns feature_of (depth - 1) t
-    fun add_term feature_of = add_term_patterns feature_of term_max_depth
-    fun add_patterns t =
+          val ps = take max_pat_breadth (pattify_term (u :: args) depth t)
+          val qs = take max_pat_breadth ("" :: pattify_term [] (depth - 1) u)
+        in pass_args ps qs end
+      | pattify_term _ _ _ = []
+    fun add_term_pat feature_of depth =
+      union (op = o pairself fst) o map feature_of o pattify_term [] depth
+    fun add_term_pats _ 0 _ = I
+      | add_term_pats feature_of depth t =
+        add_term_pat feature_of depth t
+        #> add_term_pats feature_of (depth - 1) t
+    fun add_term feature_of = add_term_pats feature_of term_max_depth
+
+    fun add_pats t =
       case strip_comb t of
         (Const (x as (_, T)), args) =>
         let val (built_in, args) = is_built_in x args in
           (not built_in ? add_term const_feature_of t)
           #> add_type T
-          #> fold add_patterns args
+          #> fold add_pats args
         end
       | (head, args) =>
         (case head of
            Const (_, T) => add_term const_feature_of t #> add_type T
          | Free (_, T) => add_term free_feature_of t #> add_type T
          | Var (_, T) => add_type T
-         | Abs (_, T, body) => add_type T #> add_patterns body
+         | Abs (_, T, body) => add_type T #> add_pats body
          | _ => I)
-        #> fold add_patterns args
-  in [] |> fold add_patterns ts end
+        #> fold add_pats args
+  in [] |> fold add_pats ts end
 
 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
 
-val term_max_depth = 2
-val type_max_depth = 2
+val term_max_depth = 3
+val type_max_depth = 3
 
 (* TODO: Generate type classes for types? *)
 fun features_of ctxt prover thy (scope, status) ts =