when "max_thm_instances" is hit, choose more carefully which instances should be kept
authorblanchet
Tue, 24 Sep 2013 16:21:03 +0200
changeset 53823 191ec7f873d5
parent 53822 6304b12c7627
child 53824 b81cea96a85e
when "max_thm_instances" is hit, choose more carefully which instances should be kept
src/HOL/Tools/monomorph.ML
--- a/src/HOL/Tools/monomorph.ML	Tue Sep 24 15:16:59 2013 +0200
+++ b/src/HOL/Tools/monomorph.ML	Tue Sep 24 16:21:03 2013 +0200
@@ -177,11 +177,11 @@
   in Term.fold_aterms add (Thm.prop_of thm) end
 
 
-fun add_insts max_instances max_thm_instances ctxt round used_grounds
+fun add_insts max_instances max_thm_insts ctxt round used_grounds
     new_grounds id thm tvars schematics cx =
   let
     exception ENOUGH of
-      typ list Symtab.table * (int * (int * thm) list Inttab.table)
+      typ list Symtab.table * (int * ((int * (sort * typ) Vartab.table) * thm) list Inttab.table)
 
     val thy = Proof_Context.theory_of ctxt
 
@@ -191,14 +191,14 @@
       else
         let
           val thm' = instantiate thy subst thm
-          val rthm = (round, thm')
+          val rthm = ((round, subst), thm')
           val rthms = Inttab.lookup_list insts id;
           val n_insts' =
-            if member (eq_snd Thm.eq_thm) rthms rthm orelse
-               length rthms >= max_thm_instances then
+            if member (eq_snd Thm.eq_thm) rthms rthm then
               (n, insts)
             else
-              (n + 1, Inttab.cons_list (id, rthm) insts)
+              (if length rthms >= max_thm_insts then n else n + 1,
+               Inttab.cons_list (id, rthm) insts)
           val next_grounds' =
             add_new_grounds used_grounds new_grounds thm' next_grounds
         in (next_grounds', n_insts') end
@@ -247,10 +247,10 @@
 fun is_active round initial_round = (round > initial_round)
 
 
-fun find_instances max_instances max_thm_instances thm_infos ctxt round
+fun find_instances max_instances max_thm_insts thm_infos ctxt round
     (known_grounds, new_grounds, insts) =
   let
-    val add_new = add_insts max_instances max_thm_instances ctxt round
+    val add_new = add_insts max_instances max_thm_insts ctxt round
     fun consider_all pred f (cx as (_, (n, _))) =
       if n >= max_instances then cx else fold_schematics pred f thm_infos cx
 
@@ -270,17 +270,15 @@
   in Term.fold_aterms (fn Const c => add c | _ => I) (Thm.prop_of thm) end
 
 
-fun collect_instances ctxt thm_infos consts =
+fun collect_instances ctxt max_thm_insts thm_infos consts =
   let
     val known_grounds = fold_grounds add_ground_types thm_infos consts
     val empty_grounds = clear_grounds known_grounds
     val max_instances = Config.get ctxt max_new_instances
       |> fold (fn Schematic _ => Integer.add 1 | _ => I) thm_infos
-    val max_thm_instances = Config.get ctxt max_thm_instances
   in
     (empty_grounds, known_grounds, (0, Inttab.empty))
-    |> limit_rounds ctxt
-      (find_instances max_instances max_thm_instances thm_infos)
+    |> limit_rounds ctxt (find_instances max_instances max_thm_insts thm_infos)
     |> (fn (_, _, (_, insts)) => insts)
   end
 
@@ -288,17 +286,28 @@
 
 (* monomorphization *)
 
-fun instantiated_thms _ (Ground thm) = [(0, thm)]
-  | instantiated_thms _ Ignored = []
-  | instantiated_thms insts (Schematic {id, ...}) = Inttab.lookup_list insts id
+
+fun size_of_subst subst =
+  Vartab.fold (Integer.add o size_of_typ o snd o snd) subst 0
+
+val subst_ord = int_ord o pairself size_of_subst
 
+fun instantiated_thms _ _ (Ground thm) = [(0, thm)]
+  | instantiated_thms _ _ Ignored = []
+  | instantiated_thms max_thm_insts insts (Schematic {id, ...}) =
+    Inttab.lookup_list insts id
+    |> (fn rthms => if length rthms <= max_thm_insts then rthms
+      else take max_thm_insts
+        (sort (prod_ord int_ord subst_ord o pairself fst) rthms))
+    |> map (apfst fst)
 
 fun monomorph schematic_consts_of ctxt rthms =
   let
+    val max_thm_insts = Config.get ctxt max_thm_instances
     val (thm_infos, consts) = prepare schematic_consts_of rthms
     val insts =
       if Symtab.is_empty consts then Inttab.empty
-      else collect_instances ctxt thm_infos consts
-  in map (instantiated_thms insts) thm_infos end
+      else collect_instances ctxt max_thm_insts thm_infos consts
+  in map (instantiated_thms max_thm_insts insts) thm_infos end
 
 end