src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
author blanchet
Fri Jan 31 16:10:39 2014 +0100 (2014-01-31 ago)
changeset 55212 5832470d956e
parent 55205 8450622db0c5
child 55213 dcb36a2540bc
permissions -rw-r--r--
tuning
blanchet@55202
     1
(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
blanchet@55212
     2
    Author:     Steffen Juilf Smolka, TU Muenchen
smolkas@50263
     3
    Author:     Jasmin Blanchette, TU Muenchen
smolkas@50263
     4
blanchet@55212
     5
Supplements term with a locally minmal, complete set of type constraints. Complete: The constraints
blanchet@55212
     6
suffice to infer the term's types. Minimal: Reducing the set of constraints further will make it
blanchet@55212
     7
incomplete.
smolkas@52369
     8
blanchet@55212
     9
When configuring the pretty printer appropriately, the constraints will show up as type annotations
blanchet@55212
    10
when printing the term. This allows the term to be printed and reparsed without a change of types.
smolkas@52369
    11
blanchet@55212
    12
NOTE: Terms should be unchecked before calling annotate_types to avoid awkward syntax.
smolkas@50263
    13
*)
smolkas@50263
    14
blanchet@55202
    15
signature SLEDGEHAMMER_ISAR_ANNOTATE =
smolkas@50258
    16
sig
smolkas@50258
    17
  val annotate_types : Proof.context -> term -> term
blanchet@54504
    18
end;
smolkas@50258
    19
blanchet@55202
    20
structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
smolkas@50258
    21
struct
smolkas@50258
    22
smolkas@50258
    23
(* Util *)
smolkas@50258
    24
fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
smolkas@50258
    25
  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
smolkas@50258
    26
  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
smolkas@50258
    27
  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
smolkas@50258
    28
  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
smolkas@50258
    29
    let
smolkas@50258
    30
      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
smolkas@50258
    31
    in f (Abs (x, T1, b')) (T1 --> T2) s' end
smolkas@50258
    32
  | post_traverse_term_type' f env (u $ v) s =
smolkas@50258
    33
    let
smolkas@50258
    34
      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
smolkas@50258
    35
      val ((v', s''), _) = post_traverse_term_type' f env v s'
smolkas@50258
    36
    in f (u' $ v') T s'' end
blanchet@55202
    37
    handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
smolkas@50258
    38
smolkas@50258
    39
fun post_traverse_term_type f s t =
smolkas@50258
    40
  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
smolkas@50258
    41
fun post_fold_term_type f s t =
smolkas@50258
    42
  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
smolkas@50258
    43
smolkas@52452
    44
fun fold_map_atypes f T s =
smolkas@52452
    45
  case T of
smolkas@52452
    46
    Type (name, Ts) =>
smolkas@52452
    47
        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
smolkas@52452
    48
          (Type (name, Ts), s)
smolkas@52452
    49
        end
smolkas@52452
    50
  | _ => f T s
smolkas@52452
    51
smolkas@52452
    52
(** get unique elements of a list **)
smolkas@52452
    53
local
smolkas@52452
    54
  fun unique' b x [] = if b then [x] else []
smolkas@52452
    55
    | unique' b x (y :: ys) =
smolkas@52452
    56
      if x = y then unique' false x ys
smolkas@52452
    57
      else unique' true y ys |> b ? cons x
smolkas@52452
    58
in
smolkas@52452
    59
  fun unique ord xs =
smolkas@52452
    60
    case sort ord xs of x :: ys => unique' true x ys | [] => []
smolkas@52452
    61
end
smolkas@52452
    62
smolkas@52452
    63
(** Data structures, orders **)
smolkas@52452
    64
val indexname_ord = Term_Ord.fast_indexname_ord
smolkas@50258
    65
val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
smolkas@50258
    66
structure Var_Set_Tab = Table(
smolkas@50258
    67
  type key = indexname list
smolkas@52452
    68
  val ord = list_ord indexname_ord)
smolkas@50258
    69
smolkas@50258
    70
(* (1) Generalize types *)
smolkas@50258
    71
fun generalize_types ctxt t =
smolkas@52369
    72
  let
smolkas@52369
    73
    val erase_types = map_types (fn _ => dummyT)
smolkas@52369
    74
    (* use schematic type variables *)
smolkas@52369
    75
    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
smolkas@52369
    76
    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
smolkas@52369
    77
  in
smolkas@52369
    78
     t |> erase_types |> infer_types
smolkas@52369
    79
  end
smolkas@50258
    80
smolkas@52452
    81
(* (2) match types *)
smolkas@52452
    82
fun match_types ctxt t1 t2 =
smolkas@52452
    83
  let
smolkas@52452
    84
    val thy = Proof_Context.theory_of ctxt
smolkas@52452
    85
    val get_types = post_fold_term_type (K cons) []
smolkas@52452
    86
  in
smolkas@52452
    87
    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
blanchet@55202
    88
    handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
smolkas@52452
    89
  end
smolkas@52452
    90
smolkas@52452
    91
smolkas@52452
    92
(* (3) handle trivial tfrees  *)
smolkas@52452
    93
fun handle_trivial_tfrees ctxt (t', subst) =
smolkas@52452
    94
  let
smolkas@52452
    95
    val add_tfree_names =
smolkas@52452
    96
      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
smolkas@52452
    97
smolkas@52452
    98
    val trivial_tfree_names =
smolkas@52452
    99
      Vartab.fold add_tfree_names subst []
smolkas@52452
   100
      |> filter_out (Variable.is_declared ctxt)
smolkas@52452
   101
      |> unique fast_string_ord
smolkas@52452
   102
    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
smolkas@52452
   103
smolkas@52452
   104
    val trivial_tvar_names =
smolkas@52452
   105
      Vartab.fold
smolkas@52452
   106
        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
smolkas@52452
   107
               tfree_name_trivial tfree_name ? cons tvar_name
smolkas@52452
   108
          | _ => I)
smolkas@52452
   109
        subst
smolkas@52452
   110
        []
smolkas@52452
   111
      |> sort indexname_ord
smolkas@52452
   112
    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
smolkas@52452
   113
smolkas@52452
   114
    val t' =
smolkas@52452
   115
      t' |> map_types
smolkas@52452
   116
              (map_type_tvar
smolkas@52452
   117
                (fn (idxn, sort) =>
smolkas@52452
   118
                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
smolkas@52452
   119
smolkas@52452
   120
    val subst =
smolkas@52452
   121
      subst |> fold Vartab.delete trivial_tvar_names
smolkas@52452
   122
            |> Vartab.map
smolkas@52452
   123
               (K (apsnd (map_type_tfree
smolkas@52452
   124
                           (fn (name, sort) =>
smolkas@52452
   125
                              if tfree_name_trivial name then dummyT
smolkas@52452
   126
                              else TFree (name, sort)))))
smolkas@52452
   127
  in
smolkas@52452
   128
    (t', subst)
smolkas@52452
   129
  end
smolkas@52452
   130
smolkas@52452
   131
(* (4) Typing-spot table *)
smolkas@50258
   132
local
blanchet@54821
   133
fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
smolkas@50258
   134
  | key_of_atype _ = I
smolkas@50258
   135
fun key_of_type T = fold_atyps key_of_atype T []
smolkas@50258
   136
fun update_tab t T (tab, pos) =
smolkas@50258
   137
  (case key_of_type T of
smolkas@50258
   138
     [] => tab
smolkas@50258
   139
   | key =>
smolkas@50258
   140
     let val cost = (size_of_typ T, (size_of_term t, pos)) in
smolkas@50258
   141
       case Var_Set_Tab.lookup tab key of
smolkas@50258
   142
         NONE => Var_Set_Tab.update_new (key, cost) tab
smolkas@50258
   143
       | SOME old_cost =>
smolkas@50258
   144
         (case cost_ord (cost, old_cost) of
smolkas@50258
   145
            LESS => Var_Set_Tab.update (key, cost) tab
smolkas@50258
   146
          | _ => tab)
smolkas@50258
   147
     end,
smolkas@50258
   148
   pos + 1)
smolkas@50258
   149
in
smolkas@50258
   150
val typing_spot_table =
smolkas@50258
   151
  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
smolkas@50258
   152
end
smolkas@50258
   153
smolkas@52452
   154
(* (5) Reverse-greedy *)
smolkas@50258
   155
fun reverse_greedy typing_spot_tab =
smolkas@50258
   156
  let
smolkas@50258
   157
    fun update_count z =
smolkas@50258
   158
      fold (fn tvar => fn tab =>
smolkas@50258
   159
        let val c = Vartab.lookup tab tvar |> the_default 0 in
smolkas@50258
   160
          Vartab.update (tvar, c + z) tab
smolkas@50258
   161
        end)
smolkas@50258
   162
    fun superfluous tcount =
smolkas@50258
   163
      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
smolkas@50258
   164
    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
smolkas@50258
   165
      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
smolkas@50258
   166
      else (spot :: spots, tcount)
smolkas@50258
   167
    val (typing_spots, tvar_count_tab) =
smolkas@50258
   168
      Var_Set_Tab.fold
smolkas@50258
   169
        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
smolkas@50258
   170
        typing_spot_tab ([], Vartab.empty)
smolkas@50258
   171
      |>> sort_distinct (rev_order o cost_ord o pairself snd)
smolkas@50258
   172
  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
smolkas@50258
   173
smolkas@52452
   174
(* (6) Introduce annotations *)
smolkas@52452
   175
fun introduce_annotations subst spots t t' =
smolkas@50258
   176
  let
smolkas@52452
   177
    fun subst_atype (T as TVar (idxn, S)) subst =
blanchet@54821
   178
        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
smolkas@52452
   179
      | subst_atype T subst = (T, subst)
smolkas@52452
   180
    val subst_type = fold_map_atypes subst_atype
smolkas@52452
   181
    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
smolkas@50258
   182
        if p <> cp then
smolkas@52452
   183
          (subst, cp + 1, ps, annots)
smolkas@50258
   184
        else
smolkas@52452
   185
          let val (T, subst) = subst_type T subst in
smolkas@52452
   186
            (subst, cp + 1, ps', (p, T)::annots)
smolkas@50258
   187
          end
smolkas@52452
   188
      | collect_annot _ _ x = x
smolkas@52452
   189
    val (_, _, _, annots) =
smolkas@52452
   190
      post_fold_term_type collect_annot (subst, 0, spots, []) t'
smolkas@52452
   191
    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
blanchet@54821
   192
        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
smolkas@52452
   193
      | insert_annot t _ x = (t, x)
smolkas@52110
   194
  in
smolkas@52452
   195
    t |> post_traverse_term_type insert_annot (0, rev annots)
smolkas@52110
   196
      |> fst
smolkas@52110
   197
  end
smolkas@50258
   198
smolkas@52452
   199
(* (7) Annotate *)
smolkas@50258
   200
fun annotate_types ctxt t =
smolkas@50258
   201
  let
smolkas@50258
   202
    val t' = generalize_types ctxt t
smolkas@52452
   203
    val subst = match_types ctxt t' t
smolkas@52452
   204
    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
smolkas@50258
   205
    val typing_spots =
smolkas@50258
   206
      t' |> typing_spot_table
smolkas@50258
   207
         |> reverse_greedy
smolkas@50258
   208
         |> sort int_ord
smolkas@52452
   209
  in introduce_annotations subst typing_spots t t' end
smolkas@50258
   210
blanchet@54504
   211
end;