src/HOL/Tools/SMT/smt_monomorph.ML
author boehmes
Thu Mar 31 14:02:03 2011 +0200 (2011-03-31)
changeset 42183 173b0f488428
parent 42181 8f25605e646c
child 42190 b6b5846504cd
permissions -rw-r--r--
provide a flag controlling whether all provided facts should be instantiated, possibly inventing new types (which does not work well with Sledgehammer)
     1 (*  Title:      HOL/Tools/SMT/smt_monomorph.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Monomorphization of theorems, i.e., computation of all (necessary)
     5 instances.  This procedure is incomplete in general, but works well for
     6 most practical problems.
     7 
     8 For a list of universally closed theorems (without schematic term
     9 variables), monomorphization computes a list of theorems with schematic
    10 term variables: all polymorphic constants (i.e., constants occurring both
    11 with ground types and schematic type variables) are instantiated with all
    12 (necessary) ground types; thereby theorems containing these constants are
    13 copied.  To prevent non-termination, there is an upper limit for the number
    14 of iterations involved in the fixpoint construction.
    15 
    16 The search for instances is performed on the constants with schematic
    17 types, which are extracted from the initial set of theorems.  The search
    18 constructs, for each theorem with those constants, a set of substitutions,
    19 which, in the end, is applied to all corresponding theorems.  Remaining
    20 schematic type variables are substituted with fresh types.
    21 
    22 Searching for necessary substitutions is an iterative fixpoint
    23 construction: each iteration computes all required instances required by
    24 the ground instances computed in the previous step and which haven't been
    25 found before.  Computed substitutions are always nontrivial: schematic type
    26 variables are never mapped to schematic type variables.
    27 *)
    28 
    29 signature SMT_MONOMORPH =
    30 sig
    31   val typ_has_tvars: typ -> bool
    32   val monomorph: bool -> ('a * thm) list -> Proof.context ->
    33     ('a * thm) list * Proof.context
    34 end
    35 
    36 structure SMT_Monomorph: SMT_MONOMORPH =
    37 struct
    38 
    39 (* utility functions *)
    40 
    41 fun fold_maps f = fold (fn x => uncurry (fold_map (f x)) #>> flat)
    42 
    43 fun pair_trans ((x, y), z) = (x, (y, z))
    44 
    45 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
    46 
    47 val ignored = member (op =) [@{const_name All}, @{const_name Ex},
    48   @{const_name Let}, @{const_name If}, @{const_name HOL.eq}]
    49 
    50 fun is_const pred (n, T) = not (ignored n) andalso pred T
    51 
    52 fun collect_consts_if pred f =
    53   let
    54     fun collect (@{const trigger} $ p $ t) = collect_trigger p #> collect t
    55       | collect (t $ u) = collect t #> collect u
    56       | collect (Abs (_, _, t)) = collect t
    57       | collect (Const c) = if is_const pred c then f c else I
    58       | collect _ = I
    59     and collect_trigger t =
    60       let val dest = these o try HOLogic.dest_list 
    61       in fold (fold collect_pat o dest) (dest t) end
    62     and collect_pat (Const (@{const_name SMT.pat}, _) $ t) = collect t
    63       | collect_pat (Const (@{const_name SMT.nopat}, _) $ t) = collect t
    64       | collect_pat _ = I
    65   in collect o Thm.prop_of end
    66 
    67 val insert_const = Ord_List.insert (prod_ord fast_string_ord Term_Ord.typ_ord)
    68 
    69 fun tvar_consts_of thm = collect_consts_if typ_has_tvars insert_const thm []
    70 
    71 fun add_const_types pred =
    72   collect_consts_if pred (fn (n, T) => Symtab.map_entry n (insert (op =) T))
    73 
    74 fun incr_indexes ithms =
    75   let
    76     fun inc (i, thm) idx =
    77       ((i, Thm.incr_indexes idx thm), Thm.maxidx_of thm + idx + 1)
    78   in fst (fold_map inc ithms 0) end
    79 
    80 
    81 
    82 (* search for necessary substitutions *)
    83 
    84 fun new_substitutions thy limit grounds (n, T) subst instances =
    85   if not (typ_has_tvars T) then ([subst], instances)
    86   else
    87     Symtab.lookup_list grounds n
    88     |> map_filter (try (fn U => Sign.typ_match thy (T, U) subst))
    89     |> (fn substs => (substs, instances - length substs))
    90     |>> take limit (* limit the breadth of the search as well as the width *)
    91     |>> cons subst
    92 
    93 fun apply_subst grounds consts subst =
    94   let
    95     fun is_new_ground (n, T) = not (typ_has_tvars T) andalso
    96       not (member (op =) (Symtab.lookup_list grounds n) T)
    97 
    98     fun apply_const (n, T) new_grounds =
    99       let val c = (n, Envir.subst_type subst T)
   100       in
   101         new_grounds
   102         |> is_new_ground c ? Symtab.insert_list (op =) c
   103         |> pair c
   104       end
   105   in fold_map apply_const consts #>> pair subst end
   106 
   107 fun specialize thy limit all_grounds new_grounds scs =
   108   let
   109     fun spec (subst, consts) (next_grounds, instances) =
   110       ([subst], instances)
   111       |> fold_maps (new_substitutions thy limit new_grounds) consts
   112       |>> rpair next_grounds
   113       |>> uncurry (fold_map (apply_subst all_grounds consts))
   114       |> pair_trans
   115   in
   116     fold_map spec scs #>> (fn scss =>
   117     fold (fold (insert (eq_snd (op =)))) scss [])
   118   end
   119 
   120 val limit_reached_warning = "Warning: Monomorphization limit reached"
   121 
   122 fun search_substitutions ctxt limit instances all_grounds new_grounds scss =
   123   let
   124     val thy = ProofContext.theory_of ctxt
   125     val all_grounds' = Symtab.merge_list (op =) (all_grounds, new_grounds)
   126     val spec = specialize thy limit all_grounds' new_grounds
   127     val (scss', (new_grounds', instances')) =
   128       fold_map spec scss (Symtab.empty, instances)
   129   in
   130     if Symtab.is_empty new_grounds' then scss'
   131     else if limit > 0 andalso instances' > 0 then
   132       search_substitutions ctxt (limit-1) instances' all_grounds' new_grounds'
   133         scss'
   134     else (SMT_Config.verbose_msg ctxt (K limit_reached_warning) (); scss')
   135   end
   136 
   137 
   138 
   139 (* instantiation *)
   140 
   141 fun filter_most_specific thy =
   142   let
   143     fun typ_match (_, T) (_, U) = Sign.typ_match thy (T, U)
   144 
   145     fun is_trivial subst = Vartab.is_empty subst orelse
   146       forall (fn (v, (S, T)) => TVar (v, S) = T) (Vartab.dest subst)
   147 
   148     fun match general specific =
   149       (case try (fold2 typ_match general specific) Vartab.empty of
   150         NONE => false
   151       | SOME subst => not (is_trivial subst))
   152 
   153     fun most_specific _ [] = []
   154       | most_specific css ((ss, cs) :: scs) =
   155           let val substs = most_specific (cs :: css) scs
   156           in
   157             if exists (match cs) css orelse exists (match cs o snd) scs
   158             then substs else ss :: substs
   159           end
   160 
   161   in most_specific [] end
   162 
   163 fun instantiate full (i, thm) substs (ithms, ctxt) =
   164   let
   165     val thy = ProofContext.theory_of ctxt
   166 
   167     val (vs, Ss) = split_list (Term.add_tvars (Thm.prop_of thm) [])
   168     val (Tenv, ctxt') =
   169       ctxt
   170       |> Variable.invent_types Ss
   171       |>> map2 (fn v => fn (n, S) => (v, (S, TFree (n, S)))) vs
   172 
   173     exception PARTIAL_INST of unit
   174 
   175     fun update_subst vT subst =
   176       if full then Vartab.update vT subst
   177       else raise PARTIAL_INST ()
   178 
   179     fun replace (v, (_, T)) (U as TVar (u, _)) = if u = v then T else U
   180       | replace _ T = T
   181 
   182     fun complete (vT as (v, _)) subst =
   183       subst
   184       |> not (Vartab.defined subst v) ? update_subst vT
   185       |> Vartab.map (K (apsnd (Term.map_atyps (replace vT))))
   186 
   187     fun cert (ix, (S, T)) = pairself (Thm.ctyp_of thy) (TVar (ix, S), T)
   188 
   189     fun inst subst =
   190       let val cTs = Vartab.fold (cons o cert) (fold complete Tenv subst) []
   191       in SOME (i, Thm.instantiate (cTs, []) thm) end
   192       handle PARTIAL_INST () => NONE
   193 
   194   in (map_filter inst substs @ ithms, if full then ctxt' else ctxt) end
   195 
   196 
   197 
   198 (* overall procedure *)
   199 
   200 fun mono_all full ctxt polys monos =
   201   let
   202     val scss = map (single o pair Vartab.empty o tvar_consts_of o snd) polys
   203 
   204     (* all known non-schematic instances of polymorphic constants: find all
   205        names of polymorphic constants, then add all known ground types *)
   206     val grounds =
   207       Symtab.empty
   208       |> fold (fold (fold (Symtab.update o rpair [] o fst) o snd)) scss
   209       |> fold (add_const_types (K true) o snd) monos
   210       |> fold (add_const_types (not o typ_has_tvars) o snd) polys
   211 
   212     val limit = Config.get ctxt SMT_Config.monomorph_limit
   213     val instances = Config.get ctxt SMT_Config.monomorph_instances
   214   in
   215     scss
   216     |> search_substitutions ctxt limit instances Symtab.empty grounds
   217     |> map (filter_most_specific (ProofContext.theory_of ctxt))
   218     |> rpair (monos, ctxt)
   219     |-> fold2 (instantiate full) polys
   220   end
   221 
   222 fun monomorph full irules ctxt =
   223   irules
   224   |> List.partition (Term.exists_type typ_has_tvars o Thm.prop_of o snd)
   225   |>> incr_indexes  (* avoid clashes of schematic type variables *)
   226   |-> (fn [] => rpair ctxt | polys => mono_all full ctxt polys)
   227 
   228 end