MaSh improvements: deeper patterns + more respect for chained facts
authorblanchet
Tue, 04 Dec 2012 00:37:11 +0100
changeset 50339 d8dae91f3107
parent 50338 73f2f0cd4aea
child 50340 72519bf5f135
MaSh improvements: deeper patterns + more respect for chained facts
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Dec 03 23:43:53 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Tue Dec 04 00:37:11 2012 +0100
@@ -281,7 +281,7 @@
 
 local
 
-val version = "*** MaSh 0.0 ***"
+val version = "*** MaSh 0.1 ***"
 
 fun extract_node line =
   case space_explode ":" line of
@@ -488,18 +488,17 @@
       | do_add_type (TFree (_, S)) = add_classes S
       | do_add_type (TVar (_, S)) = add_classes S
     fun add_type T = type_max_depth >= 0 ? do_add_type T
-    fun mk_app s args =
-      if member (op <>) args "" then s ^ "(" ^ space_implode "," args ^ ")"
-      else s
-    fun patternify_term ~1 _ = ""
-      | patternify_term depth t =
-        case strip_comb t of
-          (Const (x as (s, _)), args) =>
-          if is_bad_const x args then ""
-          else mk_app (const_name_of s) (map (patternify_term (depth - 1)) args)
-        | _ => ""
-    fun add_term_pattern depth t =
-      case patternify_term depth t of "" => I | s => insert (op =) s
+    fun patternify_term _ ~1 _ = []
+      | patternify_term args _ (Const (x as (s, _))) =
+        if is_bad_const x args then [] else [const_name_of s]
+      | patternify_term _ 0 _ = []
+      | patternify_term args depth (t $ u) =
+        let
+          val ps = patternify_term (u :: args) depth t
+          val qs = "" :: patternify_term [] (depth - 1) u
+        in map_product (fn p => fn q => p ^ "(" ^ q ^ ")") ps qs end
+      | patternify_term _ _ _ = []
+    val add_term_pattern = union (op =) oo patternify_term []
     fun add_term_patterns ~1 _ = I
       | add_term_patterns depth t =
         add_term_pattern depth t #> add_term_patterns (depth - 1) t
@@ -518,8 +517,8 @@
 
 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
 
-val term_max_depth = 1
-val type_max_depth = 1
+val term_max_depth = 2
+val type_max_depth = 2
 
 (* TODO: Generate type classes for types? *)
 fun features_of ctxt prover thy (scope, status) ts =
@@ -688,15 +687,21 @@
               (fact_G, mash_QUERY ctxt overlord (max_suggs_of max_facts)
                                   (parents, feats))
             end)
+    val (chained, unchained) =
+      List.partition (fn ((_, (scope, _)), _) => scope = Chained) facts
     val sels =
       facts |> suggested_facts suggs
-            (* The weights currently returned by "mash.py" are too extreme to
+            (* The weights currently returned by "mash.py" are too spaced out to
                make any sense. *)
             |> map fst
+            |> filter_out (member (Thm.eq_thm_prop o pairself snd) chained)
     val (unk_global, unk_local) =
-      facts |> filter_out (is_fact_in_graph fact_G)
-            |> List.partition (fn ((_, (scope, _)), _) => scope = Global)
-  in (interleave max_facts unk_local sels |> weight_mepo_facts, unk_global) end
+      unchained |> filter_out (is_fact_in_graph fact_G)
+                |> List.partition (fn ((_, (scope, _)), _) => scope = Global)
+  in
+    (interleave max_facts (chained @ unk_local) sels |> weight_mepo_facts,
+     unk_global)
+  end
 
 fun add_wrt_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   let