src/HOL/Tools/ATP/reduce_axiomsN.ML
changeset 19334 96ca738055a6
parent 19321 30b5bb35dd33
child 19335 9e82f341a71b
equal deleted inserted replaced
19333:99dbefd7bc2e 19334:96ca738055a6
     5 structure ReduceAxiomsN =
     5 structure ReduceAxiomsN =
     6 struct
     6 struct
     7 
     7 
     8 val pass_mark = ref 0.6;
     8 val pass_mark = ref 0.6;
     9 val reduction_factor = ref 1.0;
     9 val reduction_factor = ref 1.0;
    10 
    10 val convergence = ref 4.0;   (*Higher numbers allow longer inference chains*)
    11 (*Whether all "simple" unit clauses should be included*)
    11 
       
    12 (*FIXME DELETE Whether all "simple" unit clauses should be included*)
    12 val add_unit = ref false;
    13 val add_unit = ref false;
    13 val unit_pass_mark = ref 0.0;
    14 val unit_pass_mark = ref 0.0;
       
    15 
       
    16 (*The default ignores the constant-count and gives the old Strategy 3*)
       
    17 val weight_fn = ref (fn x : real => 1.0);
    14 
    18 
    15 
    19 
    16 (*Including equality in this list might be expected to stop rules like subset_antisym from
    20 (*Including equality in this list might be expected to stop rules like subset_antisym from
    17   being chosen, but for some reason filtering works better with them listed.*)
    21   being chosen, but for some reason filtering works better with them listed.*)
    18 val standard_consts =
    22 val standard_consts =
    86 
    90 
    87 fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs)
    91 fun get_goal_consts_typs thy cs = foldl (op union) [] (map (consts_typs_of_term thy) cs)
    88 
    92 
    89 
    93 
    90 (**** Constant / Type Frequencies ****)
    94 (**** Constant / Type Frequencies ****)
    91 
       
    92 
    95 
    93 local
    96 local
    94 
    97 
    95 fun cons_nr CTVar = 0
    98 fun cons_nr CTVar = 0
    96   | cons_nr (CType _) = 1;
    99   | cons_nr (CType _) = 1;
   124   in  count_term_consts (t, tab)  end;
   127   in  count_term_consts (t, tab)  end;
   125 
   128 
   126 
   129 
   127 (******** filter clauses ********)
   130 (******** filter clauses ********)
   128 
   131 
   129 (*The default ignores the constant-count and gives the old Strategy 3*)
       
   130 val weight_fn = ref (fn x : real => 1.0);
       
   131 
       
   132 fun const_weight ctab (c, cts) =
   132 fun const_weight ctab (c, cts) =
   133   let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
   133   let val pairs = CTtab.dest (Option.valOf (Symtab.lookup ctab c))
   134       fun add ((cts',m), n) = if uni_types cts cts' then m+n else n
   134       fun add ((cts',m), n) = if uni_types cts cts' then m+n else n
   135   in  List.foldl add 0 pairs  end;
   135   in  List.foldl add 0 pairs  end;
   136 
   136 
   148         val rel_weight = consts_typs_weight ctab rel
   148         val rel_weight = consts_typs_weight ctab rel
   149     in
   149     in
   150 	rel_weight / (rel_weight + real (length consts_typs - length rel))
   150 	rel_weight / (rel_weight + real (length consts_typs - length rel))
   151     end;
   151     end;
   152     
   152     
   153 fun relevant_clauses ctab rel_axs [] (addc,tmpc) keep =
       
   154       if null addc orelse null tmpc 
       
   155       then (addc @ rel_axs @ keep, tmpc)   (*termination!*)
       
   156       else relevant_clauses ctab addc tmpc ([],[]) (rel_axs @ keep)
       
   157   | relevant_clauses ctab rel_axs ((clstm,(consts_typs,w))::e_axs) (addc,tmpc) keep =
       
   158       let fun clause_weight_ax (_,(refconsts_typs,wa)) =
       
   159               wa * clause_weight ctab refconsts_typs consts_typs;
       
   160           val weight' = List.foldl Real.max w (map clause_weight_ax rel_axs)
       
   161 	  val e_ax' = (clstm, (consts_typs,weight'))
       
   162       in
       
   163 	if !pass_mark <= weight' 
       
   164 	then relevant_clauses ctab rel_axs e_axs (e_ax'::addc, tmpc) keep
       
   165 	else relevant_clauses ctab rel_axs e_axs (addc, e_ax'::tmpc) keep
       
   166       end;
       
   167 
       
   168 fun pair_consts_typs_axiom thy (tm,name) =
   153 fun pair_consts_typs_axiom thy (tm,name) =
   169     ((tm,name), (consts_typs_of_term thy tm));
   154     ((tm,name), (consts_typs_of_term thy tm));
       
   155 
       
   156 fun relevant_clauses ctab p rel_consts =
       
   157   let fun relevant (newrels,rejects) []  =
       
   158 	    if null newrels then [] 
       
   159 	    else 
       
   160 	      let val new_consts = map #2 newrels
       
   161 	          val rel_consts' = foldl (op union) rel_consts new_consts
       
   162                   val newp = p + (1.0-p) / !convergence
       
   163 	      in Output.debug ("found relevant: " ^ Int.toString (length newrels));
       
   164                  newrels @ relevant_clauses ctab newp rel_consts' rejects
       
   165 	      end
       
   166 	| relevant (newrels,rejects) ((ax as (clstm,consts_typs)) :: axs) =
       
   167 	    let val weight = clause_weight ctab rel_consts consts_typs
       
   168 	    in
       
   169 	      if p <= weight 
       
   170 	      then relevant (ax::newrels, rejects) axs
       
   171 	      else relevant (newrels, ax::rejects) axs
       
   172 	    end
       
   173     in  Output.debug ("relevant_clauses: " ^ Real.toString p);
       
   174         relevant ([],[]) end;
   170 
   175 
   171 (*Unit clauses other than non-trivial equations can be included subject to
   176 (*Unit clauses other than non-trivial equations can be included subject to
   172   a separate (presumably lower) mark. *)
   177   a separate (presumably lower) mark. *)
   173 fun good_unit_clause ((t,_), (_,w)) = 
   178 fun good_unit_clause ((t,_), (_,w)) = 
   174      !unit_pass_mark <= w andalso
   179      !unit_pass_mark <= w andalso
   175      (case clause_kind t of
   180      (case clause_kind t of
   176 	  Unit_neq => true
   181 	  Unit_neq => true
   177 	| Unit_geq => true
   182 	| Unit_geq => true
   178 	| Other => false);
   183 	| Other => false);
   179 	
   184 	
   180 fun axiom_ord ((_,(_,w1)), (_,(_,w2))) = Real.compare (w2,w1);
       
   181 
       
   182 fun showconst (c,cttab) = 
   185 fun showconst (c,cttab) = 
   183       List.app (fn n => Output.debug (Int.toString n ^ " occurrences of " ^ c))
   186       List.app (fn n => Output.debug (Int.toString n ^ " occurrences of " ^ c))
   184 	        (map #2 (CTtab.dest cttab))
   187 	        (map #2 (CTtab.dest cttab))
   185 
       
   186 fun show_cname (name,k) = name ^ "__" ^ Int.toString k;
       
   187 
       
   188 fun showax ((_,cname), (_,w)) = 
       
   189     Output.debug ("Axiom " ^ show_cname cname ^ " has weight " ^ Real.toString w)
       
   190 	      
   188 	      
   191 exception ConstFree;
   189 exception ConstFree;
   192 fun dest_ConstFree (Const aT) = aT
   190 fun dest_ConstFree (Const aT) = aT
   193   | dest_ConstFree (Free aT) = aT
   191   | dest_ConstFree (Free aT) = aT
   194   | dest_ConstFree _ = raise ConstFree;
   192   | dest_ConstFree _ = raise ConstFree;
   208   end
   206   end
   209 
   207 
   210 fun relevance_filter_aux thy axioms goals = 
   208 fun relevance_filter_aux thy axioms goals = 
   211   let val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
   209   let val const_tab = List.foldl (count_axiom_consts thy) Symtab.empty axioms
   212       val goals_consts_typs = get_goal_consts_typs thy goals
   210       val goals_consts_typs = get_goal_consts_typs thy goals
   213       fun relevant [] (rels,nonrels) = (rels,nonrels)
   211       val rels = relevant_clauses const_tab (!pass_mark) goals_consts_typs 
   214 	| relevant ((clstm,consts_typs)::axs) (rels,nonrels) =
   212                    (map (pair_consts_typs_axiom thy) axioms)
   215 	    let val weight = clause_weight const_tab goals_consts_typs consts_typs
       
   216 		val ccc = (clstm, (consts_typs,weight))
       
   217 	    in
       
   218 	      if !pass_mark <= weight orelse defines thy clstm goals_consts_typs
       
   219 	      then relevant axs (ccc::rels, nonrels)
       
   220 	      else relevant axs (rels, ccc::nonrels)
       
   221 	    end
       
   222       val (rel_clauses,nrel_clauses) =
       
   223 	  relevant (map (pair_consts_typs_axiom thy) axioms) ([],[]) 
       
   224       val (rels,nonrels) = relevant_clauses const_tab rel_clauses nrel_clauses ([],[]) []
       
   225       val max_filtered = floor (!reduction_factor * real (length rels))
       
   226       val rels' = Library.take(max_filtered, Library.sort axiom_ord rels)
       
   227   in
   213   in
   228       if !Output.show_debug_msgs then
   214       Output.debug ("Total relevant: " ^ Int.toString (length rels));
   229 	   (List.app showconst (Symtab.dest const_tab);
   215       rels
   230 	    List.app showax rels)
       
   231       else ();
       
   232       if !add_unit then (filter good_unit_clause nonrels) @ rels'
       
   233       else rels'
       
   234   end;
   216   end;
   235 
   217 
   236 fun relevance_filter thy axioms goals =
   218 fun relevance_filter thy axioms goals =
   237   map #1 (relevance_filter_aux thy axioms goals);
   219   map #1 (relevance_filter_aux thy axioms goals);
   238     
   220