# HG changeset patch # User blanchet # Date 1404226030 -7200 # Node ID 9cc802a8ab06c11871efb325ea0805d454916eef # Parent 22023ab4df3c1d32075f4793a222d45bbbb0e7c2 speed up MaSh a bit diff -r 22023ab4df3c -r 9cc802a8ab06 src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Tue Jul 01 16:47:10 2014 +0200 @@ -187,7 +187,7 @@ fun weight_of fact = mess |> map_filter (score_in fact) |> scaled_avg in fold (union fact_eq o map fst o take max_facts o fst o snd) mess [] - |> map (`weight_of) |> sort (int_ord o swap o pairself fst) + |> map (`weight_of) |> sort (int_ord o pairself fst o swap) |> map snd |> take max_facts end @@ -197,16 +197,12 @@ fun weight_facts_smoothly facts = facts ~~ map smooth_weight_of_fact (0 upto length facts - 1) fun weight_facts_steeply facts = facts ~~ map steep_weight_of_fact (0 upto length facts - 1) - -(*** Isabelle-agnostic machine learning ***) - -structure MaSh = -struct - -fun heap cmp bnd al a = +fun rev_sort_array_prefix cmp bnd a = let exception BOTTOM of int + val al = Array.length a + fun maxson l i = let val i31 = i + i + i + 1 in if i31 + 2 < l then @@ -270,6 +266,18 @@ () end +fun rev_sort_list_prefix cmp bnd xs = + let val ary = Array.fromList xs in + rev_sort_array_prefix cmp bnd ary; + Array.foldr (op ::) [] ary + end + + +(*** Isabelle-agnostic machine learning ***) + +structure MaSh = +struct + fun select_visible_facts big_number recommends = List.app (fn at => let val (j, ov) = Array.sub (recommends, at) in @@ -354,7 +362,7 @@ if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc) in select_visible_facts 100000.0 posterior visible_facts; - heap (Real.compare o pairself snd) max_suggs num_facts posterior; + rev_sort_array_prefix (Real.compare o pairself snd) max_suggs posterior; ret (Integer.max 0 (num_facts - max_suggs)) [] end @@ -387,7 +395,7 @@ end val _ = List.app do_feat goal_feats - val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr + val _ = rev_sort_array_prefix (Real.compare o pairself snd) num_facts overlaps_sqr val no_recommends = Unsynchronized.ref 0 val recommends = Array.tabulate (num_facts, rpair 0.0) val age = Unsynchronized.ref 500000000.0 @@ -432,7 +440,7 @@ while1 (); while2 (); select_visible_facts 1000000000.0 recommends visible_facts; - heap (Real.compare o pairself snd) max_suggs num_facts recommends; + rev_sort_array_prefix (Real.compare o pairself snd) max_suggs recommends; ret [] (Integer.max 0 (num_facts - max_suggs)) end @@ -1110,9 +1118,13 @@ find_maxes Symtab.empty ([], Graph.maximals G) end -fun maximal_wrt_access_graph access_G facts = - map (nickname_of_thm o snd) facts - |> maximal_wrt_graph access_G +fun maximal_wrt_access_graph _ [] = [] + | maximal_wrt_access_graph access_G ((fact as (_, th)) :: facts) = + let val thy = theory_of_thm th in + fact :: filter_out (fn (_, th') => Theory.subthy (theory_of_thm th', thy)) facts + |> map (nickname_of_thm o snd) + |> maximal_wrt_graph access_G + end fun is_fact_in_graph access_G = can (Graph.get_node access_G) o nickname_of_thm @@ -1144,8 +1156,11 @@ val thy_name = Context.theory_name thy val engine = the_mash_engine () - val facts = facts |> sort (crude_thm_ord o pairself snd o swap) - val chained = facts |> filter (fn ((_, (scope, _)), _) => scope = Chained) + val facts = facts + |> rev_sort_list_prefix (crude_thm_ord o pairself snd) + (Int.max (num_extra_feature_facts, max_proximity_facts)) + + val chained = filter (fn ((_, (scope, _)), _) => scope = Chained) facts fun fact_has_right_theory (_, th) = thy_name = Context.theory_name (theory_of_thm th) @@ -1155,53 +1170,44 @@ |> features_of ctxt (theory_of_thm th) stature |> map (rpair (weight * factor)) - fun query_args access_G = - let - val parents = maximal_wrt_access_graph access_G facts - - val goal_feats = features_of ctxt thy (Local, General) (concl_t :: hyp_ts) - val chained_feats = chained - |> map (rpair 1.0) - |> map (chained_or_extra_features_of chained_feature_factor) - |> rpair [] |-> fold (union (eq_fst (op =))) - val extra_feats = - facts - |> take (Int.max (0, num_extra_feature_facts - length chained)) - |> filter fact_has_right_theory - |> weight_facts_steeply - |> map (chained_or_extra_features_of extra_feature_factor) - |> rpair [] |-> fold (union (eq_fst (op =))) - val feats = - fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats) - |> debug ? sort (Real.compare o swap o pairself snd) - in - (parents, feats) - end - val {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} = peek_state ctxt + val goal_feats0 = features_of ctxt thy (Local, General) (concl_t :: hyp_ts) + val chained_feats = chained + |> map (rpair 1.0) + |> map (chained_or_extra_features_of chained_feature_factor) + |> rpair [] |-> fold (union (eq_fst (op =))) + val extra_feats = facts + |> take (Int.max (0, num_extra_feature_facts - length chained)) + |> filter fact_has_right_theory + |> weight_facts_steeply + |> map (chained_or_extra_features_of extra_feature_factor) + |> rpair [] |-> fold (union (eq_fst (op =))) + + val goal_feats = + fold (union (eq_fst (op =))) [chained_feats, extra_feats] (map (rpair 1.0) goal_feats0) + |> debug ? sort (Real.compare o swap o pairself snd) + + val parents = maximal_wrt_access_graph access_G facts + val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) + val suggs = - let - val (parents, goal_feats) = query_args access_G - val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) - in - if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then - let - val learns = - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G - in - MaSh.query_external ctxt engine max_suggs learns goal_feats - end - else - let - val int_goal_feats = - map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats - in - MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs - goal_feats int_goal_feats - end - end + if engine = MaSh_NB_Ext orelse engine = MaSh_kNN_Ext then + let + val learns = + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + in + MaSh.query_external ctxt engine max_suggs learns goal_feats + end + else + let + val int_goal_feats = + map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats + in + MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs + goal_feats int_goal_feats + end val unknown = filter_out (is_fact_in_graph access_G o snd) facts in @@ -1264,6 +1270,7 @@ let val thy = Proof_Context.theory_of ctxt val feats = features_of ctxt thy (Local, General) [t] + val facts = rev_sort_list_prefix (crude_thm_ord o pairself snd) 1 facts in map_state ctxt (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} => diff -r 22023ab4df3c -r 9cc802a8ab06 src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML Tue Jul 01 16:47:10 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_prover_atp.ML Tue Jul 01 16:47:10 2014 +0200 @@ -383,7 +383,7 @@ val atp_proof = atp_proof |> termify_atp_proof ctxt name format type_enc pool lifted sym_tab - |> introduce_spass_skolem + |> spass ? introduce_spass_skolem |> factify_atp_proof (map fst used_from) hyp_ts concl_t in (verbose, (metis_type_enc, metis_lam_trans), preplay_timeout, compress, try0,