src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
author blanchet
Tue, 19 Nov 2013 18:38:25 +0100
changeset 54504 096f7d452164
parent 52627 ecb4a858991d
child 54821 a12796872603
permissions -rw-r--r--
tuning

(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
    Author:     Jasmin Blanchette, TU Muenchen
    Author:     Steffen Juilf Smolka, TU Muenchen

Supplements term with a locally minmal, complete set of type constraints.
Complete: The constraints suffice to infer the term's types.
Minimal: Reducing the set of constraints further will make it incomplete.

When configuring the pretty printer appropriately, the constraints will show up
as type annotations when printing the term. This allows the term to be printed
and reparsed without a change of types.

NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
syntax.
*)

signature SLEDGEHAMMER_ANNOTATE =
sig
  val annotate_types : Proof.context -> term -> term
end;

structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
struct

(* Util *)
fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    let
      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    in f (Abs (x, T1, b')) (T1 --> T2) s' end
  | post_traverse_term_type' f env (u $ v) s =
    let
      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
      val ((v', s''), _) = post_traverse_term_type' f env v s'
    in f (u' $ v') T s'' end
    handle Bind => raise Fail "Sledgehammer_Annotate: post_traverse_term_type'"

fun post_traverse_term_type f s t =
  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
fun post_fold_term_type f s t =
  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd

fun fold_map_atypes f T s =
  case T of
    Type (name, Ts) =>
        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
          (Type (name, Ts), s)
        end
  | _ => f T s

(** get unique elements of a list **)
local
  fun unique' b x [] = if b then [x] else []
    | unique' b x (y :: ys) =
      if x = y then unique' false x ys
      else unique' true y ys |> b ? cons x
in
  fun unique ord xs =
    case sort ord xs of x :: ys => unique' true x ys | [] => []
end

(** Data structures, orders **)
val indexname_ord = Term_Ord.fast_indexname_ord
val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
structure Var_Set_Tab = Table(
  type key = indexname list
  val ord = list_ord indexname_ord)

(* (1) Generalize types *)
fun generalize_types ctxt t =
  let
    val erase_types = map_types (fn _ => dummyT)
    (* use schematic type variables *)
    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
  in
     t |> erase_types |> infer_types
  end

(* (2) match types *)
fun match_types ctxt t1 t2 =
  let
    val thy = Proof_Context.theory_of ctxt
    val get_types = post_fold_term_type (K cons) []
  in
    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
    handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Annotate: match_types"
  end


(* (3) handle trivial tfrees  *)
fun handle_trivial_tfrees ctxt (t', subst) =
  let

    val add_tfree_names =
      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)

    val trivial_tfree_names =
      Vartab.fold add_tfree_names subst []
      |> filter_out (Variable.is_declared ctxt)
      |> unique fast_string_ord
    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names

    val trivial_tvar_names =
      Vartab.fold
        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
               tfree_name_trivial tfree_name ? cons tvar_name
          | _ => I)
        subst
        []
      |> sort indexname_ord
    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names

    val t' =
      t' |> map_types
              (map_type_tvar
                (fn (idxn, sort) =>
                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))

    val subst =
      subst |> fold Vartab.delete trivial_tvar_names
            |> Vartab.map
               (K (apsnd (map_type_tfree
                           (fn (name, sort) =>
                              if tfree_name_trivial name then dummyT
                              else TFree (name, sort)))))
  in
    (t', subst)
  end


(* (4) Typing-spot table *)
local
fun key_of_atype (TVar (z, _)) =
    Ord_List.insert indexname_ord z
  | key_of_atype _ = I
fun key_of_type T = fold_atyps key_of_atype T []
fun update_tab t T (tab, pos) =
  (case key_of_type T of
     [] => tab
   | key =>
     let val cost = (size_of_typ T, (size_of_term t, pos)) in
       case Var_Set_Tab.lookup tab key of
         NONE => Var_Set_Tab.update_new (key, cost) tab
       | SOME old_cost =>
         (case cost_ord (cost, old_cost) of
            LESS => Var_Set_Tab.update (key, cost) tab
          | _ => tab)
     end,
   pos + 1)
in
val typing_spot_table =
  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
end

(* (5) Reverse-greedy *)
fun reverse_greedy typing_spot_tab =
  let
    fun update_count z =
      fold (fn tvar => fn tab =>
        let val c = Vartab.lookup tab tvar |> the_default 0 in
          Vartab.update (tvar, c + z) tab
        end)
    fun superfluous tcount =
      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
      else (spot :: spots, tcount)
    val (typing_spots, tvar_count_tab) =
      Var_Set_Tab.fold
        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
        typing_spot_tab ([], Vartab.empty)
      |>> sort_distinct (rev_order o cost_ord o pairself snd)
  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end

(* (6) Introduce annotations *)
fun introduce_annotations subst spots t t' =
  let
    fun subst_atype (T as TVar (idxn, S)) subst =
        (Envir.subst_type subst T,
         Vartab.update (idxn, (S, dummyT)) subst)
      | subst_atype T subst = (T, subst)
    val subst_type = fold_map_atypes subst_atype
    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
        if p <> cp then
          (subst, cp + 1, ps, annots)
        else
          let val (T, subst) = subst_type T subst in
            (subst, cp + 1, ps', (p, T)::annots)
          end
      | collect_annot _ _ x = x
    val (_, _, _, annots) =
      post_fold_term_type collect_annot (subst, 0, spots, []) t'
    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
        if p <> cp then (t, (cp + 1, annots))
        else (Type.constraint T t, (cp + 1, annots'))
      | insert_annot t _ x = (t, x)
  in
    t |> post_traverse_term_type insert_annot (0, rev annots)
      |> fst
  end

(* (7) Annotate *)
fun annotate_types ctxt t =
  let
    val t' = generalize_types ctxt t
    val subst = match_types ctxt t' t
    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
    val typing_spots =
      t' |> typing_spot_table
         |> reverse_greedy
         |> sort int_ord
  in introduce_annotations subst typing_spots t t' end

end;