src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57367 e64c1b174f4b
parent 57366 d01d1befe4a3
child 57368 b89937ed6099
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:30 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:36 2014 +0200
@@ -629,7 +629,10 @@
   c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
 
+val empty_xtab = (0, Symtab.empty)
+
 fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
+fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
 
 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
 
@@ -643,28 +646,18 @@
       naive_bayes_cpp max_suggs learns feats
     else
       let
-        val (rev_depss, rev_featss, ((num_facts, fact_tab), rev_facts), (num_feats, feat_tab)) =
-          fold (fn (fact, feats, deps) =>
-                fn (rev_depss, rev_featss, (fact_xtab as (_, fact_tab), rev_facts), feat_xtab) =>
-              let
-                fun add_feat feat (xtab as (n, tab)) =
-                  (case Symtab.lookup tab feat of
-                    SOME i => (i, xtab)
-                  | NONE => (n, add_to_xtab feat xtab))
+        val facts = map #1 learns
+        val fact_vec = Vector.fromList facts
+
+        val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
+        val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
 
-                val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
-              in
-                (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
-                 (add_to_xtab fact fact_xtab, fact :: rev_facts), feat_xtab')
-              end)
-            learns ([], [], ((0, Symtab.empty), []), (0, Symtab.empty))
+        val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
 
-        val facts = rev rev_facts
-        val fact_vec = Vector.fromList facts
+        val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
+
         val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
 
-        val deps_vec = Vector.fromList (rev rev_depss)
-
         val get_deps = curry Vector.sub deps_vec
 
         val int_feats = map_filter (Symtab.lookup feat_tab) feats
@@ -676,10 +669,8 @@
              val facts_ary = Array.array (num_feats, [])
              val _ =
                fold (fn feats => fn fact =>
-                   let val fact' = fact - 1 in
-                     List.app (map_array_at facts_ary (cons fact')) feats; fact'
-                   end)
-                 rev_featss num_facts
+                   (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
+                 featss 0
              val get_facts = curry Array.sub facts_ary
            in
              k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
@@ -687,7 +678,7 @@
            end
          else
            let
-             val unweighted_feats_ary = Vector.fromList (rev rev_featss)
+             val unweighted_feats_ary = Vector.fromList featss
              val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
            in
              (case engine of