proper nesting of loops in new monomorphizer;
authorboehmes
Tue May 31 19:27:19 2011 +0200 (2011-05-31)
changeset 431175de84843685f
parent 43116 e0add071fa10
child 43118 e3c7b07704bc
proper nesting of loops in new monomorphizer;
less duplication of code in new monomorphizer;
drop output of warnings of new monomorphizer
src/HOL/Tools/monomorph.ML
     1.1 --- a/src/HOL/Tools/monomorph.ML	Tue May 31 19:21:20 2011 +0200
     1.2 +++ b/src/HOL/Tools/monomorph.ML	Tue May 31 19:27:19 2011 +0200
     1.3 @@ -38,7 +38,6 @@
     1.4    val max_rounds: int Config.T
     1.5    val max_new_instances: int Config.T
     1.6    val complete_instances: bool Config.T
     1.7 -  val verbose: bool Config.T
     1.8  
     1.9    (* monomorphization *)
    1.10    val monomorph: (term -> typ list Symtab.table) -> (int * thm) list ->
    1.11 @@ -50,9 +49,6 @@
    1.12  
    1.13  (* utility functions *)
    1.14  
    1.15 -fun fold_env _ [] y = y
    1.16 -  | fold_env f (x :: xs) y = fold_env f xs (f xs x y)
    1.17 -
    1.18  val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
    1.19  
    1.20  fun add_schematic_const (c as (_, T)) =
    1.21 @@ -72,9 +68,6 @@
    1.22    Attrib.setup_config_int @{binding monomorph_max_new_instances} (K 300)
    1.23  val complete_instances =
    1.24    Attrib.setup_config_bool @{binding monomorph_complete_instances} (K true)
    1.25 -val verbose = Attrib.setup_config_bool @{binding monomorph_verbose} (K true)
    1.26 -
    1.27 -fun show_info ctxt msg = if Config.get ctxt verbose then tracing msg else ()
    1.28  
    1.29  
    1.30  
    1.31 @@ -93,11 +86,11 @@
    1.32  
    1.33  fun prepare schematic_consts_of rthms =
    1.34    let
    1.35 -    val empty_subst = ((0, false, false), Vartab.empty)
    1.36 +    val empty_sub = ((0, false, false), Vartab.empty)
    1.37  
    1.38 -    fun prep (r, thm) ((i, idx), (consts, substs)) =
    1.39 +    fun prep (r, thm) ((i, idx), (consts, subs)) =
    1.40        if not (Term.exists_type typ_has_tvars (Thm.prop_of thm)) then
    1.41 -        (Ground thm, ((i+1, idx + Thm.maxidx_of thm + 1), (consts, substs)))
    1.42 +        (Ground thm, ((i+1, idx + Thm.maxidx_of thm + 1), (consts, subs)))
    1.43        else
    1.44          let
    1.45            (* increase indices to avoid clashes of type variables *)
    1.46 @@ -106,20 +99,101 @@
    1.47            val schematics = schematic_consts_of (Thm.prop_of thm')
    1.48            val consts' =
    1.49              Symtab.fold (fn (n, _) => Symtab.update (n, [])) schematics consts
    1.50 -          val substs' = Inttab.update (i, [empty_subst]) substs
    1.51 +          val subs' = Inttab.update (i, [empty_sub]) subs
    1.52            val thm_info = Schematic {
    1.53              index = i,
    1.54              theorem = thm',
    1.55              tvars = Term.add_tvars (Thm.prop_of thm') [],
    1.56              schematics = schematics,
    1.57              initial_round = r }
    1.58 -      in (thm_info, ((i+1, idx'), (consts', substs'))) end
    1.59 +      in (thm_info, ((i+1, idx'), (consts', subs'))) end
    1.60    in fold_map prep rthms ((0, 0), (Symtab.empty, Inttab.empty)) ||> snd end
    1.61  
    1.62  
    1.63  
    1.64  (** collecting substitutions **)
    1.65  
    1.66 +fun exceeded limit = (limit <= 0)
    1.67 +fun exceeded_limit (limit, _, _) = exceeded limit
    1.68 +
    1.69 +
    1.70 +fun with_all_grounds cx grounds f =
    1.71 +  if exceeded_limit cx then I else Symtab.fold f grounds
    1.72 +
    1.73 +fun with_all_type_combinations cx schematics f (n, Ts) =
    1.74 +  if exceeded_limit cx then I
    1.75 +  else fold_product f (Symtab.lookup_list schematics n) Ts
    1.76 +
    1.77 +fun derive_new_substs thy cx new_grounds schematics subst =
    1.78 +  with_all_grounds cx new_grounds
    1.79 +    (with_all_type_combinations cx schematics (fn T => fn U =>
    1.80 +      (case try (Sign.typ_match thy (T, U)) subst of
    1.81 +        NONE => I
    1.82 +      | SOME subst' => cons subst'))) []
    1.83 +
    1.84 +
    1.85 +fun same_subst subst' (_, subst) = subst' |> Vartab.forall (fn (n, (_, T)) => 
    1.86 +  Vartab.lookup subst n |> Option.map (equal T o snd) |> the_default false)
    1.87 +
    1.88 +fun known_subst sub subs1 subs2 subst =
    1.89 +  same_subst subst sub orelse exists (same_subst subst) subs1 orelse
    1.90 +  exists (same_subst subst) subs2
    1.91 +
    1.92 +fun within_limit f cx = if exceeded_limit cx then cx else f cx
    1.93 +
    1.94 +fun fold_partial_substs derive add = within_limit (
    1.95 +  let
    1.96 +    fun fold_partial [] cx = cx
    1.97 +      | fold_partial (sub :: subs) (limit, subs', next) =
    1.98 +          if exceeded limit then (limit, sub :: subs @ subs', next)
    1.99 +          else sub |> (fn ((generation, full, _), subst) =>
   1.100 +            if full then fold_partial subs (limit, sub :: subs', next)
   1.101 +            else
   1.102 +              (case filter_out (known_subst sub subs subs') (derive subst) of
   1.103 +                [] => fold_partial subs (limit, sub :: subs', next)
   1.104 +              | substs =>
   1.105 +                  (limit, ((generation, full, true), subst) :: subs', next)
   1.106 +                  |> fold (within_limit o add) substs
   1.107 +                  |> fold_partial subs))
   1.108 +  in (fn (limit, subs, next) => fold_partial subs (limit, [], next)) end)
   1.109 +
   1.110 +
   1.111 +fun refine ctxt round known_grounds new_grounds (tvars, schematics) cx =
   1.112 +  let
   1.113 +    val thy = Proof_Context.theory_of ctxt
   1.114 +    val count_partial = Config.get ctxt complete_instances
   1.115 +
   1.116 +    fun add_new_ground subst n T =
   1.117 +      let val T' = Envir.subst_type subst T
   1.118 +      in
   1.119 +        (* FIXME: maybe keep types in a table or net for known_grounds,
   1.120 +           that might improve efficiency here
   1.121 +        *)
   1.122 +        if typ_has_tvars T' then I
   1.123 +        else if member (op =) (Symtab.lookup_list known_grounds n) T' then I
   1.124 +        else Symtab.cons_list (n, T')
   1.125 +      end
   1.126 +
   1.127 +    fun add_new_subst subst (limit, subs, next_grounds) =
   1.128 +      let
   1.129 +        val full = forall (Vartab.defined subst o fst) tvars
   1.130 +        val limit' =
   1.131 +          if full orelse count_partial then limit - 1 else limit
   1.132 +        val sub = ((round, full, false), subst)
   1.133 +        val next_grounds' =
   1.134 +          (schematics, next_grounds)
   1.135 +          |-> Symtab.fold (uncurry (fold o add_new_ground subst))
   1.136 +      in (limit', sub :: subs, next_grounds') end
   1.137 +  in
   1.138 +    fold_partial_substs (derive_new_substs thy cx new_grounds schematics)
   1.139 +      add_new_subst cx
   1.140 +  end
   1.141 +
   1.142 +
   1.143 +(*
   1.144 +  'known_grounds' are all constant names known to occur schematically
   1.145 +  associated with all ground instances considered so far
   1.146 +*)
   1.147  fun add_relevant_instances known_grounds (Const (c as (n, T))) =
   1.148        if typ_has_tvars T orelse not (Symtab.defined known_grounds n) then I
   1.149        else if member (op =) (Symtab.lookup_list known_grounds n) T then I
   1.150 @@ -130,86 +204,6 @@
   1.151    Term.fold_aterms (add_relevant_instances known_grounds) (Thm.prop_of thm)
   1.152  
   1.153  
   1.154 -fun exceeded_limit (limit, _, _) = (limit <= 0)
   1.155 -
   1.156 -fun with_substs index f (limit, substitutions, next_grounds) =
   1.157 -  let
   1.158 -    val substs = Inttab.lookup_list substitutions index
   1.159 -    val (limit', substs', next_grounds') = f (limit, substs, next_grounds)
   1.160 -  in (limit', Inttab.update (index, substs') substitutions, next_grounds') end
   1.161 -
   1.162 -fun with_grounds grounds f cx =
   1.163 -  if exceeded_limit cx then cx else Symtab.fold f grounds cx
   1.164 -
   1.165 -fun with_all_combinations schematics f (n, Ts) cx =
   1.166 -  if exceeded_limit cx then cx
   1.167 -  else fold_product f (Symtab.lookup_list schematics n) Ts cx
   1.168 -
   1.169 -fun with_partial_substs f T U (cx as (limit, substs, next_grounds)) =
   1.170 -  if exceeded_limit cx then cx
   1.171 -  else fold_env (f (T, U)) substs (limit, [], next_grounds)
   1.172 -
   1.173 -
   1.174 -fun same_subst subst =
   1.175 -  Vartab.forall (fn (n, (_, T)) => 
   1.176 -    Vartab.lookup subst n |> Option.map (equal T o snd) |> the_default false)
   1.177 -
   1.178 -(* FIXME: necessary? would it have an impact?
   1.179 -   comparing substitutions can be tricky ... *)
   1.180 -fun known substs1 substs2 subst = false
   1.181 -
   1.182 -fun refine ctxt known_grounds new_grounds info =
   1.183 -  let
   1.184 -    val thy = Proof_Context.theory_of ctxt
   1.185 -    val count_partial = Config.get ctxt complete_instances
   1.186 -    val (round, index, _, tvars, schematics) = info
   1.187 -
   1.188 -    fun refine_subst TU = try (Sign.typ_match thy TU)
   1.189 -
   1.190 -    fun add_new_ground subst n T =
   1.191 -      let val T' = Envir.subst_type subst T
   1.192 -      in
   1.193 -        (* FIXME: maybe keep types in a table or net for known_grounds,
   1.194 -           that might improve efficiency here
   1.195 -        *)
   1.196 -        if member (op =) (Symtab.lookup_list known_grounds n) T' then I
   1.197 -        else Symtab.cons_list (n, T')
   1.198 -      end
   1.199 -
   1.200 -    fun refine_step subst limit next_grounds substs =
   1.201 -      let
   1.202 -        val full = forall (Vartab.defined subst o fst) tvars
   1.203 -        val limit' =
   1.204 -          if full orelse count_partial then limit - 1 else limit
   1.205 -        val sub = ((round, full, false), subst)
   1.206 -        val next_grounds' =
   1.207 -          (schematics, next_grounds)
   1.208 -          |-> Symtab.fold (uncurry (fold o add_new_ground subst))
   1.209 -      in (limit', sub :: substs, next_grounds') end
   1.210 -
   1.211 -    fun refine_substs TU substs sub (cx as (limit, substs', next_grounds)) =
   1.212 -      let val ((generation, full, _), subst) = sub
   1.213 -      in
   1.214 -        if exceeded_limit cx orelse full then
   1.215 -          (limit, sub :: substs', next_grounds)
   1.216 -        else
   1.217 -          (case refine_subst TU subst of
   1.218 -            NONE => (limit, sub :: substs', next_grounds)
   1.219 -          | SOME subst' =>
   1.220 -              if (same_subst subst orf known substs substs') subst' then
   1.221 -                (limit, sub :: substs', next_grounds)
   1.222 -              else
   1.223 -                substs'
   1.224 -                |> cons ((generation, full, true), subst)
   1.225 -                |> refine_step subst' limit next_grounds)
   1.226 -      end
   1.227 -  in
   1.228 -    with_substs index (
   1.229 -      with_grounds new_grounds (with_all_combinations schematics (
   1.230 -        with_partial_substs refine_substs)))
   1.231 -  end
   1.232 -
   1.233 -
   1.234  fun make_subst_ctxt ctxt thm_infos known_grounds substitutions =
   1.235    let
   1.236      val limit = Config.get ctxt max_new_instances
   1.237 @@ -219,40 +213,38 @@
   1.238      val initial_grounds = fold add_ground_consts thm_infos Symtab.empty
   1.239    in (known_grounds, (limit, substitutions, initial_grounds)) end
   1.240  
   1.241 -fun with_new round f thm_info =
   1.242 -  (case thm_info of
   1.243 +fun is_new round initial_round = (round = initial_round)
   1.244 +fun is_active round initial_round = (round > initial_round)
   1.245 +
   1.246 +fun fold_schematic pred f = fold (fn
   1.247      Schematic {index, theorem, tvars, schematics, initial_round} =>
   1.248 -      if initial_round <> round then I
   1.249 -      else f (round, index, theorem, tvars, schematics)
   1.250 -  | Ground _ => I)
   1.251 -
   1.252 -fun with_active round f thm_info =
   1.253 -  (case thm_info of
   1.254 -    Schematic {index, theorem, tvars, schematics, initial_round} =>
   1.255 -      if initial_round < round then I
   1.256 -      else f (round, index, theorem, tvars, schematics)
   1.257 +      if pred initial_round then f theorem (index, tvars, schematics) else I
   1.258    | Ground _ => I)
   1.259  
   1.260 -fun collect_substitutions thm_infos ctxt round (known_grounds, subst_ctxt) =
   1.261 -  let val (limit, substitutions, next_grounds) = subst_ctxt
   1.262 +fun focus f _ (index, tvars, schematics) (limit, subs, next_grounds) =
   1.263 +  let
   1.264 +    val (limit', isubs', next_grounds') =
   1.265 +      (limit, Inttab.lookup_list subs index, next_grounds)
   1.266 +      |> f (tvars, schematics)
   1.267 +  in (limit', Inttab.update (index, isubs') subs, next_grounds') end
   1.268 +
   1.269 +fun collect_substitutions thm_infos ctxt round subst_ctxt =
   1.270 +  let val (known_grounds, (limit, subs, next_grounds)) = subst_ctxt
   1.271    in
   1.272 -    (*
   1.273 -      'known_grounds' are all constant names known to occur schematically
   1.274 -      associated with all ground instances considered so far
   1.275 -    *)
   1.276 -    if exceeded_limit subst_ctxt then (true, (known_grounds, subst_ctxt))
   1.277 +    if exceeded limit then subst_ctxt
   1.278      else
   1.279        let
   1.280 -        fun collect (_, _, thm, _, _) = collect_instances known_grounds thm
   1.281 -        val new = fold (with_new round collect) thm_infos next_grounds
   1.282 +        fun collect thm _ = collect_instances known_grounds thm
   1.283 +        val new = fold_schematic (is_new round) collect thm_infos next_grounds
   1.284 +
   1.285          val known' = Symtab.merge_list (op =) (known_grounds, new)
   1.286 +        val step = focus o refine ctxt round known'
   1.287        in
   1.288 -        if Symtab.is_empty new then (true, (known_grounds, subst_ctxt))
   1.289 -        else
   1.290 -          (limit, substitutions, Symtab.empty)
   1.291 -          |> fold (with_active round (refine ctxt known_grounds new)) thm_infos
   1.292 -          |> fold (with_new round (refine ctxt Symtab.empty known')) thm_infos
   1.293 -          |> pair false o pair known'
   1.294 +        (limit, subs, Symtab.empty)
   1.295 +        |> not (Symtab.is_empty new) ?
   1.296 +            fold_schematic (is_active round) (step new) thm_infos
   1.297 +        |> fold_schematic (is_new round) (step known') thm_infos
   1.298 +        |> pair known'
   1.299        end
   1.300    end
   1.301  
   1.302 @@ -276,7 +268,7 @@
   1.303    |> Vartab.map (K (apsnd (Term.map_atyps (fn TVar _ => T | U => U))))
   1.304    |> fold (add_missing_tvar T) tvars
   1.305  
   1.306 -fun instantiate_all' (mT, ctxt) substitutions thm_infos =
   1.307 +fun instantiate_all' (mT, ctxt) subs thm_infos =
   1.308    let
   1.309      val thy = Proof_Context.theory_of ctxt
   1.310  
   1.311 @@ -290,20 +282,18 @@
   1.312  
   1.313      fun inst (Ground thm) = [(0, thm)]
   1.314        | inst (Schematic {theorem, tvars, index, ...}) =
   1.315 -          Inttab.lookup_list substitutions index
   1.316 +          Inttab.lookup_list subs index
   1.317            |> map_filter (with_subst tvars (instantiate theorem))
   1.318    in (map inst thm_infos, ctxt) end
   1.319  
   1.320 -fun instantiate_all ctxt thm_infos (_, (_, substitutions, _)) =
   1.321 +fun instantiate_all ctxt thm_infos (_, (_, subs, _)) =
   1.322    if Config.get ctxt complete_instances then
   1.323 -    let
   1.324 -      fun refined ((_, _, true), _) = true
   1.325 -        | refined _ = false
   1.326 +    let fun is_refined ((_, _, refined), _) = refined
   1.327      in
   1.328 -      (Inttab.map (K (filter_out refined)) substitutions, thm_infos)
   1.329 +      (Inttab.map (K (filter_out is_refined)) subs, thm_infos)
   1.330        |-> instantiate_all' (new_super_type ctxt thm_infos)
   1.331      end
   1.332 -  else instantiate_all' (NONE, ctxt) substitutions thm_infos
   1.333 +  else instantiate_all' (NONE, ctxt) subs thm_infos
   1.334  
   1.335  
   1.336  
   1.337 @@ -312,24 +302,17 @@
   1.338  fun limit_rounds ctxt f =
   1.339    let
   1.340      val max = Config.get ctxt max_rounds
   1.341 -
   1.342 -    fun round _ (true, x) = x
   1.343 -      | round i (_, x) =
   1.344 -          if i <= max then round (i + 1) (f ctxt i x)
   1.345 -          else (
   1.346 -            show_info ctxt "Warning: Monomorphization limit reached";
   1.347 -            x)
   1.348 -  in round 1 o pair false end
   1.349 +    fun round i x = if i > max then x else round (i + 1) (f ctxt i x)
   1.350 +  in round 1 end
   1.351  
   1.352  fun monomorph schematic_consts_of rthms ctxt =
   1.353    let
   1.354 -    val (thm_infos, (known_grounds, substitutions)) =
   1.355 -      prepare schematic_consts_of rthms
   1.356 +    val (thm_infos, (known_grounds, subs)) = prepare schematic_consts_of rthms
   1.357    in
   1.358      if Symtab.is_empty known_grounds then
   1.359        (map (single o pair 0 o snd) rthms, ctxt)
   1.360      else
   1.361 -      make_subst_ctxt ctxt thm_infos known_grounds substitutions
   1.362 +      make_subst_ctxt ctxt thm_infos known_grounds subs
   1.363        |> limit_rounds ctxt (collect_substitutions thm_infos)
   1.364        |> instantiate_all ctxt thm_infos
   1.365    end