src/HOL/Tools/ATP/reduce_axiomsN.ML
author paulson
Fri Mar 10 12:27:36 2006 +0100 (2006-03-10)
changeset 19231 c8879dd3a953
parent 19212 ec53c138277a
child 19315 b218cc3d1bb4
permissions -rw-r--r--
Frequency analysis of constants (with types).

Ability to restrict the number of accepted clauses.
     1 (* Authors: Jia Meng, NICTA and Lawrence C Paulson, Cambridge University Computer Laboratory
     2    ID: $Id$
     3    Filtering strategies *)
     4 
     5 structure ReduceAxiomsN =
     6 struct
     7 
     8 val pass_mark = ref 0.5;
     9 val strategy = ref 3;
    10 val max_filtered = ref 2000;
    11 
    12 fun pol_to_int true = 1
    13   | pol_to_int false = ~1;
    14 
    15 fun part refs [] (s1,s2) = (s1,s2)
    16   | part refs (s::ss) (s1,s2) = 
    17       if (s mem refs) then part refs ss (s::s1,s2) else part refs ss (s1,s::s2);
    18 
    19 
    20 fun pol_mem _ [] = false
    21   | pol_mem (pol,const) ((p,c)::pcs) =
    22       (pol = not p andalso const = c) orelse pol_mem (pol,const) pcs;
    23 
    24 
    25 fun part_w_pol refs [] (s1,s2) = (s1,s2)
    26   | part_w_pol refs (s::ss) (s1,s2) =
    27       if (pol_mem s refs) then part_w_pol refs ss (s::s1,s2) 
    28       else part_w_pol refs ss (s1,s::s2);
    29 
    30 
    31 fun add_term_consts_rm ncs (Const(c, _)) cs =
    32       if (c mem ncs) then cs else (c ins_string cs)
    33   | add_term_consts_rm ncs (t $ u) cs =
    34       add_term_consts_rm ncs t (add_term_consts_rm ncs u cs)
    35   | add_term_consts_rm ncs (Abs(_,_,t)) cs = add_term_consts_rm ncs t cs
    36   | add_term_consts_rm ncs _ cs = cs;
    37 
    38 fun term_consts_rm ncs t = add_term_consts_rm ncs t [];
    39 
    40 (*Including equality in this list might be expected to stop rules like subset_antisym from
    41   being chosen, but for some reason filtering works better with them listed.*)
    42 val standard_consts =
    43   ["Trueprop","==>","all","Ex","op &","op |","Not","All","op -->","op =","==","True","False"];
    44 
    45 val consts_of_term = term_consts_rm standard_consts;
    46 
    47 
    48 fun add_term_pconsts_rm ncs (Const(c,_)) pol cs = if c mem ncs then cs else ((pol,c) ins cs)
    49   | add_term_pconsts_rm ncs (Const("Not",_)$P) pol cs = add_term_pconsts_rm ncs P (not pol) cs
    50   | add_term_pconsts_rm ncs (P$Q) pol cs = 
    51     add_term_pconsts_rm ncs P pol (add_term_pconsts_rm ncs Q pol cs)
    52   | add_term_pconsts_rm ncs (Abs(_,_,t)) pol cs = add_term_pconsts_rm ncs t pol cs
    53   | add_term_pconsts_rm ncs _ _ cs = cs;
    54 
    55 
    56 fun term_pconsts_rm ncs t = add_term_pconsts_rm ncs t true [];
    57 
    58 val pconsts_of_term = term_pconsts_rm standard_consts;
    59 
    60 fun consts_in_goal goal = consts_of_term goal;
    61 fun get_goal_consts cs = foldl (op union_string) [] (map consts_in_goal cs);
    62 
    63 fun pconsts_in_goal goal = pconsts_of_term goal;
    64 fun get_goal_pconsts cs = foldl (op union) [] (map pconsts_in_goal cs);
    65 
    66 
    67 (*************************************************************************)
    68 (*            the first relevance filtering strategy                     *)
    69 (*************************************************************************)
    70 
    71 fun find_clause_weight_s_1 (refconsts : string list) consts wa = 
    72     let val (rel,irrel) = part refconsts consts ([],[])
    73     in
    74 	(real (length rel) / real (length consts)) * wa
    75     end;
    76 
    77 fun find_clause_weight_m_1 [] (_,w) = w 
    78   | find_clause_weight_m_1 ((_,(refconsts,wa))::y) (consts,w) =
    79       let val w' = find_clause_weight_s_1 refconsts consts wa
    80       in
    81 	if w < w' then find_clause_weight_m_1 y (consts,w')
    82 	else find_clause_weight_m_1 y (consts,w)
    83       end;
    84 
    85 
    86 fun relevant_clauses_ax_g_1 _ []  _ (ax,r) = (ax,r)
    87   | relevant_clauses_ax_g_1 gconsts  ((clstm,(consts,_))::y) P (ax,r) =
    88       let val weight = find_clause_weight_s_1 gconsts consts 1.0
    89       in
    90 	if  P <= weight 
    91 	then relevant_clauses_ax_g_1 gconsts y P ((clstm,(consts,weight))::ax,r)
    92 	else relevant_clauses_ax_g_1 gconsts y P (ax,(clstm,(consts,weight))::r)
    93       end;
    94 
    95 
    96 fun relevant_clauses_ax_1 rel_axs  [] P (addc,tmpc) keep = 
    97     (case addc of [] => rel_axs @ keep
    98 		| _ => case tmpc of [] => addc @ rel_axs @ keep
    99 				  | _ => relevant_clauses_ax_1 addc tmpc P ([],[]) (rel_axs @ keep))
   100   | relevant_clauses_ax_1 rel_axs ((clstm,(consts,weight))::e_axs) P (addc,tmpc) keep = 
   101       let val weight' = find_clause_weight_m_1 rel_axs (consts,weight) 
   102 	  val e_ax' = (clstm,(consts, weight'))
   103       in
   104 	if P <= weight' 
   105 	then relevant_clauses_ax_1 rel_axs e_axs P ((clstm,(consts,weight'))::addc,tmpc) keep
   106 	else relevant_clauses_ax_1 rel_axs e_axs P (addc,(clstm,(consts,weight'))::tmpc) keep 
   107       end;
   108 
   109 
   110 fun initialize [] ax_weights = ax_weights
   111   | initialize ((tm,name)::tms_names) ax_weights =
   112       let val consts = consts_of_term tm
   113       in
   114 	  initialize tms_names (((tm,name),(consts,0.0))::ax_weights)
   115       end;
   116 
   117 fun relevance_filter1_aux axioms goals = 
   118     let val pass = !pass_mark
   119 	val axioms_weights = initialize axioms []
   120 	val goals_consts = get_goal_consts goals
   121 	val (rel_clauses,nrel_clauses) = relevant_clauses_ax_g_1 goals_consts axioms_weights pass ([],[]) 
   122     in
   123 	relevant_clauses_ax_1 rel_clauses nrel_clauses pass ([],[]) []
   124     end;
   125 
   126 fun relevance_filter1 axioms goals = map fst (relevance_filter1_aux axioms goals);
   127 
   128 
   129 (*************************************************************************)
   130 (*            the second relevance filtering strategy                    *)
   131 (*************************************************************************)
   132 
   133 fun find_clause_weight_s_2 (refpconsts : (bool * string) list) pconsts wa = 
   134     let val (rel,irrel) = part_w_pol refpconsts pconsts ([],[])
   135     in
   136 	((real (length rel))/(real (length pconsts))) * wa
   137     end;
   138 
   139 fun find_clause_weight_m_2 [] (_,w) = w 
   140   | find_clause_weight_m_2 ((_,(refpconsts,wa))::y) (pconsts,w) =
   141     let val w' = find_clause_weight_s_2 refpconsts pconsts wa
   142     in
   143 	if (w < w') then find_clause_weight_m_2 y (pconsts,w')
   144 	else find_clause_weight_m_2 y (pconsts,w)
   145     end;
   146 
   147 
   148 fun relevant_clauses_ax_g_2 _ []  _ (ax,r) = (ax,r)
   149   | relevant_clauses_ax_g_2 gpconsts  ((clstm,(pconsts,_))::y) P (ax,r) =
   150     let val weight = find_clause_weight_s_2 gpconsts pconsts 1.0
   151     in
   152 	if  P <= weight then relevant_clauses_ax_g_2 gpconsts y P ((clstm,(pconsts,weight))::ax,r)
   153 	else relevant_clauses_ax_g_2 gpconsts y P (ax,(clstm,(pconsts,weight))::r)
   154     end;
   155 
   156 
   157 fun relevant_clauses_ax_2 rel_axs  [] P (addc,tmpc) keep = 
   158     (case addc of [] => rel_axs @ keep
   159 		| _ => case tmpc of [] => addc @ rel_axs @ keep
   160 				  | _ => relevant_clauses_ax_2 addc tmpc P ([],[]) (rel_axs @ keep))
   161   | relevant_clauses_ax_2 rel_axs ((clstm,(pconsts,weight))::e_axs) P (addc,tmpc) keep = 
   162     let val weight' = find_clause_weight_m_2 rel_axs (pconsts,weight) 
   163 	val e_ax' = (clstm,(pconsts, weight'))
   164     in
   165 	
   166 	if P <= weight' then relevant_clauses_ax_2 rel_axs e_axs P ((clstm,(pconsts,weight'))::addc,tmpc) keep
   167 	else relevant_clauses_ax_2 rel_axs e_axs P (addc,(clstm,(pconsts,weight'))::tmpc) keep 
   168     end;
   169 
   170 
   171 fun initialize_w_pol [] ax_weights = ax_weights
   172   | initialize_w_pol ((tm,name)::tms_names) ax_weights =
   173     let val consts = pconsts_of_term tm
   174     in
   175 	initialize_w_pol tms_names (((tm,name),(consts,0.0))::ax_weights)
   176     end;
   177 
   178 
   179 fun relevance_filter2_aux axioms goals = 
   180     let val pass = !pass_mark
   181 	val axioms_weights = initialize_w_pol axioms []
   182 	val goals_consts = get_goal_pconsts goals
   183 	val (rel_clauses,nrel_clauses) = relevant_clauses_ax_g_2 goals_consts axioms_weights pass ([],[]) 
   184     in
   185 	relevant_clauses_ax_2 rel_clauses nrel_clauses pass ([],[]) []
   186     end;
   187 
   188 fun relevance_filter2 axioms goals = map fst (relevance_filter2_aux axioms goals);
   189 
   190 (******************************************************************)
   191 (*       the third relevance filtering strategy                   *)
   192 (******************************************************************)
   193 
   194 (*** unit clauses ***)
   195 datatype clause_kind = Unit_neq | Unit_geq | Other
   196 
   197 (*Whether all "simple" unit clauses should be included*)
   198 val add_unit = ref true;
   199 
   200 fun literals_of_term args (Const ("Trueprop",_) $ P) = literals_of_term args P
   201   | literals_of_term args (Const ("op |",_) $ P $ Q) = 
   202     literals_of_term (literals_of_term args P) Q
   203   | literals_of_term args P = P::args;
   204 
   205 fun is_ground t = (term_vars t = []) andalso (term_frees t = []);
   206 
   207 fun eq_clause_type (P,Q) = 
   208     if ((is_ground P) orelse (is_ground Q)) then Unit_geq else Other;
   209 
   210 fun unit_clause_type (Const ("op =",_) $ P $ Q) = eq_clause_type (P,Q)
   211   | unit_clause_type _ = Unit_neq;
   212 
   213 fun clause_kind tm = 
   214     case literals_of_term [] tm of
   215         [lit] => unit_clause_type lit
   216       | _ => Other;
   217 
   218 (*** constants with types ***)
   219 
   220 (*An abstraction of Isabelle types*)
   221 datatype const_typ =  CTVar | CType of string * const_typ list
   222 
   223 fun uni_type (CType(con1,args1)) (CType(con2,args2)) = con1=con2 andalso uni_types args1 args2
   224   | uni_type (CType _) CTVar = true
   225   | uni_type CTVar CTVar = true
   226   | uni_type CTVar _ = false
   227 and uni_types [] [] = true
   228   | uni_types (a1::as1) (a2::as2) = uni_type a1 a2 andalso uni_types as1 as2;
   229 
   230 
   231 fun uni_constants (c1,ctp1) (c2,ctp2) = (c1=c2) andalso uni_types ctp1 ctp2;
   232 
   233 fun uni_mem _ [] = false
   234   | uni_mem (c,c_typ) ((c1,c_typ1)::ctyps) =
   235       uni_constants (c1,c_typ1) (c,c_typ) orelse uni_mem (c,c_typ) ctyps;
   236 
   237 fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs) 
   238   | const_typ_of (TFree _) = CTVar
   239   | const_typ_of (TVar _) = CTVar
   240 
   241 
   242 fun const_w_typ thy (c,typ) = 
   243     let val tvars = Sign.const_typargs thy (c,typ)
   244     in (c, map const_typ_of tvars) end;
   245 
   246 fun add_term_consts_typs_rm thy ncs (Const(c, typ)) cs =
   247       if (c mem ncs) then cs else (const_w_typ thy (c,typ) ins cs)
   248   | add_term_consts_typs_rm thy ncs (t $ u) cs =
   249       add_term_consts_typs_rm thy ncs  t (add_term_consts_typs_rm thy ncs u cs)
   250   | add_term_consts_typs_rm thy ncs (Abs(_,_,t)) cs = add_term_consts_typs_rm thy ncs t cs
   251   | add_term_consts_typs_rm thy ncs _ cs = cs;
   252 
   253 fun term_consts_typs_rm thy ncs t = add_term_consts_typs_rm thy ncs t [];
   254 
   255 fun consts_typs_of_term thy = term_consts_typs_rm thy standard_consts;
   256 
   257 fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs)
   258 
   259 
   260 (**** Constant / Type Frequencies ****)
   261 
   262 local
   263 
   264 fun cons_nr CTVar = 0
   265   | cons_nr (CType _) = 1;
   266 
   267 in
   268 
   269 fun const_typ_ord TU =
   270   case TU of
   271     (CType (a, Ts), CType (b, Us)) =>
   272       (case fast_string_ord(a,b) of EQUAL => dict_ord const_typ_ord (Ts,Us) | ord => ord)
   273   | (T, U) => int_ord (cons_nr T, cons_nr U);
   274 
   275 end;
   276 
   277 structure CTtab = TableFun(type key = const_typ list val ord = dict_ord const_typ_ord);
   278 
   279 fun count_axiom_consts thy ((tm,_), tab) = 
   280   let fun count_term_consts (Const cT, tab) =
   281 	    let val (c, cts) = const_w_typ thy cT
   282 		val cttab = Option.getOpt (Symtab.lookup tab c, CTtab.empty)
   283 		val n = Option.getOpt (CTtab.lookup cttab cts, 0)
   284 	    in 
   285 		Symtab.update (c, CTtab.update (cts, n+1) cttab) tab
   286             end
   287 	| count_term_consts (t $ u, tab) =
   288 	    count_term_consts (t, count_term_consts (u, tab))
   289 	| count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab)
   290 	| count_term_consts (_, tab) = tab
   291   in  count_term_consts (tm, tab) end;
   292 
   293 
   294 (******** filter clauses ********)
   295 
   296 (*The default ignores the constant-count and gives the old Strategy 3*)
   297 val weight_fn = ref (fn x : real => 1.0);
   298 
   299 fun const_weight ctab (c, cts) =
   300   let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
   301       fun add ((cts',m), n) = if uni_types cts cts' then m+n else n
   302   in  List.foldl add 0 pairs  end;
   303 
   304 fun add_ct_weight ctab ((c,T), w) =
   305   w + !weight_fn (real (const_weight ctab (c,T)));
   306 
   307 fun consts_typs_weight ctab =
   308     List.foldl (add_ct_weight ctab) 0.0;
   309 
   310 (*Relevant constants are weighted according to frequency, 
   311   but irrelevant constants are simply counted. Otherwise, Skolem functions,
   312   which are rare, would harm a clause's chances of being picked.*)
   313 fun clause_weight_s_3 ctab gctyps consts_typs =
   314     let val rel = filter (fn s => uni_mem s gctyps) consts_typs
   315         val rel_weight = consts_typs_weight ctab rel
   316     in
   317 	rel_weight / (rel_weight + real (length consts_typs - length rel))
   318     end;
   319 
   320 fun relevant_clauses_ax_3 ctab rel_axs [] P (addc,tmpc) keep =
   321       if null addc orelse null tmpc 
   322       then (addc @ rel_axs @ keep, tmpc)   (*termination!*)
   323       else relevant_clauses_ax_3 ctab addc tmpc P ([],[]) (rel_axs @ keep)
   324   | relevant_clauses_ax_3 ctab rel_axs ((clstm,(consts_typs,weight))::e_axs) P (addc,tmpc) keep =
   325       let fun clause_weight_ax (_,(refconsts_typs,wa)) =
   326               wa * clause_weight_s_3 ctab refconsts_typs consts_typs;
   327           val weight' = List.foldl Real.max weight (map clause_weight_ax rel_axs)
   328 	  val e_ax' = (clstm, (consts_typs,weight'))
   329       in
   330 	if P <= weight' 
   331 	then relevant_clauses_ax_3 ctab rel_axs e_axs P (e_ax'::addc, tmpc) keep
   332 	else relevant_clauses_ax_3 ctab rel_axs e_axs P (addc, e_ax'::tmpc) keep
   333       end;
   334 
   335 fun pair_consts_typs_axiom thy (tm,name) =
   336     ((tm,name), (consts_typs_of_term thy tm));
   337 
   338 fun safe_unit_clause ((t,_), _) = 
   339       case clause_kind t of
   340 	  Unit_neq => true
   341 	| Unit_geq => true
   342 	| Other => false;
   343 	
   344 fun axiom_ord ((_,(_,w1)), (_,(_,w2))) = Real.compare (w2,w1);
   345 
   346 fun showconst (c,cttab) = 
   347       List.app (fn n => Output.debug (Int.toString n ^ " occurrences of " ^ c))
   348 	        (map #2 (CTtab.dest cttab))
   349 
   350 fun show_cname (name,k) = name ^ "__" ^ Int.toString k;
   351 
   352 fun showax ((_,cname), (_,w)) = 
   353     Output.debug ("Axiom " ^ show_cname cname ^ " has weight " ^ Real.toString w)
   354 	      
   355 	      fun relevance_filter3_aux thy axioms goals = 
   356   let val pass = !pass_mark
   357       val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
   358       val goals_consts_typs = get_goal_consts_typs thy goals
   359       fun relevant [] (ax,r) = (ax,r)
   360 	| relevant ((clstm,consts_typs)::y) (ax,r) =
   361 	    let val weight = clause_weight_s_3 const_tab goals_consts_typs consts_typs
   362 		val ccc = (clstm, (consts_typs,weight))
   363 	    in
   364 	      if pass <= weight 
   365 	      then relevant y (ccc::ax, r)
   366 	      else relevant y (ax, ccc::r)
   367 	    end
   368       val (rel_clauses,nrel_clauses) =
   369 	  relevant (map (pair_consts_typs_axiom thy) axioms) ([],[]) 
   370       val (ax,r) = relevant_clauses_ax_3 const_tab rel_clauses nrel_clauses pass ([],[]) []
   371       val ax' = Library.take(!max_filtered, Library.sort axiom_ord ax)
   372   in
   373       if !Output.show_debug_msgs then
   374 	   (List.app showconst (Symtab.dest const_tab);
   375 	    List.app showax ax)
   376       else ();
   377       if !add_unit then (filter safe_unit_clause r) @ ax'
   378       else ax'
   379   end;
   380 
   381 fun relevance_filter3 thy axioms goals =
   382   map #1 (relevance_filter3_aux thy axioms goals);
   383     
   384 
   385 (******************************************************************)
   386 (* Generic functions for relevance filtering                      *)
   387 (******************************************************************)
   388 
   389 exception RELEVANCE_FILTER of string;
   390 
   391 fun relevance_filter thy axioms goals = 
   392   case (!strategy) of 1 => relevance_filter1 axioms goals
   393 		    | 2 => relevance_filter2 axioms goals
   394 		    | 3 => relevance_filter3 thy axioms goals
   395 		    | _ => raise RELEVANCE_FILTER("strategy doesn't exist");
   396 
   397 end;