src/HOL/Tools/SMT/smt_monomorph.ML
author wenzelm
Tue, 05 Apr 2011 14:25:18 +0200
changeset 42224 578a51fae383
parent 42190 b6b5846504cd
child 42361 23f352990944
permissions -rw-r--r--
discontinued special treatment of structure Ast: no pervasive content, no inclusion in structure Syntax;

(*  Title:      HOL/Tools/SMT/smt_monomorph.ML
    Author:     Sascha Boehme, TU Muenchen

Monomorphization of theorems, i.e., computation of all (necessary)
instances.  This procedure is incomplete in general, but works well for
most practical problems.

For a list of universally closed theorems (without schematic term
variables), monomorphization computes a list of theorems with schematic
term variables: all polymorphic constants (i.e., constants occurring both
with ground types and schematic type variables) are instantiated with all
(necessary) ground types; thereby theorems containing these constants are
copied.  To prevent non-termination, there is an upper limit for the number
of iterations involved in the fixpoint construction.

The search for instances is performed on the constants with schematic
types, which are extracted from the initial set of theorems.  The search
constructs, for each theorem with those constants, a set of substitutions,
which, in the end, is applied to all corresponding theorems.  Remaining
schematic type variables are substituted with fresh types.

Searching for necessary substitutions is an iterative fixpoint
construction: each iteration computes all required instances required by
the ground instances computed in the previous step and which haven't been
found before.  Computed substitutions are always nontrivial: schematic type
variables are never mapped to schematic type variables.
*)

signature SMT_MONOMORPH =
sig
  val typ_has_tvars: typ -> bool
  val monomorph: ('a * thm) list -> Proof.context ->
    ('a * thm) list * Proof.context
end

structure SMT_Monomorph: SMT_MONOMORPH =
struct

(* utility functions *)

fun fold_maps f = fold (fn x => uncurry (fold_map (f x)) #>> flat)

fun pair_trans ((x, y), z) = (x, (y, z))

val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)

val ignored = member (op =) [@{const_name All}, @{const_name Ex},
  @{const_name Let}, @{const_name If}, @{const_name HOL.eq}]

fun is_const pred (n, T) = not (ignored n) andalso pred T

fun collect_consts_if pred f =
  let
    fun collect (@{const trigger} $ p $ t) = collect_trigger p #> collect t
      | collect (t $ u) = collect t #> collect u
      | collect (Abs (_, _, t)) = collect t
      | collect (Const c) = if is_const pred c then f c else I
      | collect _ = I
    and collect_trigger t =
      let val dest = these o try HOLogic.dest_list 
      in fold (fold collect_pat o dest) (dest t) end
    and collect_pat (Const (@{const_name SMT.pat}, _) $ t) = collect t
      | collect_pat (Const (@{const_name SMT.nopat}, _) $ t) = collect t
      | collect_pat _ = I
  in collect o Thm.prop_of end

val insert_const = Ord_List.insert (prod_ord fast_string_ord Term_Ord.typ_ord)

fun tvar_consts_of thm = collect_consts_if typ_has_tvars insert_const thm []

fun add_const_types pred =
  collect_consts_if pred (fn (n, T) => Symtab.map_entry n (insert (op =) T))

fun incr_indexes ithms =
  let
    fun inc (i, thm) idx =
      ((i, Thm.incr_indexes idx thm), Thm.maxidx_of thm + idx + 1)
  in fst (fold_map inc ithms 0) end



(* search for necessary substitutions *)

fun new_substitutions thy limit grounds (n, T) subst instances =
  if not (typ_has_tvars T) then ([subst], instances)
  else
    Symtab.lookup_list grounds n
    |> map_filter (try (fn U => Sign.typ_match thy (T, U) subst))
    |> (fn substs => (substs, instances - length substs))
    |>> take limit (* limit the breadth of the search as well as the width *)
    |>> cons subst

fun apply_subst grounds consts subst =
  let
    fun is_new_ground (n, T) = not (typ_has_tvars T) andalso
      not (member (op =) (Symtab.lookup_list grounds n) T)

    fun apply_const (n, T) new_grounds =
      let val c = (n, Envir.subst_type subst T)
      in
        new_grounds
        |> is_new_ground c ? Symtab.insert_list (op =) c
        |> pair c
      end
  in fold_map apply_const consts #>> pair subst end

fun specialize thy limit all_grounds new_grounds scs =
  let
    fun spec (subst, consts) (next_grounds, instances) =
      ([subst], instances)
      |> fold_maps (new_substitutions thy limit new_grounds) consts
      |>> rpair next_grounds
      |>> uncurry (fold_map (apply_subst all_grounds consts))
      |> pair_trans
  in
    fold_map spec scs #>> (fn scss =>
    fold (fold (insert (eq_snd (op =)))) scss [])
  end

val limit_reached_warning = "Warning: Monomorphization limit reached"

fun search_substitutions ctxt limit instances all_grounds new_grounds scss =
  let
    val thy = ProofContext.theory_of ctxt
    val all_grounds' = Symtab.merge_list (op =) (all_grounds, new_grounds)
    val spec = specialize thy limit all_grounds' new_grounds
    val (scss', (new_grounds', instances')) =
      fold_map spec scss (Symtab.empty, instances)
  in
    if Symtab.is_empty new_grounds' then scss'
    else if limit > 0 andalso instances' > 0 then
      search_substitutions ctxt (limit-1) instances' all_grounds' new_grounds'
        scss'
    else (SMT_Config.verbose_msg ctxt (K limit_reached_warning) (); scss')
  end



(* instantiation *)

fun filter_most_specific thy =
  let
    fun typ_match (_, T) (_, U) = Sign.typ_match thy (T, U)

    fun is_trivial subst = Vartab.is_empty subst orelse
      forall (fn (v, (S, T)) => TVar (v, S) = T) (Vartab.dest subst)

    fun match general specific =
      (case try (fold2 typ_match general specific) Vartab.empty of
        NONE => false
      | SOME subst => not (is_trivial subst))

    fun most_specific _ [] = []
      | most_specific css ((ss, cs) :: scs) =
          let val substs = most_specific (cs :: css) scs
          in
            if exists (match cs) css orelse exists (match cs o snd) scs
            then substs else ss :: substs
          end

  in most_specific [] end

fun instantiate full (i, thm) substs (ithms, ctxt) =
  let
    val thy = ProofContext.theory_of ctxt

    val (vs, Ss) = split_list (Term.add_tvars (Thm.prop_of thm) [])
    val (Tenv, ctxt') =
      ctxt
      |> Variable.invent_types Ss
      |>> map2 (fn v => fn (n, S) => (v, (S, TFree (n, S)))) vs

    exception PARTIAL_INST of unit

    fun update_subst vT subst =
      if full then Vartab.update vT subst
      else raise PARTIAL_INST ()

    fun replace (v, (_, T)) (U as TVar (u, _)) = if u = v then T else U
      | replace _ T = T

    fun complete (vT as (v, _)) subst =
      subst
      |> not (Vartab.defined subst v) ? update_subst vT
      |> Vartab.map (K (apsnd (Term.map_atyps (replace vT))))

    fun cert (ix, (S, T)) = pairself (Thm.ctyp_of thy) (TVar (ix, S), T)

    fun inst subst =
      let val cTs = Vartab.fold (cons o cert) (fold complete Tenv subst) []
      in SOME (i, Thm.instantiate (cTs, []) thm) end
      handle PARTIAL_INST () => NONE

  in (map_filter inst substs @ ithms, if full then ctxt' else ctxt) end



(* overall procedure *)

fun mono_all ctxt polys monos =
  let
    val scss = map (single o pair Vartab.empty o tvar_consts_of o snd) polys

    (* all known non-schematic instances of polymorphic constants: find all
       names of polymorphic constants, then add all known ground types *)
    val grounds =
      Symtab.empty
      |> fold (fold (fold (Symtab.update o rpair [] o fst) o snd)) scss
      |> fold (add_const_types (K true) o snd) monos
      |> fold (add_const_types (not o typ_has_tvars) o snd) polys

    val limit = Config.get ctxt SMT_Config.monomorph_limit
    val instances = Config.get ctxt SMT_Config.monomorph_instances
    val full = Config.get ctxt SMT_Config.monomorph_full
  in
    scss
    |> search_substitutions ctxt limit instances Symtab.empty grounds
    |> map (filter_most_specific (ProofContext.theory_of ctxt))
    |> rpair (monos, ctxt)
    |-> fold2 (instantiate full) polys
  end

fun monomorph irules ctxt =
  irules
  |> List.partition (Term.exists_type typ_has_tvars o Thm.prop_of o snd)
  |>> incr_indexes  (* avoid clashes of schematic type variables *)
  |-> (fn [] => rpair ctxt | polys => mono_all ctxt polys)

end