# HG changeset patch # User paulson # Date 1141809597 -3600 # Node ID ec53c138277a2a2d8a9c2ddcdd1e08927a2d2e81 # Parent 307dfa3f9e6668efb880a5cad7c3d32a683dcb3f Frequency strategy. Revised indentation, etc. diff -r 307dfa3f9e66 -r ec53c138277a src/HOL/Tools/ATP/reduce_axiomsN.ML --- a/src/HOL/Tools/ATP/reduce_axiomsN.ML Wed Mar 08 10:06:31 2006 +0100 +++ b/src/HOL/Tools/ATP/reduce_axiomsN.ML Wed Mar 08 10:19:57 2006 +0100 @@ -235,14 +235,12 @@ | const_typ_of (TVar(_,_)) = CTVar -fun const_w_typ thy (c,tp) = - let val tvars = Sign.const_typargs thy (c,tp) - in - (c,map const_typ_of tvars) - end; +fun const_w_typ thy (c,typ) = + let val tvars = Sign.const_typargs thy (c,typ) + in (c, map const_typ_of tvars) end; -fun add_term_consts_typs_rm thy ncs (Const(c, tp)) cs = - if (c mem ncs) then cs else (const_w_typ thy (c,tp) ins cs) +fun add_term_consts_typs_rm thy ncs (Const(c, typ)) cs = + if (c mem ncs) then cs else (const_w_typ thy (c,typ) ins cs) | add_term_consts_typs_rm thy ncs (t $ u) cs = add_term_consts_typs_rm thy ncs t (add_term_consts_typs_rm thy ncs u cs) | add_term_consts_typs_rm thy ncs (Abs(_,_,t)) cs = add_term_consts_typs_rm thy ncs t cs @@ -254,68 +252,87 @@ fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs) +fun lookup_or_zero (c,tab) = + case Symtab.lookup tab c of + NONE => 0 + | SOME n => n + +fun count_term_consts (Const(c,_), tab) = + Symtab.update (c, 1 + lookup_or_zero (c,tab)) tab + | count_term_consts (t $ u, tab) = + count_term_consts (t, count_term_consts (u, tab)) + | count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab) + | count_term_consts (_, tab) = tab; + +fun count_axiom_consts ((tm,_), tab) = count_term_consts (tm, tab); + (******** filter clauses ********) -fun find_clause_weight_s_3 gctyps consts_typs wa = +(*The default ignores the constant-count and gives the old Strategy 3*) +val weight_fn = ref (fn x : real => 1.0); + +fun add_ct_weight ctab ((c,_), w) = + w + !weight_fn (100.0 / real (Option.valOf (Symtab.lookup ctab c))); + +fun consts_typs_weight ctab = + List.foldl (add_ct_weight ctab) 0.0; + +fun clause_weight_s_3 ctab gctyps consts_typs = let val rel = filter (fn s => uni_mem s gctyps) consts_typs in - (real (length rel) / real (length consts_typs)) * wa + (consts_typs_weight ctab rel) / (consts_typs_weight ctab consts_typs) end; -fun relevant_clauses_ax_g_3 _ [] _ (ax,r) = (ax,r) - | relevant_clauses_ax_g_3 gctyps ((cls_typ,(clstm,(consts_typs,_)))::y) P (ax,r) = - let val weight = find_clause_weight_s_3 gctyps consts_typs 1.0 - val ccc = (cls_typ, (clstm, (consts_typs,weight))) - in - if P <= weight - then relevant_clauses_ax_g_3 gctyps y P (ccc::ax, r) - else relevant_clauses_ax_g_3 gctyps y P (ax, ccc::r) - end; +fun find_clause_weight_s_3_alt ctab consts_typs (_,(refconsts_typs,wa)) = + wa * clause_weight_s_3 ctab refconsts_typs consts_typs; -fun find_clause_weight_s_3_alt consts_typs (_,(_,(refconsts_typs,wa))) = - find_clause_weight_s_3 refconsts_typs consts_typs wa; - -fun relevant_clauses_ax_3 rel_axs [] P (addc,tmpc) keep = +fun relevant_clauses_ax_3 ctab rel_axs [] P (addc,tmpc) keep = if null addc orelse null tmpc then (addc @ rel_axs @ keep, tmpc) (*termination!*) - else relevant_clauses_ax_3 addc tmpc P ([],[]) (rel_axs @ keep) - | relevant_clauses_ax_3 rel_axs ((cls_typ,(clstm,(consts_typs,weight)))::e_axs) P (addc,tmpc) keep = - let val weights = map (find_clause_weight_s_3_alt consts_typs) rel_axs + else relevant_clauses_ax_3 ctab addc tmpc P ([],[]) (rel_axs @ keep) + | relevant_clauses_ax_3 ctab rel_axs ((clstm,(consts_typs,weight))::e_axs) P (addc,tmpc) keep = + let val weights = map (find_clause_weight_s_3_alt ctab consts_typs) rel_axs val weight' = List.foldl Real.max weight weights - val e_ax' = (cls_typ,(clstm,(consts_typs,weight'))) + val e_ax' = (clstm, (consts_typs,weight')) in - if P <= weight' then relevant_clauses_ax_3 rel_axs e_axs P (e_ax'::addc,tmpc) keep - else relevant_clauses_ax_3 rel_axs e_axs P (addc,e_ax'::tmpc) keep + if P <= weight' + then relevant_clauses_ax_3 ctab rel_axs e_axs P (e_ax'::addc, tmpc) keep + else relevant_clauses_ax_3 ctab rel_axs e_axs P (addc, e_ax'::tmpc) keep end; -fun initialize3 thy [] ax_weights = ax_weights - | initialize3 thy ((tm,name)::tms_names) ax_weights = - let val cls_type = clause_type tm - val consts = consts_typs_of_term thy tm - in - initialize3 thy tms_names ((cls_type,((tm,name),(consts,0.0)))::ax_weights) - end; +fun weight_of_axiom thy (tm,name) = + ((tm,name), (consts_typs_of_term thy tm)); -fun add_unit_clauses ax [] = ax - | add_unit_clauses ax ((cls_typ,consts_weight)::cs) = - case cls_typ of Unit_neq => add_unit_clauses ((cls_typ,consts_weight)::ax) cs - | Unit_geq => add_unit_clauses ((cls_typ,consts_weight)::ax) cs - | Other => add_unit_clauses ax cs; +fun safe_unit_clause ((clstm,_), _) = + case clause_type clstm of + Unit_neq => true + | Unit_geq => true + | Other => false; fun relevance_filter3_aux thy axioms goals = let val pass = !pass_mark - val axioms_weights = initialize3 thy axioms [] + val const_tab = List.foldl count_axiom_consts Symtab.empty axioms val goals_consts_typs = get_goal_consts_typs thy goals + fun relevant [] (ax,r) = (ax,r) + | relevant ((clstm,consts_typs)::y) (ax,r) = + let val weight = clause_weight_s_3 const_tab goals_consts_typs consts_typs + val ccc = (clstm, (consts_typs,weight)) + in + if pass <= weight + then relevant y (ccc::ax, r) + else relevant y (ax, ccc::r) + end val (rel_clauses,nrel_clauses) = - relevant_clauses_ax_g_3 goals_consts_typs axioms_weights pass ([],[]) - val (ax,r) = relevant_clauses_ax_3 rel_clauses nrel_clauses pass ([],[]) [] + relevant (map (weight_of_axiom thy) axioms) ([],[]) + val (ax,r) = relevant_clauses_ax_3 const_tab rel_clauses nrel_clauses pass ([],[]) [] in - if !add_unit then add_unit_clauses ax r else ax + if !add_unit then (filter safe_unit_clause r) @ ax + else ax end; fun relevance_filter3 thy axioms goals = - map (#1 o #2) (relevance_filter3_aux thy axioms goals); + map #1 (relevance_filter3_aux thy axioms goals); (******************************************************************)