Frequency analysis of constants (with types).
authorpaulson
Fri, 10 Mar 2006 12:27:36 +0100
changeset 19231 c8879dd3a953
parent 19230 3342e7554b77
child 19232 1f5b5dc3f48a
Frequency analysis of constants (with types). Ability to restrict the number of accepted clauses.
src/HOL/Tools/ATP/reduce_axiomsN.ML
--- a/src/HOL/Tools/ATP/reduce_axiomsN.ML	Fri Mar 10 04:03:48 2006 +0100
+++ b/src/HOL/Tools/ATP/reduce_axiomsN.ML	Fri Mar 10 12:27:36 2006 +0100
@@ -6,7 +6,8 @@
 struct
 
 val pass_mark = ref 0.5;
-val strategy = ref 1;
+val strategy = ref 3;
+val max_filtered = ref 2000;
 
 fun pol_to_int true = 1
   | pol_to_int false = ~1;
@@ -36,6 +37,8 @@
 
 fun term_consts_rm ncs t = add_term_consts_rm ncs t [];
 
+(*Including equality in this list might be expected to stop rules like subset_antisym from
+  being chosen, but for some reason filtering works better with them listed.*)
 val standard_consts =
   ["Trueprop","==>","all","Ex","op &","op |","Not","All","op -->","op =","==","True","False"];
 
@@ -189,7 +192,7 @@
 (******************************************************************)
 
 (*** unit clauses ***)
-datatype clause_type = Unit_neq | Unit_geq | Other
+datatype clause_kind = Unit_neq | Unit_geq | Other
 
 (*Whether all "simple" unit clauses should be included*)
 val add_unit = ref true;
@@ -207,13 +210,14 @@
 fun unit_clause_type (Const ("op =",_) $ P $ Q) = eq_clause_type (P,Q)
   | unit_clause_type _ = Unit_neq;
 
-fun clause_type tm = 
+fun clause_kind tm = 
     case literals_of_term [] tm of
         [lit] => unit_clause_type lit
       | _ => Other;
 
 (*** constants with types ***)
 
+(*An abstraction of Isabelle types*)
 datatype const_typ =  CTVar | CType of string * const_typ list
 
 fun uni_type (CType(con1,args1)) (CType(con2,args2)) = con1=con2 andalso uni_types args1 args2
@@ -224,15 +228,15 @@
   | uni_types (a1::as1) (a2::as2) = uni_type a1 a2 andalso uni_types as1 as2;
 
 
-fun uni_constants (c1,ctp1) (c2,ctp2) = (c1 = c2) andalso (uni_types ctp1 ctp2);
+fun uni_constants (c1,ctp1) (c2,ctp2) = (c1=c2) andalso uni_types ctp1 ctp2;
 
 fun uni_mem _ [] = false
   | uni_mem (c,c_typ) ((c1,c_typ1)::ctyps) =
       uni_constants (c1,c_typ1) (c,c_typ) orelse uni_mem (c,c_typ) ctyps;
 
-fun const_typ_of (Type (c,typs)) = CType (c,map const_typ_of typs) 
-  | const_typ_of (TFree(_,_)) = CTVar
-  | const_typ_of (TVar(_,_)) = CTVar
+fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs) 
+  | const_typ_of (TFree _) = CTVar
+  | const_typ_of (TVar _) = CTVar
 
 
 fun const_w_typ thy (c,typ) = 
@@ -252,19 +256,39 @@
 
 fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs)
 
-fun lookup_or_zero (c,tab) =
-    case Symtab.lookup tab c of
-        NONE => 0
-      | SOME n => n
+
+(**** Constant / Type Frequencies ****)
+
+local
+
+fun cons_nr CTVar = 0
+  | cons_nr (CType _) = 1;
+
+in
+
+fun const_typ_ord TU =
+  case TU of
+    (CType (a, Ts), CType (b, Us)) =>
+      (case fast_string_ord(a,b) of EQUAL => dict_ord const_typ_ord (Ts,Us) | ord => ord)
+  | (T, U) => int_ord (cons_nr T, cons_nr U);
 
-fun count_term_consts (Const(c,_), tab) =
-      Symtab.update (c, 1 + lookup_or_zero (c,tab)) tab
-  | count_term_consts (t $ u, tab) =
-      count_term_consts (t, count_term_consts (u, tab))
-  | count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab)
-  | count_term_consts (_, tab) = tab;
+end;
+
+structure CTtab = TableFun(type key = const_typ list val ord = dict_ord const_typ_ord);
 
-fun count_axiom_consts ((tm,_), tab) = count_term_consts (tm, tab);
+fun count_axiom_consts thy ((tm,_), tab) = 
+  let fun count_term_consts (Const cT, tab) =
+	    let val (c, cts) = const_w_typ thy cT
+		val cttab = Option.getOpt (Symtab.lookup tab c, CTtab.empty)
+		val n = Option.getOpt (CTtab.lookup cttab cts, 0)
+	    in 
+		Symtab.update (c, CTtab.update (cts, n+1) cttab) tab
+            end
+	| count_term_consts (t $ u, tab) =
+	    count_term_consts (t, count_term_consts (u, tab))
+	| count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab)
+	| count_term_consts (_, tab) = tab
+  in  count_term_consts (tm, tab) end;
 
 
 (******** filter clauses ********)
@@ -272,28 +296,35 @@
 (*The default ignores the constant-count and gives the old Strategy 3*)
 val weight_fn = ref (fn x : real => 1.0);
 
-fun add_ct_weight ctab ((c,_), w) =
-  w + !weight_fn (100.0 / real (Option.valOf (Symtab.lookup ctab c)));
+fun const_weight ctab (c, cts) =
+  let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
+      fun add ((cts',m), n) = if uni_types cts cts' then m+n else n
+  in  List.foldl add 0 pairs  end;
+
+fun add_ct_weight ctab ((c,T), w) =
+  w + !weight_fn (real (const_weight ctab (c,T)));
 
 fun consts_typs_weight ctab =
     List.foldl (add_ct_weight ctab) 0.0;
 
+(*Relevant constants are weighted according to frequency, 
+  but irrelevant constants are simply counted. Otherwise, Skolem functions,
+  which are rare, would harm a clause's chances of being picked.*)
 fun clause_weight_s_3 ctab gctyps consts_typs =
     let val rel = filter (fn s => uni_mem s gctyps) consts_typs
+        val rel_weight = consts_typs_weight ctab rel
     in
-	(consts_typs_weight ctab rel) / (consts_typs_weight ctab consts_typs)
+	rel_weight / (rel_weight + real (length consts_typs - length rel))
     end;
 
-fun find_clause_weight_s_3_alt ctab consts_typs (_,(refconsts_typs,wa)) =
-    wa * clause_weight_s_3 ctab refconsts_typs consts_typs;
-
 fun relevant_clauses_ax_3 ctab rel_axs [] P (addc,tmpc) keep =
       if null addc orelse null tmpc 
       then (addc @ rel_axs @ keep, tmpc)   (*termination!*)
       else relevant_clauses_ax_3 ctab addc tmpc P ([],[]) (rel_axs @ keep)
   | relevant_clauses_ax_3 ctab rel_axs ((clstm,(consts_typs,weight))::e_axs) P (addc,tmpc) keep =
-      let val weights = map (find_clause_weight_s_3_alt ctab consts_typs) rel_axs
-          val weight' = List.foldl Real.max weight weights
+      let fun clause_weight_ax (_,(refconsts_typs,wa)) =
+              wa * clause_weight_s_3 ctab refconsts_typs consts_typs;
+          val weight' = List.foldl Real.max weight (map clause_weight_ax rel_axs)
 	  val e_ax' = (clstm, (consts_typs,weight'))
       in
 	if P <= weight' 
@@ -301,35 +332,51 @@
 	else relevant_clauses_ax_3 ctab rel_axs e_axs P (addc, e_ax'::tmpc) keep
       end;
 
-fun weight_of_axiom thy (tm,name) =
+fun pair_consts_typs_axiom thy (tm,name) =
     ((tm,name), (consts_typs_of_term thy tm));
 
-fun safe_unit_clause ((clstm,_), _) = 
-      case clause_type clstm of
+fun safe_unit_clause ((t,_), _) = 
+      case clause_kind t of
 	  Unit_neq => true
 	| Unit_geq => true
 	| Other => false;
+	
+fun axiom_ord ((_,(_,w1)), (_,(_,w2))) = Real.compare (w2,w1);
 
-fun relevance_filter3_aux thy axioms goals = 
-    let val pass = !pass_mark
-	val const_tab = List.foldl count_axiom_consts Symtab.empty axioms
-	val goals_consts_typs = get_goal_consts_typs thy goals
-	fun relevant [] (ax,r) = (ax,r)
-	  | relevant ((clstm,consts_typs)::y) (ax,r) =
-	      let val weight = clause_weight_s_3 const_tab goals_consts_typs consts_typs
-		  val ccc = (clstm, (consts_typs,weight))
-	      in
-		if pass <= weight 
-		then relevant y (ccc::ax, r)
-		else relevant y (ax, ccc::r)
-	      end
-	val (rel_clauses,nrel_clauses) =
-	    relevant (map (weight_of_axiom thy) axioms) ([],[]) 
-	val (ax,r) = relevant_clauses_ax_3 const_tab rel_clauses nrel_clauses pass ([],[]) []
-    in
-	if !add_unit then (filter safe_unit_clause r) @ ax 
-	else ax
-    end;
+fun showconst (c,cttab) = 
+      List.app (fn n => Output.debug (Int.toString n ^ " occurrences of " ^ c))
+	        (map #2 (CTtab.dest cttab))
+
+fun show_cname (name,k) = name ^ "__" ^ Int.toString k;
+
+fun showax ((_,cname), (_,w)) = 
+    Output.debug ("Axiom " ^ show_cname cname ^ " has weight " ^ Real.toString w)
+	      
+	      fun relevance_filter3_aux thy axioms goals = 
+  let val pass = !pass_mark
+      val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
+      val goals_consts_typs = get_goal_consts_typs thy goals
+      fun relevant [] (ax,r) = (ax,r)
+	| relevant ((clstm,consts_typs)::y) (ax,r) =
+	    let val weight = clause_weight_s_3 const_tab goals_consts_typs consts_typs
+		val ccc = (clstm, (consts_typs,weight))
+	    in
+	      if pass <= weight 
+	      then relevant y (ccc::ax, r)
+	      else relevant y (ax, ccc::r)
+	    end
+      val (rel_clauses,nrel_clauses) =
+	  relevant (map (pair_consts_typs_axiom thy) axioms) ([],[]) 
+      val (ax,r) = relevant_clauses_ax_3 const_tab rel_clauses nrel_clauses pass ([],[]) []
+      val ax' = Library.take(!max_filtered, Library.sort axiom_ord ax)
+  in
+      if !Output.show_debug_msgs then
+	   (List.app showconst (Symtab.dest const_tab);
+	    List.app showax ax)
+      else ();
+      if !add_unit then (filter safe_unit_clause r) @ ax'
+      else ax'
+  end;
 
 fun relevance_filter3 thy axioms goals =
   map #1 (relevance_filter3_aux thy axioms goals);