src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
author smolkas
Tue Jun 11 19:58:09 2013 -0400 (2013-06-11 ago)
changeset 52369 0b395800fdf0
parent 52366 ff89424b5094
child 52452 2207825d67f3
permissions -rw-r--r--
uncheck terms before annotation to avoid awkward syntax
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Author:     Steffen Juilf Smolka, TU Muenchen
     4 
     5 Supplements term with a locally minmal, complete set of type constraints.
     6 Complete: The constraints suffice to infer the term's types.
     7 Minimal: Reducing the set of constraints further will make it incomplete.
     8 
     9 When configuring the pretty printer appropriately, the constraints will show up
    10 as type annotations when printing the term. This allows the term to be reparsed
    11 without a change of types.
    12 
    13 NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
    14 syntax.
    15 *)
    16 
    17 signature SLEDGEHAMMER_ANNOTATE =
    18 sig
    19   val annotate_types : Proof.context -> term -> term
    20 end
    21 
    22 structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
    23 struct
    24 
    25 (* Util *)
    26 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    27   | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    28   | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    29   | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    30   | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    31     let
    32       val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    33     in f (Abs (x, T1, b')) (T1 --> T2) s' end
    34   | post_traverse_term_type' f env (u $ v) s =
    35     let
    36       val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    37       val ((v', s''), _) = post_traverse_term_type' f env v s'
    38     in f (u' $ v') T s'' end
    39 
    40 fun post_traverse_term_type f s t =
    41   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    42 fun post_fold_term_type f s t =
    43   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    44 
    45 (* Data structures, orders *)
    46 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    47 structure Var_Set_Tab = Table(
    48   type key = indexname list
    49   val ord = list_ord Term_Ord.fast_indexname_ord)
    50 
    51 (* (1) Generalize types *)
    52 fun generalize_types ctxt t =
    53   let
    54     val erase_types = map_types (fn _ => dummyT)
    55     (* use schematic type variables *)
    56     val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
    57     val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
    58   in
    59      t |> erase_types |> infer_types
    60   end
    61 
    62 (* (2) Typing-spot table *)
    63 local
    64 fun key_of_atype (TVar (z, _)) =
    65     Ord_List.insert Term_Ord.fast_indexname_ord z
    66   | key_of_atype _ = I
    67 fun key_of_type T = fold_atyps key_of_atype T []
    68 fun update_tab t T (tab, pos) =
    69   (case key_of_type T of
    70      [] => tab
    71    | key =>
    72      let val cost = (size_of_typ T, (size_of_term t, pos)) in
    73        case Var_Set_Tab.lookup tab key of
    74          NONE => Var_Set_Tab.update_new (key, cost) tab
    75        | SOME old_cost =>
    76          (case cost_ord (cost, old_cost) of
    77             LESS => Var_Set_Tab.update (key, cost) tab
    78           | _ => tab)
    79      end,
    80    pos + 1)
    81 in
    82 val typing_spot_table =
    83   post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
    84 end
    85 
    86 (* (3) Reverse-greedy *)
    87 fun reverse_greedy typing_spot_tab =
    88   let
    89     fun update_count z =
    90       fold (fn tvar => fn tab =>
    91         let val c = Vartab.lookup tab tvar |> the_default 0 in
    92           Vartab.update (tvar, c + z) tab
    93         end)
    94     fun superfluous tcount =
    95       forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
    96     fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
    97       if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
    98       else (spot :: spots, tcount)
    99     val (typing_spots, tvar_count_tab) =
   100       Var_Set_Tab.fold
   101         (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   102         typing_spot_tab ([], Vartab.empty)
   103       |>> sort_distinct (rev_order o cost_ord o pairself snd)
   104   in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
   105 
   106 (* (4) Introduce annotations *)
   107 fun introduce_annotations ctxt spots t t' =
   108   let
   109     val thy = Proof_Context.theory_of ctxt
   110     val get_types = post_fold_term_type (K cons) []
   111     fun match_types tp =
   112       fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
   113     fun unica' b x [] = if b then [x] else []
   114       | unica' b x (y :: ys) =
   115         if x = y then unica' false x ys
   116         else unica' true y ys |> b ? cons x
   117     fun unica ord xs =
   118       case sort ord xs of x :: ys => unica' true x ys | [] => []
   119     val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
   120     fun erase_unica_tfrees env =
   121       let
   122         val unica =
   123           Vartab.fold (add_all_tfree_namesT o snd o snd) env []
   124           |> filter_out (Variable.is_declared ctxt)
   125           |> unica fast_string_ord
   126         val erase_unica = map_atyps
   127           (fn T as TFree (s, _) =>
   128               if Ord_List.member fast_string_ord unica s then dummyT else T
   129             | T => T)
   130       in Vartab.map (K (apsnd erase_unica)) env end
   131     val env = match_types (t', t) |> erase_unica_tfrees
   132     fun get_annot env (TFree _) = (false, (env, dummyT))
   133       | get_annot env (T as TVar (v, S)) =
   134         let val T' = Envir.subst_type env T in
   135           if T' = dummyT then (false, (env, dummyT))
   136           else (true, (Vartab.update (v, (S, dummyT)) env, T'))
   137         end
   138       | get_annot env (Type (S, Ts)) =
   139         (case fold_rev (fn T => fn (b, (env, Ts)) =>
   140                   let
   141                     val (b', (env', T)) = get_annot env T
   142                   in (b orelse b', (env', T :: Ts)) end)
   143                 Ts (false, (env, [])) of
   144            (true, (env', Ts)) => (true, (env', Type (S, Ts)))
   145          | (false, (env', _)) => (false, (env', dummyT)))
   146     fun post1 _ T (env, cp, ps as p :: ps', annots) =
   147         if p <> cp then
   148           (env, cp + 1, ps, annots)
   149         else
   150           let val (annot_necessary, (env', T')) = get_annot env T in
   151             (env', cp + 1, ps', annots |> annot_necessary ? cons (p, T'))
   152           end
   153       | post1 _ _ accum = accum
   154     val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   155     fun post2 t _ (cp, annots as (p, T) :: annots') =
   156         if p <> cp then (t, (cp + 1, annots))
   157         else (Type.constraint T t, (cp + 1, annots'))
   158       | post2 t _ x = (t, x)
   159   in
   160     t |> post_traverse_term_type post2 (0, rev annots)
   161       |> fst
   162   end
   163 
   164 (* (5) Annotate *)
   165 fun annotate_types ctxt t =
   166   let
   167     val t' = generalize_types ctxt t
   168     val typing_spots =
   169       t' |> typing_spot_table
   170          |> reverse_greedy
   171          |> sort int_ord
   172   in introduce_annotations ctxt typing_spots t t' end
   173 
   174 end