src/HOL/Tools/ATP/reduce_axiomsN.ML
changeset 20566 499500b1e348
parent 20527 958ec4833d87
child 20825 4b48fd429b18
equal deleted inserted replaced
20565:4440dd392048 20566:499500b1e348
    11 structure ReduceAxiomsN =
    11 structure ReduceAxiomsN =
    12 struct
    12 struct
    13 
    13 
    14 val run_relevance_filter = ref true;
    14 val run_relevance_filter = ref true;
    15 val theory_const = ref false;
    15 val theory_const = ref false;
    16 val pass_mark = ref 0.6;
    16 val pass_mark = ref 0.5;
    17 val convergence = ref 2.4;   (*Higher numbers allow longer inference chains*)
    17 val convergence = ref 3.2;    (*Higher numbers allow longer inference chains*)
       
    18 val max_new = ref 60;         (*Limits how many clauses can be picked up per stage*)
    18 val follow_defs = ref false;  (*Follow definitions. Makes problems bigger.*)
    19 val follow_defs = ref false;  (*Follow definitions. Makes problems bigger.*)
    19 
    20 
    20 fun log_weight2 (x:real) = 1.0 + 2.0/Math.ln (x+1.0);
    21 fun log_weight2 (x:real) = 1.0 + 2.0/Math.ln (x+1.0);
    21 
    22 
    22 (*The default seems best in practice. A constant function of one ignores
    23 (*The default seems best in practice. A constant function of one ignores
    37 datatype const_typ =  CTVar | CType of string * const_typ list
    38 datatype const_typ =  CTVar | CType of string * const_typ list
    38 
    39 
    39 (*Is the second type an instance of the first one?*)
    40 (*Is the second type an instance of the first one?*)
    40 fun match_type (CType(con1,args1)) (CType(con2,args2)) = 
    41 fun match_type (CType(con1,args1)) (CType(con2,args2)) = 
    41       con1=con2 andalso match_types args1 args2
    42       con1=con2 andalso match_types args1 args2
    42   | match_type CTVar (CType _) = true
    43   | match_type CTVar _ = true
    43   | match_type CTVar CTVar = true
       
    44   | match_type _ CTVar = false
    44   | match_type _ CTVar = false
    45 and match_types [] [] = true
    45 and match_types [] [] = true
    46   | match_types (a1::as1) (a2::as2) = match_type a1 a2 andalso match_types as1 as2;
    46   | match_types (a1::as1) (a2::as2) = match_type a1 a2 andalso match_types as1 as2;
    47 
    47 
    48 (*Is there a unifiable constant?*)
    48 (*Is there a unifiable constant?*)
    49 fun uni_mem gctab (c,c_typ) =
    49 fun uni_mem gctab (c,c_typ) =
    50   case Symtab.lookup gctab c of
    50   case Symtab.lookup gctab c of
    51       NONE => false
    51       NONE => false
    52     | SOME ctyps_list => exists (match_types c_typ) ctyps_list;
    52     | SOME ctyps_list => exists (match_types c_typ) ctyps_list;
    53   
    53   
       
    54 (*Maps a "real" type to a const_typ*)
    54 fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs) 
    55 fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs) 
    55   | const_typ_of (TFree _) = CTVar
    56   | const_typ_of (TFree _) = CTVar
    56   | const_typ_of (TVar _) = CTVar
    57   | const_typ_of (TVar _) = CTVar
    57 
    58 
       
    59 (*Pairs a constant with the list of its type instantiations (using const_typ)*)
    58 fun const_with_typ thy (c,typ) = 
    60 fun const_with_typ thy (c,typ) = 
    59     let val tvars = Sign.const_typargs thy (c,typ)
    61     let val tvars = Sign.const_typargs thy (c,typ)
    60     in (c, map const_typ_of tvars) end
    62     in (c, map const_typ_of tvars) end
    61     handle TYPE _ => (c,[]);   (*Variable (locale constant): monomorphic*)   
    63     handle TYPE _ => (c,[]);   (*Variable (locale constant): monomorphic*)   
    62 
    64 
    64   which we ignore.*)
    66   which we ignore.*)
    65 fun add_const_typ_table ((c,ctyps), tab) =
    67 fun add_const_typ_table ((c,ctyps), tab) =
    66   Symtab.map_default (c, [ctyps]) (fn [] => [] | ctyps_list => ctyps ins ctyps_list) 
    68   Symtab.map_default (c, [ctyps]) (fn [] => [] | ctyps_list => ctyps ins ctyps_list) 
    67     tab;
    69     tab;
    68 
    70 
    69 (*Free variables are counted, as well as constants, to handle locales*)
    71 (*Free variables are included, as well as constants, to handle locales*)
    70 fun add_term_consts_typs_rm thy (Const(c, typ), tab) =
    72 fun add_term_consts_typs_rm thy (Const(c, typ), tab) =
    71       add_const_typ_table (const_with_typ thy (c,typ), tab) 
    73       add_const_typ_table (const_with_typ thy (c,typ), tab) 
    72   | add_term_consts_typs_rm thy (Free(c, typ), tab) =
    74   | add_term_consts_typs_rm thy (Free(c, typ), tab) =
    73       add_const_typ_table (const_with_typ thy (c,typ), tab) 
    75       add_const_typ_table (const_with_typ thy (c,typ), tab) 
    74   | add_term_consts_typs_rm thy (t $ u, tab) =
    76   | add_term_consts_typs_rm thy (t $ u, tab) =
    93   in  t $ prop_of th  end
    95   in  t $ prop_of th  end
    94  else prop_of th;
    96  else prop_of th;
    95 
    97 
    96 (**** Constant / Type Frequencies ****)
    98 (**** Constant / Type Frequencies ****)
    97 
    99 
       
   100 (*A two-dimensional symbol table counts frequencies of constants. It's keyed first by
       
   101   constant name and second by its list of type instantiations. For the latter, we need
       
   102   a linear ordering on type const_typ list.*)
       
   103   
    98 local
   104 local
    99 
   105 
   100 fun cons_nr CTVar = 0
   106 fun cons_nr CTVar = 0
   101   | cons_nr (CType _) = 1;
   107   | cons_nr (CType _) = 1;
   102 
   108 
   126 	| count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab)
   132 	| count_term_consts (Abs(_,_,t), tab) = count_term_consts (t, tab)
   127 	| count_term_consts (_, tab) = tab
   133 	| count_term_consts (_, tab) = tab
   128   in  count_term_consts (const_prop_of thm, tab)  end;
   134   in  count_term_consts (const_prop_of thm, tab)  end;
   129 
   135 
   130 
   136 
   131 (******** filter clauses ********)
   137 (**** Actual Filtering Code ****)
   132 
   138 
   133 fun const_weight ctab (c, cts) =
   139 (*The frequency of a constant is the sum of those of all instances of its type.*)
       
   140 fun const_frequency ctab (c, cts) =
   134   let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
   141   let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
   135       fun add ((cts',m), n) = if match_types cts cts' then m+n else n
   142       fun add ((cts',m), n) = if match_types cts cts' then m+n else n
   136   in  List.foldl add 0 pairs  end;
   143   in  List.foldl add 0 pairs  end;
   137 
   144 
       
   145 (*Add in a constant's weight, as determined by its frequency.*)
   138 fun add_ct_weight ctab ((c,T), w) =
   146 fun add_ct_weight ctab ((c,T), w) =
   139   w + !weight_fn (real (const_weight ctab (c,T)));
   147   w + !weight_fn (real (const_frequency ctab (c,T)));
   140 
       
   141 fun consts_typs_weight ctab =
       
   142     List.foldl (add_ct_weight ctab) 0.0;
       
   143 
   148 
   144 (*Relevant constants are weighted according to frequency, 
   149 (*Relevant constants are weighted according to frequency, 
   145   but irrelevant constants are simply counted. Otherwise, Skolem functions,
   150   but irrelevant constants are simply counted. Otherwise, Skolem functions,
   146   which are rare, would harm a clause's chances of being picked.*)
   151   which are rare, would harm a clause's chances of being picked.*)
   147 fun clause_weight ctab gctyps consts_typs =
   152 fun clause_weight ctab gctyps consts_typs =
   148     let val rel = filter (uni_mem gctyps) consts_typs
   153     let val rel = filter (uni_mem gctyps) consts_typs
   149         val rel_weight = consts_typs_weight ctab rel
   154         val rel_weight = List.foldl (add_ct_weight ctab) 0.0 rel
   150     in
   155     in
   151 	rel_weight / (rel_weight + real (length consts_typs - length rel))
   156 	rel_weight / (rel_weight + real (length consts_typs - length rel))
   152     end;
   157     end;
   153     
   158     
   154 (*Multiplies out to a list of pairs: 'a * 'b list -> ('a * 'b) list -> ('a * 'b) list*)
   159 (*Multiplies out to a list of pairs: 'a * 'b list -> ('a * 'b) list -> ('a * 'b) list*)
   155 fun add_expand_pairs (c, ctyps_list) cpairs =
   160 fun add_expand_pairs (x,ys) xys = foldl (fn (y,acc) => (x,y)::acc) xys ys;
   156       foldl (fn (ctyps,cpairs) => (c,ctyps)::cpairs) cpairs ctyps_list;
       
   157 
   161 
   158 fun consts_typs_of_term thy t = 
   162 fun consts_typs_of_term thy t = 
   159   let val tab = add_term_consts_typs_rm thy (t, null_const_tab)
   163   let val tab = add_term_consts_typs_rm thy (t, null_const_tab)
   160   in  Symtab.fold add_expand_pairs tab []  end;
   164   in  Symtab.fold add_expand_pairs tab []  end;
   161 
   165 
   182 		   defs lhs rhs andalso
   186 		   defs lhs rhs andalso
   183 		   (Output.debug ("Definition found: " ^ name ^ "_" ^ Int.toString n); true)
   187 		   (Output.debug ("Definition found: " ^ name ^ "_" ^ Int.toString n); true)
   184 		 | _ => false
   188 		 | _ => false
   185     end;
   189     end;
   186 
   190 
       
   191 (*For a reverse sort, putting the largest values first.*)
       
   192 fun compare_pairs ((_,w1),(_,w2)) = Real.compare (w2,w1);
       
   193 
       
   194 (*Limit the number of new clauses, to prevent runaway acceptance.*)
       
   195 fun take_best newpairs =
       
   196   let val nnew = length newpairs
       
   197   in
       
   198     if nnew <= !max_new then (map #1 newpairs, [])
       
   199     else 
       
   200       let val cls = map #1 (sort compare_pairs newpairs)
       
   201       in  Output.debug ("Number of candidates, " ^ Int.toString nnew ^ 
       
   202 			", exceeds the limit of " ^ Int.toString (!max_new));
       
   203 	  (List.take (cls, !max_new), List.drop (cls, !max_new))
       
   204       end
       
   205   end;
       
   206 
   187 fun relevant_clauses thy ctab p rel_consts =
   207 fun relevant_clauses thy ctab p rel_consts =
   188   let fun relevant (newrels,rejects) []  =
   208   let fun relevant ([],rejects) [] = []     (*Nothing added this iteration, so stop*)
   189 	    if null newrels then [] 
   209 	| relevant (newpairs,rejects) [] =
   190 	    else 
   210 	    let val (newrels,more_rejeccts) = take_best newpairs
   191 	      let val new_consts = List.concat (map #2 newrels)
   211 		val new_consts = List.concat (map #2 newrels)
   192 	          val rel_consts' = foldl add_const_typ_table rel_consts new_consts
   212 		val rel_consts' = foldl add_const_typ_table rel_consts new_consts
   193                   val newp = p + (1.0-p) / !convergence
   213 		val newp = p + (1.0-p) / !convergence
   194 	      in Output.debug ("found relevant: " ^ Int.toString (length newrels));
   214 	    in Output.debug ("relevant this iteration: " ^ Int.toString (length newrels));
   195                  newrels @ relevant_clauses thy ctab newp rel_consts' rejects
   215 	       (map #1 newrels) @ 
   196 	      end
   216 	       (relevant_clauses thy ctab newp rel_consts' (more_rejeccts@rejects))
       
   217 	    end
   197 	| relevant (newrels,rejects) ((ax as (clsthm as (_,(name,n)),consts_typs)) :: axs) =
   218 	| relevant (newrels,rejects) ((ax as (clsthm as (_,(name,n)),consts_typs)) :: axs) =
   198 	    let val weight = clause_weight ctab rel_consts consts_typs
   219 	    let val weight = clause_weight ctab rel_consts consts_typs
   199 	    in
   220 	    in
   200 	      if p <= weight orelse (!follow_defs andalso defines thy clsthm rel_consts)
   221 	      if p <= weight orelse (!follow_defs andalso defines thy clsthm rel_consts)
   201 	      then (Output.debug name; Output.debug "\n";
   222 	      then (Output.debug name; 
   202 	            relevant (ax::newrels, rejects) axs)
   223 	            relevant ((ax,weight)::newrels, rejects) axs)
   203 	      else relevant (newrels, ax::rejects) axs
   224 	      else relevant (newrels, ax::rejects) axs
   204 	    end
   225 	    end
   205     in  Output.debug ("relevant_clauses: " ^ Real.toString p);
   226     in  Output.debug ("relevant_clauses: " ^ Real.toString p);
   206         relevant ([],[]) end;
   227         relevant ([],[]) 
       
   228     end;
   207 	
   229 	
   208      
   230 fun relevance_filter thy axioms goals = 
   209 fun relevance_filter_aux thy axioms goals = 
   231  if !run_relevance_filter andalso !pass_mark >= 0.1
   210   let val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
   232  then
   211       val goals_consts_typs = get_goal_consts_typs thy goals
   233   let val _ = Output.debug "Start of relevance filtering";
   212       val rels = relevant_clauses thy const_tab (!pass_mark) goals_consts_typs 
   234       val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
       
   235       val rels = relevant_clauses thy const_tab (!pass_mark) 
       
   236                    (get_goal_consts_typs thy goals)
   213                    (map (pair_consts_typs_axiom thy) axioms)
   237                    (map (pair_consts_typs_axiom thy) axioms)
   214   in
   238   in
   215       Output.debug ("Total relevant: " ^ Int.toString (length rels));
   239       Output.debug ("Total relevant: " ^ Int.toString (length rels));
   216       rels
   240       rels
   217   end;
   241   end
   218 
   242  else axioms;
   219 fun relevance_filter thy axioms goals =
       
   220   if !run_relevance_filter andalso !pass_mark >= 0.1
       
   221   then map #1 (relevance_filter_aux thy axioms goals)
       
   222   else axioms
       
   223 
   243 
   224 end;
   244 end;