src/HOL/Tools/monomorph.ML
author huffman
Fri Mar 30 12:32:35 2012 +0200 (2012-03-30)
changeset 47220 52426c62b5d0
parent 44717 c9cf0780cd4f
child 51575 907efc894051
permissions -rw-r--r--
replace lemmas eval_nat_numeral with a simpler reformulation
     1 (*  Title:      HOL/Tools/monomorph.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Monomorphization of theorems, i.e., computation of all (necessary)
     5 instances.  This procedure is incomplete in general, but works well for
     6 most practical problems.
     7 
     8 For a list of universally closed theorems (without schematic term
     9 variables), monomorphization computes a list of theorems with schematic
    10 term variables: all polymorphic constants (i.e., constants occurring both
    11 with ground types and schematic type variables) are instantiated with all
    12 (necessary) ground types; thereby theorems containing these constants are
    13 copied.  To prevent nontermination, there is an upper limit for the number
    14 of iterations involved in the fixpoint construction.
    15 
    16 The search for instances is performed on the constants with schematic
    17 types, which are extracted from the initial set of theorems.  The search
    18 constructs, for each theorem with those constants, a set of substitutions,
    19 which, in the end, is applied to all corresponding theorems.  Remaining
    20 schematic type variables are substituted with fresh types.
    21 
    22 Searching for necessary substitutions is an iterative fixpoint
    23 construction: each iteration computes all required instances required by
    24 the ground instances computed in the previous step and which haven't been
    25 found before.  Computed substitutions are always nontrivial: schematic type
    26 variables are never mapped to schematic type variables.
    27 *)
    28 
    29 signature MONOMORPH =
    30 sig
    31   (* utility function *)
    32   val typ_has_tvars: typ -> bool
    33   val all_schematic_consts_of: term -> typ list Symtab.table
    34   val add_schematic_consts_of: term -> typ list Symtab.table ->
    35     typ list Symtab.table
    36 
    37   (* configuration options *)
    38   val max_rounds: int Config.T
    39   val max_new_instances: int Config.T
    40   val keep_partial_instances: bool Config.T
    41 
    42   (* monomorphization *)
    43   val monomorph: (term -> typ list Symtab.table) -> (int * thm) list ->
    44     Proof.context -> (int * thm) list list * Proof.context
    45 end
    46 
    47 structure Monomorph: MONOMORPH =
    48 struct
    49 
    50 (* utility functions *)
    51 
    52 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
    53 
    54 fun add_schematic_const (c as (_, T)) =
    55   if typ_has_tvars T then Symtab.insert_list (op =) c else I
    56 
    57 fun add_schematic_consts_of t =
    58   Term.fold_aterms (fn Const c => add_schematic_const c | _ => I) t
    59 
    60 fun all_schematic_consts_of t = add_schematic_consts_of t Symtab.empty
    61 
    62 
    63 
    64 (* configuration options *)
    65 
    66 val max_rounds = Attrib.setup_config_int @{binding monomorph_max_rounds} (K 5)
    67 val max_new_instances =
    68   Attrib.setup_config_int @{binding monomorph_max_new_instances} (K 300)
    69 val keep_partial_instances =
    70   Attrib.setup_config_bool @{binding monomorph_keep_partial_instances} (K true)
    71 
    72 
    73 
    74 (* monomorphization *)
    75 
    76 (** preparing the problem **)
    77 
    78 datatype thm_info =
    79   Ground of thm |
    80   Schematic of {
    81     index: int,
    82     theorem: thm,
    83     tvars: (indexname * sort) list,
    84     schematics: typ list Symtab.table,
    85     initial_round: int }
    86 
    87 fun prepare schematic_consts_of rthms =
    88   let
    89     val empty_sub = ((0, false, false), Vartab.empty)
    90 
    91     fun prep (r, thm) ((i, idx), (consts, subs)) =
    92       if not (Term.exists_type typ_has_tvars (Thm.prop_of thm)) then
    93         (Ground thm, ((i+1, idx + Thm.maxidx_of thm + 1), (consts, subs)))
    94       else
    95         let
    96           (* increase indices to avoid clashes of type variables *)
    97           val thm' = Thm.incr_indexes idx thm
    98           val idx' = Thm.maxidx_of thm' + 1
    99           val schematics = schematic_consts_of (Thm.prop_of thm')
   100           val consts' =
   101             Symtab.fold (fn (n, _) => Symtab.update (n, [])) schematics consts
   102           val subs' = Inttab.update (i, [empty_sub]) subs
   103           val thm_info = Schematic {
   104             index = i,
   105             theorem = thm',
   106             tvars = Term.add_tvars (Thm.prop_of thm') [],
   107             schematics = schematics,
   108             initial_round = r }
   109       in (thm_info, ((i+1, idx'), (consts', subs'))) end
   110   in fold_map prep rthms ((0, 0), (Symtab.empty, Inttab.empty)) ||> snd end
   111 
   112 
   113 
   114 (** collecting substitutions **)
   115 
   116 fun exceeded limit = (limit <= 0)
   117 fun exceeded_limit (limit, _, _) = exceeded limit
   118 
   119 
   120 fun derived_subst subst' subst = subst' |> Vartab.forall (fn (n, (_, T)) => 
   121   Vartab.lookup subst n |> Option.map (equal T o snd) |> the_default false)
   122 
   123 fun eq_subst (subst1, subst2) =
   124   derived_subst subst1 subst2 andalso derived_subst subst2 subst1
   125 
   126 
   127 fun with_all_grounds cx grounds f =
   128   if exceeded_limit cx then I else Symtab.fold f grounds
   129 
   130 fun with_all_type_combinations cx schematics f (n, Ts) =
   131   if exceeded_limit cx then I
   132   else fold_product f (Symtab.lookup_list schematics n) Ts
   133 
   134 fun derive_new_substs thy cx new_grounds schematics subst =
   135   with_all_grounds cx new_grounds
   136     (with_all_type_combinations cx schematics (fn T => fn U =>
   137       (case try (Sign.typ_match thy (T, U)) subst of
   138         NONE => I
   139       | SOME subst' => insert eq_subst subst'))) []
   140 
   141 
   142 fun known_subst sub subs1 subs2 subst' =
   143   let fun derived (_, subst) = derived_subst subst' subst
   144   in derived sub orelse exists derived subs1 orelse exists derived subs2 end
   145 
   146 fun within_limit f cx = if exceeded_limit cx then cx else f cx
   147 
   148 fun fold_partial_substs derive add = within_limit (
   149   let
   150     fun fold_partial [] cx = cx
   151       | fold_partial (sub :: subs) (limit, subs', next) =
   152           if exceeded limit then (limit, sub :: subs @ subs', next)
   153           else sub |> (fn ((generation, full, _), subst) =>
   154             if full then fold_partial subs (limit, sub :: subs', next)
   155             else
   156               (case filter_out (known_subst sub subs subs') (derive subst) of
   157                 [] => fold_partial subs (limit, sub :: subs', next)
   158               | substs =>
   159                   (limit, ((generation, full, true), subst) :: subs', next)
   160                   |> fold (within_limit o add) substs
   161                   |> fold_partial subs))
   162   in (fn (limit, subs, next) => fold_partial subs (limit, [], next)) end)
   163 
   164 
   165 fun refine ctxt round known_grounds new_grounds (tvars, schematics) cx =
   166   let
   167     val thy = Proof_Context.theory_of ctxt
   168     val count_partial = Config.get ctxt keep_partial_instances
   169 
   170     fun add_new_ground subst n T =
   171       let val T' = Envir.subst_type subst T
   172       in
   173         (* FIXME: maybe keep types in a table or net for known_grounds,
   174            that might improve efficiency here
   175         *)
   176         if typ_has_tvars T' then I
   177         else if member (op =) (Symtab.lookup_list known_grounds n) T' then I
   178         else Symtab.cons_list (n, T')
   179       end
   180 
   181     fun add_new_subst subst (limit, subs, next_grounds) =
   182       let
   183         val full = forall (Vartab.defined subst o fst) tvars
   184         val limit' =
   185           if full orelse count_partial then limit - 1 else limit
   186         val sub = ((round, full, false), subst)
   187         val next_grounds' =
   188           (schematics, next_grounds)
   189           |-> Symtab.fold (uncurry (fold o add_new_ground subst))
   190       in (limit', sub :: subs, next_grounds') end
   191   in
   192     fold_partial_substs (derive_new_substs thy cx new_grounds schematics)
   193       add_new_subst cx
   194   end
   195 
   196 
   197 (*
   198   'known_grounds' are all constant names known to occur schematically
   199   associated with all ground instances considered so far
   200 *)
   201 fun add_relevant_instances known_grounds (Const (c as (n, T))) =
   202       if typ_has_tvars T orelse not (Symtab.defined known_grounds n) then I
   203       else if member (op =) (Symtab.lookup_list known_grounds n) T then I
   204       else Symtab.insert_list (op =) c
   205   | add_relevant_instances _ _ = I
   206 
   207 fun collect_instances known_grounds thm =
   208   Term.fold_aterms (add_relevant_instances known_grounds) (Thm.prop_of thm)
   209 
   210 
   211 fun make_subst_ctxt ctxt thm_infos known_grounds substitutions =
   212   let
   213     (* The total limit of returned (ground) facts is the number of facts
   214        given to the monomorphizer increased by max_new_instances.  Since
   215        initially ground facts are returned anyway, the limit here is not
   216        counting them. *)
   217     val limit = Config.get ctxt max_new_instances + 
   218       fold (fn Schematic _ => Integer.add 1 | _ => I) thm_infos 0
   219 
   220     fun add_ground_consts (Ground thm) = collect_instances known_grounds thm
   221       | add_ground_consts (Schematic _) = I
   222     val initial_grounds = fold add_ground_consts thm_infos Symtab.empty
   223   in (known_grounds, (limit, substitutions, initial_grounds)) end
   224 
   225 fun is_new round initial_round = (round = initial_round)
   226 fun is_active round initial_round = (round > initial_round)
   227 
   228 fun fold_schematic pred f = fold (fn
   229     Schematic {index, theorem, tvars, schematics, initial_round} =>
   230       if pred initial_round then f theorem (index, tvars, schematics) else I
   231   | Ground _ => I)
   232 
   233 fun focus f _ (index, tvars, schematics) (limit, subs, next_grounds) =
   234   let
   235     val (limit', isubs', next_grounds') =
   236       (limit, Inttab.lookup_list subs index, next_grounds)
   237       |> f (tvars, schematics)
   238   in (limit', Inttab.update (index, isubs') subs, next_grounds') end
   239 
   240 fun collect_substitutions thm_infos ctxt round subst_ctxt =
   241   let val (known_grounds, (limit, subs, next_grounds)) = subst_ctxt
   242   in
   243     if exceeded limit then subst_ctxt
   244     else
   245       let
   246         fun collect thm _ = collect_instances known_grounds thm
   247         val new = fold_schematic (is_new round) collect thm_infos next_grounds
   248 
   249         val known' = Symtab.merge_list (op =) (known_grounds, new)
   250         val step = focus o refine ctxt round known'
   251       in
   252         (limit, subs, Symtab.empty)
   253         |> not (Symtab.is_empty new) ?
   254             fold_schematic (is_active round) (step new) thm_infos
   255         |> fold_schematic (is_new round) (step known') thm_infos
   256         |> pair known'
   257       end
   258   end
   259 
   260 
   261 
   262 (** instantiating schematic theorems **)
   263 
   264 fun super_sort (Ground _) S = S
   265   | super_sort (Schematic {tvars, ...}) S = merge (op =) (S, maps snd tvars)
   266 
   267 fun new_super_type ctxt thm_infos =
   268   let val S = fold super_sort thm_infos @{sort type}
   269   in yield_singleton Variable.invent_types S ctxt |>> SOME o TFree end
   270 
   271 fun add_missing_tvar T (ix, S) subst =
   272   if Vartab.defined subst ix then subst
   273   else Vartab.update (ix, (S, T)) subst
   274 
   275 fun complete tvars subst T =
   276   subst
   277   |> Vartab.map (K (apsnd (Term.map_atyps (fn TVar _ => T | U => U))))
   278   |> fold (add_missing_tvar T) tvars
   279 
   280 fun instantiate_all' (mT, ctxt) subs thm_infos =
   281   let
   282     val thy = Proof_Context.theory_of ctxt
   283 
   284     fun cert (ix, (S, T)) = pairself (Thm.ctyp_of thy) (TVar (ix, S), T)
   285     fun cert' subst = Vartab.fold (cons o cert) subst []
   286     fun instantiate thm subst = Thm.instantiate (cert' subst, []) thm
   287 
   288     fun with_subst tvars f ((generation, full, _), subst) =
   289       if full then SOME (generation, f subst)
   290       else Option.map (pair generation o f o complete tvars subst) mT
   291 
   292     fun inst (Ground thm) = [(0, thm)]
   293       | inst (Schematic {theorem, tvars, index, ...}) =
   294           Inttab.lookup_list subs index
   295           |> map_filter (with_subst tvars (instantiate theorem))
   296   in (map inst thm_infos, ctxt) end
   297 
   298 fun instantiate_all ctxt thm_infos (_, (_, subs, _)) =
   299   if Config.get ctxt keep_partial_instances then
   300     let fun is_refined ((_, _, refined), _) = refined
   301     in
   302       (Inttab.map (K (filter_out is_refined)) subs, thm_infos)
   303       |-> instantiate_all' (new_super_type ctxt thm_infos)
   304     end
   305   else instantiate_all' (NONE, ctxt) subs thm_infos
   306 
   307 
   308 
   309 (** overall procedure **)
   310 
   311 fun limit_rounds ctxt f =
   312   let
   313     val max = Config.get ctxt max_rounds
   314     fun round i x = if i > max then x else round (i + 1) (f ctxt i x)
   315   in round 1 end
   316 
   317 fun monomorph schematic_consts_of rthms ctxt =
   318   let
   319     val (thm_infos, (known_grounds, subs)) = prepare schematic_consts_of rthms
   320   in
   321     if Symtab.is_empty known_grounds then
   322       (map (fn Ground thm => [(0, thm)] | _ => []) thm_infos, ctxt)
   323     else
   324       make_subst_ctxt ctxt thm_infos known_grounds subs
   325       |> limit_rounds ctxt (collect_substitutions thm_infos)
   326       |> instantiate_all ctxt thm_infos
   327   end
   328 
   329 
   330 end
   331