Tidying and restructuring.
authorpaulson
Tue, 07 Mar 2006 16:49:48 +0100
changeset 19208 3e8006cbc925
parent 19207 33f1b4515ce4
child 19209 27b91724809f
Tidying and restructuring.
src/HOL/Tools/ATP/reduce_axiomsN.ML
--- a/src/HOL/Tools/ATP/reduce_axiomsN.ML	Tue Mar 07 16:49:12 2006 +0100
+++ b/src/HOL/Tools/ATP/reduce_axiomsN.ML	Tue Mar 07 16:49:48 2006 +0100
@@ -1,8 +1,8 @@
+(* Authors: Jia Meng, NICTA and Lawrence C Paulson, Cambridge University Computer Laboratory
+   ID: $Id$
+   Filtering strategies *)
+
 structure ReduceAxiomsN =
-(* Author: Jia Meng, Cambridge University Computer Laboratory
-   
-   Two filtering strategies *)
-
 struct
 
 val pass_mark = ref 0.5;
@@ -12,18 +12,23 @@
   | pol_to_int false = ~1;
 
 fun part refs [] (s1,s2) = (s1,s2)
-  | part refs (s::ss) (s1,s2) = if (s mem refs) then (part refs ss (s::s1,s2)) else (part refs ss (s1,s::s2));
+  | part refs (s::ss) (s1,s2) = 
+      if (s mem refs) then part refs ss (s::s1,s2) else part refs ss (s1,s::s2);
 
 
 fun pol_mem _ [] = false
-  | pol_mem (pol,const) ((p,c)::pcs) = if ((pol = not p) andalso (const = c)) then true else pol_mem (pol,const) pcs;
+  | pol_mem (pol,const) ((p,c)::pcs) =
+      (pol = not p andalso const = c) orelse pol_mem (pol,const) pcs;
 
 
 fun part_w_pol refs [] (s1,s2) = (s1,s2)
-  | part_w_pol refs (s::ss) (s1,s2) = if (pol_mem s refs) then (part_w_pol refs ss (s::s1,s2)) else (part_w_pol refs ss (s1,s::s2));
+  | part_w_pol refs (s::ss) (s1,s2) =
+      if (pol_mem s refs) then part_w_pol refs ss (s::s1,s2) 
+      else part_w_pol refs ss (s1,s::s2);
 
 
-fun add_term_consts_rm ncs (Const(c, _)) cs = if (c mem ncs) then cs else (c ins_string cs)
+fun add_term_consts_rm ncs (Const(c, _)) cs =
+      if (c mem ncs) then cs else (c ins_string cs)
   | add_term_consts_rm ncs (t $ u) cs =
       add_term_consts_rm ncs t (add_term_consts_rm ncs u cs)
   | add_term_consts_rm ncs (Abs(_,_,t)) cs = add_term_consts_rm ncs t cs
@@ -31,10 +36,13 @@
 
 fun term_consts_rm ncs t = add_term_consts_rm ncs t [];
 
-fun consts_of_term term = term_consts_rm ["Trueprop","==>","all","Ex","op &", "op |", "Not", "All", "op -->", "op =", "==", "True", "False"] term;
+val standard_consts =
+  ["Trueprop","==>","all","Ex","op &","op |","Not","All","op -->","op =","==","True","False"];
+
+val consts_of_term = term_consts_rm standard_consts;
 
 
-fun add_term_pconsts_rm ncs (Const(c,_)) pol cs = if (c mem ncs) then cs else ((pol,c) ins cs)
+fun add_term_pconsts_rm ncs (Const(c,_)) pol cs = if c mem ncs then cs else ((pol,c) ins cs)
   | add_term_pconsts_rm ncs (Const("Not",_)$P) pol cs = add_term_pconsts_rm ncs P (not pol) cs
   | add_term_pconsts_rm ncs (P$Q) pol cs = 
     add_term_pconsts_rm ncs P pol (add_term_pconsts_rm ncs Q pol cs)
@@ -44,13 +52,11 @@
 
 fun term_pconsts_rm ncs t = add_term_pconsts_rm ncs t true [];
 
-
-fun pconsts_of_term term = term_pconsts_rm ["Trueprop","==>","all","Ex","op &", "op |", "Not", "All", "op -->", "op =", "==", "True", "False"] term;
+val pconsts_of_term = term_pconsts_rm standard_consts;
 
 fun consts_in_goal goal = consts_of_term goal;
 fun get_goal_consts cs = foldl (op union_string) [] (map consts_in_goal cs);
 
-
 fun pconsts_in_goal goal = pconsts_of_term goal;
 fun get_goal_pconsts cs = foldl (op union) [] (map pconsts_in_goal cs);
 
@@ -62,25 +68,26 @@
 fun find_clause_weight_s_1 (refconsts : string list) consts wa = 
     let val (rel,irrel) = part refconsts consts ([],[])
     in
-	((real (length rel))/(real (length consts))) * wa
+	(real (length rel) / real (length consts)) * wa
     end;
 
 fun find_clause_weight_m_1 [] (_,w) = w 
   | find_clause_weight_m_1 ((_,(refconsts,wa))::y) (consts,w) =
-    let val w' = find_clause_weight_s_1 refconsts consts wa
-    in
-	if (w < w') then find_clause_weight_m_1 y (consts,w')
+      let val w' = find_clause_weight_s_1 refconsts consts wa
+      in
+	if w < w' then find_clause_weight_m_1 y (consts,w')
 	else find_clause_weight_m_1 y (consts,w)
-    end;
+      end;
 
 
 fun relevant_clauses_ax_g_1 _ []  _ (ax,r) = (ax,r)
   | relevant_clauses_ax_g_1 gconsts  ((clstm,(consts,_))::y) P (ax,r) =
-    let val weight = find_clause_weight_s_1 gconsts consts 1.0
-    in
-	if  P <= weight then relevant_clauses_ax_g_1 gconsts y P ((clstm,(consts,weight))::ax,r)
+      let val weight = find_clause_weight_s_1 gconsts consts 1.0
+      in
+	if  P <= weight 
+	then relevant_clauses_ax_g_1 gconsts y P ((clstm,(consts,weight))::ax,r)
 	else relevant_clauses_ax_g_1 gconsts y P (ax,(clstm,(consts,weight))::r)
-    end;
+      end;
 
 
 fun relevant_clauses_ax_1 rel_axs  [] P (addc,tmpc) keep = 
@@ -88,21 +95,21 @@
 		| _ => case tmpc of [] => addc @ rel_axs @ keep
 				  | _ => relevant_clauses_ax_1 addc tmpc P ([],[]) (rel_axs @ keep))
   | relevant_clauses_ax_1 rel_axs ((clstm,(consts,weight))::e_axs) P (addc,tmpc) keep = 
-    let val weight' = find_clause_weight_m_1 rel_axs (consts,weight) 
-	val e_ax' = (clstm,(consts, weight'))
-    in
-	
-	if P <= weight' then relevant_clauses_ax_1 rel_axs e_axs P ((clstm,(consts,weight'))::addc,tmpc) keep
+      let val weight' = find_clause_weight_m_1 rel_axs (consts,weight) 
+	  val e_ax' = (clstm,(consts, weight'))
+      in
+	if P <= weight' 
+	then relevant_clauses_ax_1 rel_axs e_axs P ((clstm,(consts,weight'))::addc,tmpc) keep
 	else relevant_clauses_ax_1 rel_axs e_axs P (addc,(clstm,(consts,weight'))::tmpc) keep 
-    end;
+      end;
 
 
 fun initialize [] ax_weights = ax_weights
   | initialize ((tm,name)::tms_names) ax_weights =
-    let val consts = consts_of_term tm
-    in
-	initialize tms_names (((tm,name),(consts,0.0))::ax_weights)
-    end;
+      let val consts = consts_of_term tm
+      in
+	  initialize tms_names (((tm,name),(consts,0.0))::ax_weights)
+      end;
 
 fun relevance_filter1_aux axioms goals = 
     let val pass = !pass_mark
@@ -116,8 +123,6 @@
 fun relevance_filter1 axioms goals = map fst (relevance_filter1_aux axioms goals);
 
 
-
-
 (*************************************************************************)
 (*            the second relevance filtering strategy                    *)
 (*************************************************************************)
@@ -186,17 +191,15 @@
 (*** unit clauses ***)
 datatype clause_type = Unit_neq | Unit_geq | Other
 
+(*Whether all "simple" unit clauses should be included*)
 val add_unit = ref true;
 
-
 fun literals_of_term args (Const ("Trueprop",_) $ P) = literals_of_term args P
   | literals_of_term args (Const ("op |",_) $ P $ Q) = 
     literals_of_term (literals_of_term args P) Q
-  | literals_of_term args P = (P::args);
-
+  | literals_of_term args P = P::args;
 
-fun is_ground t = if (term_vars t = []) then (term_frees t = []) else false;
-
+fun is_ground t = (term_vars t = []) andalso (term_frees t = []);
 
 fun eq_clause_type (P,Q) = 
     if ((is_ground P) orelse (is_ground Q)) then Unit_geq else Other;
@@ -205,34 +208,27 @@
   | unit_clause_type _ = Unit_neq;
 
 fun clause_type tm = 
-    let val lits = literals_of_term [] tm
-	val nlits = length lits
-    in 
-	if (nlits > 1) then Other
-	else unit_clause_type (hd lits)
-    end;
+    case literals_of_term [] tm of
+        [lit] => unit_clause_type lit
+      | _ => Other;
 
 (*** constants with types ***)
 
 datatype const_typ =  CTVar | CType of string * const_typ list
 
-fun uni_type (CType (con1,args1)) (CType (con2,args2)) = (con1 = con2) andalso (uni_types args1 args2)
-  | uni_type (CType (_,_)) CTVar = true
+fun uni_type (CType(con1,args1)) (CType(con2,args2)) = con1=con2 andalso uni_types args1 args2
+  | uni_type (CType _) CTVar = true
   | uni_type CTVar CTVar = true
   | uni_type CTVar _ = false
-
-and 
-     uni_types [] [] = true
-      | uni_types (a1::as1) (a2::as2) = (uni_type a1 a2) andalso (uni_types as1 as2);
-
+and uni_types [] [] = true
+  | uni_types (a1::as1) (a2::as2) = uni_type a1 a2 andalso uni_types as1 as2;
 
 
 fun uni_constants (c1,ctp1) (c2,ctp2) = (c1 = c2) andalso (uni_types ctp1 ctp2);
 
 fun uni_mem _ [] = false
-  | uni_mem (c,c_typ) ((c1,c_typ1)::ctyps) = (uni_constants (c1,c_typ1) (c,c_typ)) orelse (uni_mem (c,c_typ) ctyps);
-
-
+  | uni_mem (c,c_typ) ((c1,c_typ1)::ctyps) =
+      uni_constants (c1,c_typ1) (c,c_typ) orelse uni_mem (c,c_typ) ctyps;
 
 fun const_typ_of (Type (c,typs)) = CType (c,map const_typ_of typs) 
   | const_typ_of (TFree(_,_)) = CTVar
@@ -245,7 +241,8 @@
 	(c,map const_typ_of tvars)
     end;
 
-fun add_term_consts_typs_rm thy ncs (Const(c, tp)) cs = if (c mem ncs) then cs else (const_w_typ thy (c,tp) ins cs)
+fun add_term_consts_typs_rm thy ncs (Const(c, tp)) cs =
+      if (c mem ncs) then cs else (const_w_typ thy (c,tp) ins cs)
   | add_term_consts_typs_rm thy ncs (t $ u) cs =
       add_term_consts_typs_rm thy ncs  t (add_term_consts_typs_rm thy ncs u cs)
   | add_term_consts_typs_rm thy ncs (Abs(_,_,t)) cs = add_term_consts_typs_rm thy ncs t cs
@@ -253,55 +250,44 @@
 
 fun term_consts_typs_rm thy ncs t = add_term_consts_typs_rm thy ncs t [];
 
-fun consts_typs_of_term thy term = term_consts_typs_rm thy ["Trueprop","==>","all","Ex","op &", "op |", "Not", "All", "op -->", "op =", "==", "True", "False"] term;
-
+fun consts_typs_of_term thy = term_consts_typs_rm thy standard_consts;
 
-fun consts_typs_in_goal thy goal = consts_typs_of_term thy goal;
-
-fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_in_goal thy) cs)
+fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs)
 
 
 (******** filter clauses ********)
 
-fun part3 gctyps [] (s1,s2) = (s1,s2)
-  | part3 gctyps (s::ss) (s1,s2) = if (uni_mem s gctyps) then part3 gctyps ss (s::s1,s2) else part3 gctyps ss (s1,s::s2);
-
-
 fun find_clause_weight_s_3 gctyps consts_typs wa =
-    let val (rel,irrel) = part3 gctyps consts_typs ([],[])
+    let val rel = filter (fn s => uni_mem s gctyps) consts_typs
     in
-	((real (length rel))/(real (length consts_typs))) * wa
+	(real (length rel) / real (length consts_typs)) * wa
     end;
 
-
-fun find_clause_weight_m_3 [] (_,w) = w
-  | find_clause_weight_m_3 ((_,(_,(refconsts_typs,wa)))::y) (consts_typs,w) =
-    let val w' = find_clause_weight_s_3 refconsts_typs consts_typs wa
-    in
-	if (w < w') then find_clause_weight_m_3 y (consts_typs,w')
-	else find_clause_weight_m_3 y (consts_typs,w)
-    end;
-
-
 fun relevant_clauses_ax_g_3 _ [] _ (ax,r) = (ax,r)
   | relevant_clauses_ax_g_3 gctyps ((cls_typ,(clstm,(consts_typs,_)))::y) P (ax,r) =
-    let val weight = find_clause_weight_s_3 gctyps consts_typs 1.0
-    in
-	if P <= weight then relevant_clauses_ax_g_3 gctyps y P ((cls_typ,(clstm,(consts_typs,weight)))::ax,r)
-	else relevant_clauses_ax_g_3 gctyps y P (ax,(cls_typ,(clstm,(consts_typs,weight)))::r)
-    end;
+      let val weight = find_clause_weight_s_3 gctyps consts_typs 1.0
+          val ccc = (cls_typ, (clstm, (consts_typs,weight)))
+      in
+	if P <= weight 
+	then relevant_clauses_ax_g_3 gctyps y P (ccc::ax, r)
+	else relevant_clauses_ax_g_3 gctyps y P (ax, ccc::r)
+      end;
+
+fun find_clause_weight_s_3_alt consts_typs (_,(_,(refconsts_typs,wa))) =
+    find_clause_weight_s_3 refconsts_typs consts_typs wa;
 
 fun relevant_clauses_ax_3 rel_axs [] P (addc,tmpc) keep =
-    (case addc of [] => (rel_axs @ keep,tmpc)
-		| _ => case tmpc of [] => (addc @ rel_axs @ keep,[])
-				  | _ => relevant_clauses_ax_3 addc tmpc P ([],[]) (rel_axs @ keep))
+      if null addc orelse null tmpc 
+      then (addc @ rel_axs @ keep, tmpc)   (*termination!*)
+      else relevant_clauses_ax_3 addc tmpc P ([],[]) (rel_axs @ keep)
   | relevant_clauses_ax_3 rel_axs ((cls_typ,(clstm,(consts_typs,weight)))::e_axs) P (addc,tmpc) keep =
-    let val weight' = find_clause_weight_m_3 rel_axs (consts_typs,weight)
-	val e_ax' = (cls_typ,(clstm,(consts_typs,weight')))
-    in
+      let val weights = map (find_clause_weight_s_3_alt consts_typs) rel_axs
+          val weight' = List.foldl Real.max weight weights
+	  val e_ax' = (cls_typ,(clstm,(consts_typs,weight')))
+      in
 	if P <= weight' then relevant_clauses_ax_3 rel_axs e_axs P (e_ax'::addc,tmpc) keep
 	else relevant_clauses_ax_3 rel_axs e_axs P (addc,e_ax'::tmpc) keep
-    end;
+      end;
 
 fun initialize3 thy [] ax_weights = ax_weights
   | initialize3 thy ((tm,name)::tms_names) ax_weights =
@@ -321,16 +307,17 @@
     let val pass = !pass_mark
 	val axioms_weights = initialize3 thy axioms []
 	val goals_consts_typs = get_goal_consts_typs thy goals
-	val (rel_clauses,nrel_clauses) = relevant_clauses_ax_g_3 goals_consts_typs axioms_weights pass ([],[]) 
+	val (rel_clauses,nrel_clauses) =
+	    relevant_clauses_ax_g_3 goals_consts_typs axioms_weights pass ([],[]) 
 	val (ax,r) = relevant_clauses_ax_3 rel_clauses nrel_clauses pass ([],[]) []
     in
-	if (!add_unit) then add_unit_clauses ax r else ax
+	if !add_unit then add_unit_clauses ax r else ax
     end;
 
-fun relevance_filter3 thy axioms goals = map fst (map snd (relevance_filter3_aux thy axioms goals));
+fun relevance_filter3 thy axioms goals =
+  map (#1 o #2) (relevance_filter3_aux thy axioms goals);
     
 
-
 (******************************************************************)
 (* Generic functions for relevance filtering                      *)
 (******************************************************************)
@@ -338,13 +325,9 @@
 exception RELEVANCE_FILTER of string;
 
 fun relevance_filter thy axioms goals = 
-    let val cls = (case (!strategy) of 1 => relevance_filter1 axioms goals
-				     | 2 => relevance_filter2 axioms goals
-				     | 3 => relevance_filter3 thy axioms goals
-				     | _ => raise RELEVANCE_FILTER("strategy doesn't exists"))
-    in
-	cls
-    end;
-
+  case (!strategy) of 1 => relevance_filter1 axioms goals
+		    | 2 => relevance_filter2 axioms goals
+		    | 3 => relevance_filter3 thy axioms goals
+		    | _ => raise RELEVANCE_FILTER("strategy doesn't exist");
 
 end;
\ No newline at end of file