src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
author blanchet
Mon Feb 03 15:33:18 2014 +0100 (2014-02-03 ago)
changeset 55286 7bbbd9393ce0
parent 55243 66709d41601e
child 57467 03345dad8430
permissions -rw-r--r--
tuning
blanchet@55202
     1
(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
blanchet@55212
     2
    Author:     Steffen Juilf Smolka, TU Muenchen
smolkas@50263
     3
    Author:     Jasmin Blanchette, TU Muenchen
smolkas@50263
     4
blanchet@55212
     5
Supplements term with a locally minmal, complete set of type constraints. Complete: The constraints
blanchet@55212
     6
suffice to infer the term's types. Minimal: Reducing the set of constraints further will make it
blanchet@55212
     7
incomplete.
smolkas@52369
     8
blanchet@55212
     9
When configuring the pretty printer appropriately, the constraints will show up as type annotations
blanchet@55212
    10
when printing the term. This allows the term to be printed and reparsed without a change of types.
smolkas@52369
    11
blanchet@55213
    12
Note: Terms should be unchecked before calling "annotate_types_in_term" to avoid awkward syntax.
smolkas@50263
    13
*)
smolkas@50263
    14
blanchet@55202
    15
signature SLEDGEHAMMER_ISAR_ANNOTATE =
smolkas@50258
    16
sig
blanchet@55213
    17
  val annotate_types_in_term : Proof.context -> term -> term
blanchet@54504
    18
end;
smolkas@50258
    19
blanchet@55202
    20
structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
smolkas@50258
    21
struct
smolkas@50258
    22
smolkas@50258
    23
fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
smolkas@50258
    24
  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
smolkas@50258
    25
  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
smolkas@50258
    26
  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
smolkas@50258
    27
  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
blanchet@55243
    28
    let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
blanchet@55243
    29
      f (Abs (x, T1, b')) (T1 --> T2) s'
blanchet@55243
    30
    end
smolkas@50258
    31
  | post_traverse_term_type' f env (u $ v) s =
smolkas@50258
    32
    let
smolkas@50258
    33
      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
smolkas@50258
    34
      val ((v', s''), _) = post_traverse_term_type' f env v s'
smolkas@50258
    35
    in f (u' $ v') T s'' end
blanchet@55202
    36
    handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
smolkas@50258
    37
smolkas@50258
    38
fun post_traverse_term_type f s t =
smolkas@50258
    39
  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
smolkas@50258
    40
fun post_fold_term_type f s t =
smolkas@50258
    41
  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
smolkas@50258
    42
smolkas@52452
    43
fun fold_map_atypes f T s =
blanchet@55286
    44
  (case T of
smolkas@52452
    45
    Type (name, Ts) =>
blanchet@55286
    46
    let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
blanchet@55286
    47
      (Type (name, Ts), s)
blanchet@55286
    48
    end
blanchet@55286
    49
  | _ => f T s)
smolkas@52452
    50
smolkas@52452
    51
val indexname_ord = Term_Ord.fast_indexname_ord
smolkas@50258
    52
val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
blanchet@55243
    53
smolkas@50258
    54
structure Var_Set_Tab = Table(
smolkas@50258
    55
  type key = indexname list
smolkas@52452
    56
  val ord = list_ord indexname_ord)
smolkas@50258
    57
smolkas@50258
    58
fun generalize_types ctxt t =
smolkas@52369
    59
  let
smolkas@52369
    60
    val erase_types = map_types (fn _ => dummyT)
smolkas@52369
    61
    (* use schematic type variables *)
smolkas@52369
    62
    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
smolkas@52369
    63
    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
smolkas@52369
    64
  in
smolkas@52369
    65
     t |> erase_types |> infer_types
smolkas@52369
    66
  end
smolkas@50258
    67
smolkas@52452
    68
fun match_types ctxt t1 t2 =
smolkas@52452
    69
  let
smolkas@52452
    70
    val thy = Proof_Context.theory_of ctxt
smolkas@52452
    71
    val get_types = post_fold_term_type (K cons) []
smolkas@52452
    72
  in
smolkas@52452
    73
    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
blanchet@55202
    74
    handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
smolkas@52452
    75
  end
smolkas@52452
    76
smolkas@52452
    77
fun handle_trivial_tfrees ctxt (t', subst) =
smolkas@52452
    78
  let
blanchet@55243
    79
    val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
smolkas@52452
    80
smolkas@52452
    81
    val trivial_tfree_names =
smolkas@52452
    82
      Vartab.fold add_tfree_names subst []
smolkas@52452
    83
      |> filter_out (Variable.is_declared ctxt)
blanchet@55243
    84
      |> distinct (op =)
smolkas@52452
    85
    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
smolkas@52452
    86
smolkas@52452
    87
    val trivial_tvar_names =
smolkas@52452
    88
      Vartab.fold
smolkas@52452
    89
        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
smolkas@52452
    90
               tfree_name_trivial tfree_name ? cons tvar_name
smolkas@52452
    91
          | _ => I)
smolkas@52452
    92
        subst
smolkas@52452
    93
        []
smolkas@52452
    94
      |> sort indexname_ord
smolkas@52452
    95
    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
smolkas@52452
    96
smolkas@52452
    97
    val t' =
smolkas@52452
    98
      t' |> map_types
smolkas@52452
    99
              (map_type_tvar
smolkas@52452
   100
                (fn (idxn, sort) =>
smolkas@52452
   101
                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
smolkas@52452
   102
smolkas@52452
   103
    val subst =
smolkas@52452
   104
      subst |> fold Vartab.delete trivial_tvar_names
smolkas@52452
   105
            |> Vartab.map
smolkas@52452
   106
               (K (apsnd (map_type_tfree
smolkas@52452
   107
                           (fn (name, sort) =>
smolkas@52452
   108
                              if tfree_name_trivial name then dummyT
smolkas@52452
   109
                              else TFree (name, sort)))))
smolkas@52452
   110
  in
smolkas@52452
   111
    (t', subst)
smolkas@52452
   112
  end
smolkas@52452
   113
blanchet@54821
   114
fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
smolkas@50258
   115
  | key_of_atype _ = I
smolkas@50258
   116
fun key_of_type T = fold_atyps key_of_atype T []
blanchet@55243
   117
smolkas@50258
   118
fun update_tab t T (tab, pos) =
blanchet@55243
   119
  ((case key_of_type T of
smolkas@50258
   120
     [] => tab
smolkas@50258
   121
   | key =>
smolkas@50258
   122
     let val cost = (size_of_typ T, (size_of_term t, pos)) in
blanchet@55243
   123
       (case Var_Set_Tab.lookup tab key of
smolkas@50258
   124
         NONE => Var_Set_Tab.update_new (key, cost) tab
smolkas@50258
   125
       | SOME old_cost =>
smolkas@50258
   126
         (case cost_ord (cost, old_cost) of
blanchet@55243
   127
           LESS => Var_Set_Tab.update (key, cost) tab
blanchet@55243
   128
         | _ => tab))
blanchet@55243
   129
     end),
smolkas@50258
   130
   pos + 1)
smolkas@50258
   131
blanchet@55243
   132
val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
blanchet@55243
   133
smolkas@50258
   134
fun reverse_greedy typing_spot_tab =
smolkas@50258
   135
  let
smolkas@50258
   136
    fun update_count z =
smolkas@50258
   137
      fold (fn tvar => fn tab =>
smolkas@50258
   138
        let val c = Vartab.lookup tab tvar |> the_default 0 in
smolkas@50258
   139
          Vartab.update (tvar, c + z) tab
smolkas@50258
   140
        end)
blanchet@55243
   141
    fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
smolkas@50258
   142
    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
smolkas@50258
   143
      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
smolkas@50258
   144
      else (spot :: spots, tcount)
blanchet@55243
   145
smolkas@50258
   146
    val (typing_spots, tvar_count_tab) =
blanchet@55243
   147
      Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
smolkas@50258
   148
        typing_spot_tab ([], Vartab.empty)
smolkas@50258
   149
      |>> sort_distinct (rev_order o cost_ord o pairself snd)
blanchet@55243
   150
  in
blanchet@55243
   151
    fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
blanchet@55243
   152
  end
smolkas@50258
   153
smolkas@52452
   154
fun introduce_annotations subst spots t t' =
smolkas@50258
   155
  let
smolkas@52452
   156
    fun subst_atype (T as TVar (idxn, S)) subst =
blanchet@54821
   157
        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
smolkas@52452
   158
      | subst_atype T subst = (T, subst)
blanchet@55243
   159
smolkas@52452
   160
    val subst_type = fold_map_atypes subst_atype
blanchet@55243
   161
smolkas@52452
   162
    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
smolkas@50258
   163
        if p <> cp then
smolkas@52452
   164
          (subst, cp + 1, ps, annots)
smolkas@50258
   165
        else
smolkas@52452
   166
          let val (T, subst) = subst_type T subst in
blanchet@55243
   167
            (subst, cp + 1, ps', (p, T) :: annots)
smolkas@50258
   168
          end
smolkas@52452
   169
      | collect_annot _ _ x = x
blanchet@55243
   170
blanchet@55243
   171
    val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
blanchet@55243
   172
smolkas@52452
   173
    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
blanchet@54821
   174
        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
smolkas@52452
   175
      | insert_annot t _ x = (t, x)
smolkas@52110
   176
  in
blanchet@55243
   177
    t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
smolkas@52110
   178
  end
smolkas@50258
   179
blanchet@55213
   180
fun annotate_types_in_term ctxt t =
smolkas@50258
   181
  let
smolkas@50258
   182
    val t' = generalize_types ctxt t
smolkas@52452
   183
    val subst = match_types ctxt t' t
smolkas@52452
   184
    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
blanchet@55243
   185
    val typing_spots = t' |> typing_spot_table |> reverse_greedy |> sort int_ord
blanchet@55243
   186
  in
blanchet@55243
   187
    introduce_annotations subst typing_spots t t'
blanchet@55243
   188
  end
smolkas@50258
   189
blanchet@54504
   190
end;