src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
author smolkas
Tue Jun 11 19:58:09 2013 -0400 (2013-06-11 ago)
changeset 52369 0b395800fdf0
parent 52366 ff89424b5094
child 52452 2207825d67f3
permissions -rw-r--r--
uncheck terms before annotation to avoid awkward syntax
smolkas@50263
     1
(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
smolkas@50263
     2
    Author:     Jasmin Blanchette, TU Muenchen
smolkas@50263
     3
    Author:     Steffen Juilf Smolka, TU Muenchen
smolkas@50263
     4
smolkas@52369
     5
Supplements term with a locally minmal, complete set of type constraints.
smolkas@52369
     6
Complete: The constraints suffice to infer the term's types.
smolkas@52369
     7
Minimal: Reducing the set of constraints further will make it incomplete.
smolkas@52369
     8
smolkas@52369
     9
When configuring the pretty printer appropriately, the constraints will show up
smolkas@52369
    10
as type annotations when printing the term. This allows the term to be reparsed
smolkas@52369
    11
without a change of types.
smolkas@52369
    12
smolkas@52369
    13
NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
smolkas@52369
    14
syntax.
smolkas@50263
    15
*)
smolkas@50263
    16
smolkas@50258
    17
signature SLEDGEHAMMER_ANNOTATE =
smolkas@50258
    18
sig
smolkas@50258
    19
  val annotate_types : Proof.context -> term -> term
smolkas@50258
    20
end
smolkas@50258
    21
smolkas@50258
    22
structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
smolkas@50258
    23
struct
smolkas@50258
    24
smolkas@50258
    25
(* Util *)
smolkas@50258
    26
fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
smolkas@50258
    27
  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
smolkas@50258
    28
  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
smolkas@50258
    29
  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
smolkas@50258
    30
  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
smolkas@50258
    31
    let
smolkas@50258
    32
      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
smolkas@50258
    33
    in f (Abs (x, T1, b')) (T1 --> T2) s' end
smolkas@50258
    34
  | post_traverse_term_type' f env (u $ v) s =
smolkas@50258
    35
    let
smolkas@50258
    36
      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
smolkas@50258
    37
      val ((v', s''), _) = post_traverse_term_type' f env v s'
smolkas@50258
    38
    in f (u' $ v') T s'' end
smolkas@50258
    39
smolkas@50258
    40
fun post_traverse_term_type f s t =
smolkas@50258
    41
  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
smolkas@50258
    42
fun post_fold_term_type f s t =
smolkas@50258
    43
  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
smolkas@50258
    44
smolkas@50258
    45
(* Data structures, orders *)
smolkas@50258
    46
val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
smolkas@50258
    47
structure Var_Set_Tab = Table(
smolkas@50258
    48
  type key = indexname list
smolkas@50258
    49
  val ord = list_ord Term_Ord.fast_indexname_ord)
smolkas@50258
    50
smolkas@50258
    51
(* (1) Generalize types *)
smolkas@50258
    52
fun generalize_types ctxt t =
smolkas@52369
    53
  let
smolkas@52369
    54
    val erase_types = map_types (fn _ => dummyT)
smolkas@52369
    55
    (* use schematic type variables *)
smolkas@52369
    56
    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
smolkas@52369
    57
    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
smolkas@52369
    58
  in
smolkas@52369
    59
     t |> erase_types |> infer_types
smolkas@52369
    60
  end
smolkas@50258
    61
smolkas@50258
    62
(* (2) Typing-spot table *)
smolkas@50258
    63
local
smolkas@50258
    64
fun key_of_atype (TVar (z, _)) =
smolkas@50258
    65
    Ord_List.insert Term_Ord.fast_indexname_ord z
smolkas@50258
    66
  | key_of_atype _ = I
smolkas@50258
    67
fun key_of_type T = fold_atyps key_of_atype T []
smolkas@50258
    68
fun update_tab t T (tab, pos) =
smolkas@50258
    69
  (case key_of_type T of
smolkas@50258
    70
     [] => tab
smolkas@50258
    71
   | key =>
smolkas@50258
    72
     let val cost = (size_of_typ T, (size_of_term t, pos)) in
smolkas@50258
    73
       case Var_Set_Tab.lookup tab key of
smolkas@50258
    74
         NONE => Var_Set_Tab.update_new (key, cost) tab
smolkas@50258
    75
       | SOME old_cost =>
smolkas@50258
    76
         (case cost_ord (cost, old_cost) of
smolkas@50258
    77
            LESS => Var_Set_Tab.update (key, cost) tab
smolkas@50258
    78
          | _ => tab)
smolkas@50258
    79
     end,
smolkas@50258
    80
   pos + 1)
smolkas@50258
    81
in
smolkas@50258
    82
val typing_spot_table =
smolkas@50258
    83
  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
smolkas@50258
    84
end
smolkas@50258
    85
smolkas@50258
    86
(* (3) Reverse-greedy *)
smolkas@50258
    87
fun reverse_greedy typing_spot_tab =
smolkas@50258
    88
  let
smolkas@50258
    89
    fun update_count z =
smolkas@50258
    90
      fold (fn tvar => fn tab =>
smolkas@50258
    91
        let val c = Vartab.lookup tab tvar |> the_default 0 in
smolkas@50258
    92
          Vartab.update (tvar, c + z) tab
smolkas@50258
    93
        end)
smolkas@50258
    94
    fun superfluous tcount =
smolkas@50258
    95
      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
smolkas@50258
    96
    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
smolkas@50258
    97
      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
smolkas@50258
    98
      else (spot :: spots, tcount)
smolkas@50258
    99
    val (typing_spots, tvar_count_tab) =
smolkas@50258
   100
      Var_Set_Tab.fold
smolkas@50258
   101
        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
smolkas@50258
   102
        typing_spot_tab ([], Vartab.empty)
smolkas@50258
   103
      |>> sort_distinct (rev_order o cost_ord o pairself snd)
smolkas@50258
   104
  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
smolkas@50258
   105
smolkas@50258
   106
(* (4) Introduce annotations *)
smolkas@50258
   107
fun introduce_annotations ctxt spots t t' =
smolkas@50258
   108
  let
smolkas@50258
   109
    val thy = Proof_Context.theory_of ctxt
smolkas@50258
   110
    val get_types = post_fold_term_type (K cons) []
smolkas@50258
   111
    fun match_types tp =
smolkas@50258
   112
      fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
smolkas@50258
   113
    fun unica' b x [] = if b then [x] else []
smolkas@50258
   114
      | unica' b x (y :: ys) =
smolkas@50258
   115
        if x = y then unica' false x ys
smolkas@50258
   116
        else unica' true y ys |> b ? cons x
smolkas@50258
   117
    fun unica ord xs =
smolkas@50258
   118
      case sort ord xs of x :: ys => unica' true x ys | [] => []
smolkas@50258
   119
    val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
smolkas@50258
   120
    fun erase_unica_tfrees env =
smolkas@50258
   121
      let
smolkas@50258
   122
        val unica =
smolkas@50258
   123
          Vartab.fold (add_all_tfree_namesT o snd o snd) env []
smolkas@50258
   124
          |> filter_out (Variable.is_declared ctxt)
smolkas@50258
   125
          |> unica fast_string_ord
smolkas@50258
   126
        val erase_unica = map_atyps
smolkas@50258
   127
          (fn T as TFree (s, _) =>
smolkas@50258
   128
              if Ord_List.member fast_string_ord unica s then dummyT else T
smolkas@50258
   129
            | T => T)
smolkas@50258
   130
      in Vartab.map (K (apsnd erase_unica)) env end
smolkas@50258
   131
    val env = match_types (t', t) |> erase_unica_tfrees
smolkas@50258
   132
    fun get_annot env (TFree _) = (false, (env, dummyT))
smolkas@50258
   133
      | get_annot env (T as TVar (v, S)) =
smolkas@50258
   134
        let val T' = Envir.subst_type env T in
smolkas@50258
   135
          if T' = dummyT then (false, (env, dummyT))
smolkas@50258
   136
          else (true, (Vartab.update (v, (S, dummyT)) env, T'))
smolkas@50258
   137
        end
smolkas@50258
   138
      | get_annot env (Type (S, Ts)) =
smolkas@50258
   139
        (case fold_rev (fn T => fn (b, (env, Ts)) =>
smolkas@50258
   140
                  let
smolkas@50258
   141
                    val (b', (env', T)) = get_annot env T
smolkas@50258
   142
                  in (b orelse b', (env', T :: Ts)) end)
smolkas@50258
   143
                Ts (false, (env, [])) of
smolkas@50258
   144
           (true, (env', Ts)) => (true, (env', Type (S, Ts)))
smolkas@50258
   145
         | (false, (env', _)) => (false, (env', dummyT)))
smolkas@50258
   146
    fun post1 _ T (env, cp, ps as p :: ps', annots) =
smolkas@50258
   147
        if p <> cp then
smolkas@50258
   148
          (env, cp + 1, ps, annots)
smolkas@50258
   149
        else
smolkas@51877
   150
          let val (annot_necessary, (env', T')) = get_annot env T in
smolkas@51877
   151
            (env', cp + 1, ps', annots |> annot_necessary ? cons (p, T'))
smolkas@50258
   152
          end
smolkas@50258
   153
      | post1 _ _ accum = accum
smolkas@50258
   154
    val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
smolkas@50258
   155
    fun post2 t _ (cp, annots as (p, T) :: annots') =
smolkas@50258
   156
        if p <> cp then (t, (cp + 1, annots))
smolkas@50258
   157
        else (Type.constraint T t, (cp + 1, annots'))
smolkas@50258
   158
      | post2 t _ x = (t, x)
smolkas@52110
   159
  in
smolkas@52366
   160
    t |> post_traverse_term_type post2 (0, rev annots)
smolkas@52110
   161
      |> fst
smolkas@52110
   162
  end
smolkas@50258
   163
smolkas@50258
   164
(* (5) Annotate *)
smolkas@50258
   165
fun annotate_types ctxt t =
smolkas@50258
   166
  let
smolkas@50258
   167
    val t' = generalize_types ctxt t
smolkas@50258
   168
    val typing_spots =
smolkas@50258
   169
      t' |> typing_spot_table
smolkas@50258
   170
         |> reverse_greedy
smolkas@50258
   171
         |> sort int_ord
smolkas@50258
   172
  in introduce_annotations ctxt typing_spots t t' end
smolkas@50258
   173
smolkas@50258
   174
end