proper nesting of loops in new monomorphizer;
authorboehmes
Tue, 31 May 2011 19:27:19 +0200
changeset 43117 5de84843685f
parent 43116 e0add071fa10
child 43118 e3c7b07704bc
proper nesting of loops in new monomorphizer; less duplication of code in new monomorphizer; drop output of warnings of new monomorphizer
src/HOL/Tools/monomorph.ML
--- a/src/HOL/Tools/monomorph.ML	Tue May 31 19:21:20 2011 +0200
+++ b/src/HOL/Tools/monomorph.ML	Tue May 31 19:27:19 2011 +0200
@@ -38,7 +38,6 @@
   val max_rounds: int Config.T
   val max_new_instances: int Config.T
   val complete_instances: bool Config.T
-  val verbose: bool Config.T
 
   (* monomorphization *)
   val monomorph: (term -> typ list Symtab.table) -> (int * thm) list ->
@@ -50,9 +49,6 @@
 
 (* utility functions *)
 
-fun fold_env _ [] y = y
-  | fold_env f (x :: xs) y = fold_env f xs (f xs x y)
-
 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
 
 fun add_schematic_const (c as (_, T)) =
@@ -72,9 +68,6 @@
   Attrib.setup_config_int @{binding monomorph_max_new_instances} (K 300)
 val complete_instances =
   Attrib.setup_config_bool @{binding monomorph_complete_instances} (K true)
-val verbose = Attrib.setup_config_bool @{binding monomorph_verbose} (K true)
-
-fun show_info ctxt msg = if Config.get ctxt verbose then tracing msg else ()
 
 
 
@@ -93,11 +86,11 @@
 
 fun prepare schematic_consts_of rthms =
   let
-    val empty_subst = ((0, false, false), Vartab.empty)
+    val empty_sub = ((0, false, false), Vartab.empty)
 
-    fun prep (r, thm) ((i, idx), (consts, substs)) =
+    fun prep (r, thm) ((i, idx), (consts, subs)) =
       if not (Term.exists_type typ_has_tvars (Thm.prop_of thm)) then
-        (Ground thm, ((i+1, idx + Thm.maxidx_of thm + 1), (consts, substs)))
+        (Ground thm, ((i+1, idx + Thm.maxidx_of thm + 1), (consts, subs)))
       else
         let
           (* increase indices to avoid clashes of type variables *)
@@ -106,20 +99,101 @@
           val schematics = schematic_consts_of (Thm.prop_of thm')
           val consts' =
             Symtab.fold (fn (n, _) => Symtab.update (n, [])) schematics consts
-          val substs' = Inttab.update (i, [empty_subst]) substs
+          val subs' = Inttab.update (i, [empty_sub]) subs
           val thm_info = Schematic {
             index = i,
             theorem = thm',
             tvars = Term.add_tvars (Thm.prop_of thm') [],
             schematics = schematics,
             initial_round = r }
-      in (thm_info, ((i+1, idx'), (consts', substs'))) end
+      in (thm_info, ((i+1, idx'), (consts', subs'))) end
   in fold_map prep rthms ((0, 0), (Symtab.empty, Inttab.empty)) ||> snd end
 
 
 
 (** collecting substitutions **)
 
+fun exceeded limit = (limit <= 0)
+fun exceeded_limit (limit, _, _) = exceeded limit
+
+
+fun with_all_grounds cx grounds f =
+  if exceeded_limit cx then I else Symtab.fold f grounds
+
+fun with_all_type_combinations cx schematics f (n, Ts) =
+  if exceeded_limit cx then I
+  else fold_product f (Symtab.lookup_list schematics n) Ts
+
+fun derive_new_substs thy cx new_grounds schematics subst =
+  with_all_grounds cx new_grounds
+    (with_all_type_combinations cx schematics (fn T => fn U =>
+      (case try (Sign.typ_match thy (T, U)) subst of
+        NONE => I
+      | SOME subst' => cons subst'))) []
+
+
+fun same_subst subst' (_, subst) = subst' |> Vartab.forall (fn (n, (_, T)) => 
+  Vartab.lookup subst n |> Option.map (equal T o snd) |> the_default false)
+
+fun known_subst sub subs1 subs2 subst =
+  same_subst subst sub orelse exists (same_subst subst) subs1 orelse
+  exists (same_subst subst) subs2
+
+fun within_limit f cx = if exceeded_limit cx then cx else f cx
+
+fun fold_partial_substs derive add = within_limit (
+  let
+    fun fold_partial [] cx = cx
+      | fold_partial (sub :: subs) (limit, subs', next) =
+          if exceeded limit then (limit, sub :: subs @ subs', next)
+          else sub |> (fn ((generation, full, _), subst) =>
+            if full then fold_partial subs (limit, sub :: subs', next)
+            else
+              (case filter_out (known_subst sub subs subs') (derive subst) of
+                [] => fold_partial subs (limit, sub :: subs', next)
+              | substs =>
+                  (limit, ((generation, full, true), subst) :: subs', next)
+                  |> fold (within_limit o add) substs
+                  |> fold_partial subs))
+  in (fn (limit, subs, next) => fold_partial subs (limit, [], next)) end)
+
+
+fun refine ctxt round known_grounds new_grounds (tvars, schematics) cx =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    val count_partial = Config.get ctxt complete_instances
+
+    fun add_new_ground subst n T =
+      let val T' = Envir.subst_type subst T
+      in
+        (* FIXME: maybe keep types in a table or net for known_grounds,
+           that might improve efficiency here
+        *)
+        if typ_has_tvars T' then I
+        else if member (op =) (Symtab.lookup_list known_grounds n) T' then I
+        else Symtab.cons_list (n, T')
+      end
+
+    fun add_new_subst subst (limit, subs, next_grounds) =
+      let
+        val full = forall (Vartab.defined subst o fst) tvars
+        val limit' =
+          if full orelse count_partial then limit - 1 else limit
+        val sub = ((round, full, false), subst)
+        val next_grounds' =
+          (schematics, next_grounds)
+          |-> Symtab.fold (uncurry (fold o add_new_ground subst))
+      in (limit', sub :: subs, next_grounds') end
+  in
+    fold_partial_substs (derive_new_substs thy cx new_grounds schematics)
+      add_new_subst cx
+  end
+
+
+(*
+  'known_grounds' are all constant names known to occur schematically
+  associated with all ground instances considered so far
+*)
 fun add_relevant_instances known_grounds (Const (c as (n, T))) =
       if typ_has_tvars T orelse not (Symtab.defined known_grounds n) then I
       else if member (op =) (Symtab.lookup_list known_grounds n) T then I
@@ -130,86 +204,6 @@
   Term.fold_aterms (add_relevant_instances known_grounds) (Thm.prop_of thm)
 
 
-fun exceeded_limit (limit, _, _) = (limit <= 0)
-
-fun with_substs index f (limit, substitutions, next_grounds) =
-  let
-    val substs = Inttab.lookup_list substitutions index
-    val (limit', substs', next_grounds') = f (limit, substs, next_grounds)
-  in (limit', Inttab.update (index, substs') substitutions, next_grounds') end
-
-fun with_grounds grounds f cx =
-  if exceeded_limit cx then cx else Symtab.fold f grounds cx
-
-fun with_all_combinations schematics f (n, Ts) cx =
-  if exceeded_limit cx then cx
-  else fold_product f (Symtab.lookup_list schematics n) Ts cx
-
-fun with_partial_substs f T U (cx as (limit, substs, next_grounds)) =
-  if exceeded_limit cx then cx
-  else fold_env (f (T, U)) substs (limit, [], next_grounds)
-
-
-fun same_subst subst =
-  Vartab.forall (fn (n, (_, T)) => 
-    Vartab.lookup subst n |> Option.map (equal T o snd) |> the_default false)
-
-(* FIXME: necessary? would it have an impact?
-   comparing substitutions can be tricky ... *)
-fun known substs1 substs2 subst = false
-
-fun refine ctxt known_grounds new_grounds info =
-  let
-    val thy = Proof_Context.theory_of ctxt
-    val count_partial = Config.get ctxt complete_instances
-    val (round, index, _, tvars, schematics) = info
-
-    fun refine_subst TU = try (Sign.typ_match thy TU)
-
-    fun add_new_ground subst n T =
-      let val T' = Envir.subst_type subst T
-      in
-        (* FIXME: maybe keep types in a table or net for known_grounds,
-           that might improve efficiency here
-        *)
-        if member (op =) (Symtab.lookup_list known_grounds n) T' then I
-        else Symtab.cons_list (n, T')
-      end
-
-    fun refine_step subst limit next_grounds substs =
-      let
-        val full = forall (Vartab.defined subst o fst) tvars
-        val limit' =
-          if full orelse count_partial then limit - 1 else limit
-        val sub = ((round, full, false), subst)
-        val next_grounds' =
-          (schematics, next_grounds)
-          |-> Symtab.fold (uncurry (fold o add_new_ground subst))
-      in (limit', sub :: substs, next_grounds') end
-
-    fun refine_substs TU substs sub (cx as (limit, substs', next_grounds)) =
-      let val ((generation, full, _), subst) = sub
-      in
-        if exceeded_limit cx orelse full then
-          (limit, sub :: substs', next_grounds)
-        else
-          (case refine_subst TU subst of
-            NONE => (limit, sub :: substs', next_grounds)
-          | SOME subst' =>
-              if (same_subst subst orf known substs substs') subst' then
-                (limit, sub :: substs', next_grounds)
-              else
-                substs'
-                |> cons ((generation, full, true), subst)
-                |> refine_step subst' limit next_grounds)
-      end
-  in
-    with_substs index (
-      with_grounds new_grounds (with_all_combinations schematics (
-        with_partial_substs refine_substs)))
-  end
-
-
 fun make_subst_ctxt ctxt thm_infos known_grounds substitutions =
   let
     val limit = Config.get ctxt max_new_instances
@@ -219,40 +213,38 @@
     val initial_grounds = fold add_ground_consts thm_infos Symtab.empty
   in (known_grounds, (limit, substitutions, initial_grounds)) end
 
-fun with_new round f thm_info =
-  (case thm_info of
+fun is_new round initial_round = (round = initial_round)
+fun is_active round initial_round = (round > initial_round)
+
+fun fold_schematic pred f = fold (fn
     Schematic {index, theorem, tvars, schematics, initial_round} =>
-      if initial_round <> round then I
-      else f (round, index, theorem, tvars, schematics)
-  | Ground _ => I)
-
-fun with_active round f thm_info =
-  (case thm_info of
-    Schematic {index, theorem, tvars, schematics, initial_round} =>
-      if initial_round < round then I
-      else f (round, index, theorem, tvars, schematics)
+      if pred initial_round then f theorem (index, tvars, schematics) else I
   | Ground _ => I)
 
-fun collect_substitutions thm_infos ctxt round (known_grounds, subst_ctxt) =
-  let val (limit, substitutions, next_grounds) = subst_ctxt
+fun focus f _ (index, tvars, schematics) (limit, subs, next_grounds) =
+  let
+    val (limit', isubs', next_grounds') =
+      (limit, Inttab.lookup_list subs index, next_grounds)
+      |> f (tvars, schematics)
+  in (limit', Inttab.update (index, isubs') subs, next_grounds') end
+
+fun collect_substitutions thm_infos ctxt round subst_ctxt =
+  let val (known_grounds, (limit, subs, next_grounds)) = subst_ctxt
   in
-    (*
-      'known_grounds' are all constant names known to occur schematically
-      associated with all ground instances considered so far
-    *)
-    if exceeded_limit subst_ctxt then (true, (known_grounds, subst_ctxt))
+    if exceeded limit then subst_ctxt
     else
       let
-        fun collect (_, _, thm, _, _) = collect_instances known_grounds thm
-        val new = fold (with_new round collect) thm_infos next_grounds
+        fun collect thm _ = collect_instances known_grounds thm
+        val new = fold_schematic (is_new round) collect thm_infos next_grounds
+
         val known' = Symtab.merge_list (op =) (known_grounds, new)
+        val step = focus o refine ctxt round known'
       in
-        if Symtab.is_empty new then (true, (known_grounds, subst_ctxt))
-        else
-          (limit, substitutions, Symtab.empty)
-          |> fold (with_active round (refine ctxt known_grounds new)) thm_infos
-          |> fold (with_new round (refine ctxt Symtab.empty known')) thm_infos
-          |> pair false o pair known'
+        (limit, subs, Symtab.empty)
+        |> not (Symtab.is_empty new) ?
+            fold_schematic (is_active round) (step new) thm_infos
+        |> fold_schematic (is_new round) (step known') thm_infos
+        |> pair known'
       end
   end
 
@@ -276,7 +268,7 @@
   |> Vartab.map (K (apsnd (Term.map_atyps (fn TVar _ => T | U => U))))
   |> fold (add_missing_tvar T) tvars
 
-fun instantiate_all' (mT, ctxt) substitutions thm_infos =
+fun instantiate_all' (mT, ctxt) subs thm_infos =
   let
     val thy = Proof_Context.theory_of ctxt
 
@@ -290,20 +282,18 @@
 
     fun inst (Ground thm) = [(0, thm)]
       | inst (Schematic {theorem, tvars, index, ...}) =
-          Inttab.lookup_list substitutions index
+          Inttab.lookup_list subs index
           |> map_filter (with_subst tvars (instantiate theorem))
   in (map inst thm_infos, ctxt) end
 
-fun instantiate_all ctxt thm_infos (_, (_, substitutions, _)) =
+fun instantiate_all ctxt thm_infos (_, (_, subs, _)) =
   if Config.get ctxt complete_instances then
-    let
-      fun refined ((_, _, true), _) = true
-        | refined _ = false
+    let fun is_refined ((_, _, refined), _) = refined
     in
-      (Inttab.map (K (filter_out refined)) substitutions, thm_infos)
+      (Inttab.map (K (filter_out is_refined)) subs, thm_infos)
       |-> instantiate_all' (new_super_type ctxt thm_infos)
     end
-  else instantiate_all' (NONE, ctxt) substitutions thm_infos
+  else instantiate_all' (NONE, ctxt) subs thm_infos
 
 
 
@@ -312,24 +302,17 @@
 fun limit_rounds ctxt f =
   let
     val max = Config.get ctxt max_rounds
-
-    fun round _ (true, x) = x
-      | round i (_, x) =
-          if i <= max then round (i + 1) (f ctxt i x)
-          else (
-            show_info ctxt "Warning: Monomorphization limit reached";
-            x)
-  in round 1 o pair false end
+    fun round i x = if i > max then x else round (i + 1) (f ctxt i x)
+  in round 1 end
 
 fun monomorph schematic_consts_of rthms ctxt =
   let
-    val (thm_infos, (known_grounds, substitutions)) =
-      prepare schematic_consts_of rthms
+    val (thm_infos, (known_grounds, subs)) = prepare schematic_consts_of rthms
   in
     if Symtab.is_empty known_grounds then
       (map (single o pair 0 o snd) rthms, ctxt)
     else
-      make_subst_ctxt ctxt thm_infos known_grounds substitutions
+      make_subst_ctxt ctxt thm_infos known_grounds subs
       |> limit_rounds ctxt (collect_substitutions thm_infos)
       |> instantiate_all ctxt thm_infos
   end