src/HOL/SMT/Tools/smt_monomorph.ML
author wenzelm
Thu, 15 Oct 2009 21:28:39 +0200
changeset 32950 5d5e123443b3
parent 32618 42865636d006
child 33038 8f9594c31de4
permissions -rw-r--r--
normalized aliases of Output operations;

(*  Title:      HOL/SMT/Tools/smt_monomorph.ML
    Author:     Sascha Boehme, TU Muenchen

Monomorphization of terms, i.e., computation of all (necessary) instances.
*)

signature SMT_MONOMORPH =
sig
  val monomorph: theory -> term list -> term list
end

structure SMT_Monomorph: SMT_MONOMORPH =
struct

fun selection [] = []
  | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs)

fun permute [] = []
  | permute [x] = [[x]]
  | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs)

fun fold_all f = fold (fn x => maps (f x))


val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
val term_has_tvars = Term.exists_type typ_has_tvars

val ignored = member (op =) [
  @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
  @{const_name "op ="}, @{const_name zero_class.zero},
  @{const_name one_class.one}, @{const_name number_of}]
fun consts_of ts = AList.group (op =) (fold Term.add_consts ts [])
  |> filter_out (ignored o fst)

val join_consts = curry (AList.join (op =) (K (merge (op =))))
fun diff_consts cs ds = 
  let fun diff (n, Ts) =
    (case AList.lookup (op =) cs n of
      NONE => SOME (n, Ts)
    | SOME Us =>
        let val Ts' = fold (remove (op =)) Us Ts
        in if null Ts' then NONE else SOME (n, Ts') end)
  in map_filter diff ds end

fun instances thy is (n, Ts) env =
  let
    val Us = these (AList.lookup (op =) is n)
    val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts)
  in
    (case map_product pair Ts' Us of
      [] => [env]
    | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs)
  end

fun proper_match ps env =
  forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps

val eq_tab = gen_eq_set (op =) o pairself Vartab.dest

fun specialize thy cs is ((r, ps), ces) (ts, ns) =
  let
    val ps' = filter (AList.defined (op =) is o fst) ps

    val envs = permute ps'
      |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty])
      |> filter (proper_match ps')
      |> filter_out (member eq_tab ces)
      |> distinct eq_tab

    val us = map (fn env => Envir.subst_term_types env r) envs
    val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns
  in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end


fun incr_tvar_indices i t =
  let
    val incrT = Logic.incr_tvar i

    fun incr t =
      (case t of
        Const (n, T) => Const (n, incrT T)
      | Free (n, T) => Free (n, incrT T)
      | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
      | t1 $ t2 => incr t1 $ incr t2
      | _ => t)
  in incr t end


val monomorph_limit = 10

(* Instantiate all polymorphic constants (i.e., constants occurring both with
   ground types and type variables) with all (necessary) ground types; thereby
   create copies of terms containing those constants.
   To prevent non-termination, there is an upper limit for the number of
   recursions involved in the fixpoint construction. *)
fun monomorph thy ts =
  let
    val (ps, ms) = List.partition term_has_tvars ts

    fun with_tvar (n, Ts) =
      let val Ts' = filter typ_has_tvars Ts
      in if null Ts' then NONE else SOME (n, Ts') end
    fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
    val rps = fst (fold_map incr ps 0)
      |> map (fn r => (r, map_filter with_tvar (consts_of [r])))

    fun mono count is ces cs ts =
      let
        val spec = specialize thy cs is
        val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
        val cs' = join_consts is cs
      in
        if null is' then ts'
        else if count > monomorph_limit then
          (warning "monomorphization limit reached"; ts')
        else mono (count + 1) is' ces' cs' ts'
      end
  in mono 0 (consts_of ms) (map (K []) rps) [] ms end

end