src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
author blanchet
Fri Jan 31 16:10:39 2014 +0100 (2014-01-31 ago)
changeset 55212 5832470d956e
parent 55205 8450622db0c5
child 55213 dcb36a2540bc
permissions -rw-r--r--
tuning
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
     2     Author:     Steffen Juilf Smolka, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4 
     5 Supplements term with a locally minmal, complete set of type constraints. Complete: The constraints
     6 suffice to infer the term's types. Minimal: Reducing the set of constraints further will make it
     7 incomplete.
     8 
     9 When configuring the pretty printer appropriately, the constraints will show up as type annotations
    10 when printing the term. This allows the term to be printed and reparsed without a change of types.
    11 
    12 NOTE: Terms should be unchecked before calling annotate_types to avoid awkward syntax.
    13 *)
    14 
    15 signature SLEDGEHAMMER_ISAR_ANNOTATE =
    16 sig
    17   val annotate_types : Proof.context -> term -> term
    18 end;
    19 
    20 structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
    21 struct
    22 
    23 (* Util *)
    24 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    25   | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    26   | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    27   | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    28   | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    29     let
    30       val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    31     in f (Abs (x, T1, b')) (T1 --> T2) s' end
    32   | post_traverse_term_type' f env (u $ v) s =
    33     let
    34       val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    35       val ((v', s''), _) = post_traverse_term_type' f env v s'
    36     in f (u' $ v') T s'' end
    37     handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
    38 
    39 fun post_traverse_term_type f s t =
    40   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    41 fun post_fold_term_type f s t =
    42   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    43 
    44 fun fold_map_atypes f T s =
    45   case T of
    46     Type (name, Ts) =>
    47         let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
    48           (Type (name, Ts), s)
    49         end
    50   | _ => f T s
    51 
    52 (** get unique elements of a list **)
    53 local
    54   fun unique' b x [] = if b then [x] else []
    55     | unique' b x (y :: ys) =
    56       if x = y then unique' false x ys
    57       else unique' true y ys |> b ? cons x
    58 in
    59   fun unique ord xs =
    60     case sort ord xs of x :: ys => unique' true x ys | [] => []
    61 end
    62 
    63 (** Data structures, orders **)
    64 val indexname_ord = Term_Ord.fast_indexname_ord
    65 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    66 structure Var_Set_Tab = Table(
    67   type key = indexname list
    68   val ord = list_ord indexname_ord)
    69 
    70 (* (1) Generalize types *)
    71 fun generalize_types ctxt t =
    72   let
    73     val erase_types = map_types (fn _ => dummyT)
    74     (* use schematic type variables *)
    75     val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
    76     val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
    77   in
    78      t |> erase_types |> infer_types
    79   end
    80 
    81 (* (2) match types *)
    82 fun match_types ctxt t1 t2 =
    83   let
    84     val thy = Proof_Context.theory_of ctxt
    85     val get_types = post_fold_term_type (K cons) []
    86   in
    87     fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
    88     handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
    89   end
    90 
    91 
    92 (* (3) handle trivial tfrees  *)
    93 fun handle_trivial_tfrees ctxt (t', subst) =
    94   let
    95     val add_tfree_names =
    96       snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
    97 
    98     val trivial_tfree_names =
    99       Vartab.fold add_tfree_names subst []
   100       |> filter_out (Variable.is_declared ctxt)
   101       |> unique fast_string_ord
   102     val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
   103 
   104     val trivial_tvar_names =
   105       Vartab.fold
   106         (fn (tvar_name, (_, TFree (tfree_name, _))) =>
   107                tfree_name_trivial tfree_name ? cons tvar_name
   108           | _ => I)
   109         subst
   110         []
   111       |> sort indexname_ord
   112     val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
   113 
   114     val t' =
   115       t' |> map_types
   116               (map_type_tvar
   117                 (fn (idxn, sort) =>
   118                   if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
   119 
   120     val subst =
   121       subst |> fold Vartab.delete trivial_tvar_names
   122             |> Vartab.map
   123                (K (apsnd (map_type_tfree
   124                            (fn (name, sort) =>
   125                               if tfree_name_trivial name then dummyT
   126                               else TFree (name, sort)))))
   127   in
   128     (t', subst)
   129   end
   130 
   131 (* (4) Typing-spot table *)
   132 local
   133 fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
   134   | key_of_atype _ = I
   135 fun key_of_type T = fold_atyps key_of_atype T []
   136 fun update_tab t T (tab, pos) =
   137   (case key_of_type T of
   138      [] => tab
   139    | key =>
   140      let val cost = (size_of_typ T, (size_of_term t, pos)) in
   141        case Var_Set_Tab.lookup tab key of
   142          NONE => Var_Set_Tab.update_new (key, cost) tab
   143        | SOME old_cost =>
   144          (case cost_ord (cost, old_cost) of
   145             LESS => Var_Set_Tab.update (key, cost) tab
   146           | _ => tab)
   147      end,
   148    pos + 1)
   149 in
   150 val typing_spot_table =
   151   post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   152 end
   153 
   154 (* (5) Reverse-greedy *)
   155 fun reverse_greedy typing_spot_tab =
   156   let
   157     fun update_count z =
   158       fold (fn tvar => fn tab =>
   159         let val c = Vartab.lookup tab tvar |> the_default 0 in
   160           Vartab.update (tvar, c + z) tab
   161         end)
   162     fun superfluous tcount =
   163       forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
   164     fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
   165       if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
   166       else (spot :: spots, tcount)
   167     val (typing_spots, tvar_count_tab) =
   168       Var_Set_Tab.fold
   169         (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   170         typing_spot_tab ([], Vartab.empty)
   171       |>> sort_distinct (rev_order o cost_ord o pairself snd)
   172   in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
   173 
   174 (* (6) Introduce annotations *)
   175 fun introduce_annotations subst spots t t' =
   176   let
   177     fun subst_atype (T as TVar (idxn, S)) subst =
   178         (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
   179       | subst_atype T subst = (T, subst)
   180     val subst_type = fold_map_atypes subst_atype
   181     fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
   182         if p <> cp then
   183           (subst, cp + 1, ps, annots)
   184         else
   185           let val (T, subst) = subst_type T subst in
   186             (subst, cp + 1, ps', (p, T)::annots)
   187           end
   188       | collect_annot _ _ x = x
   189     val (_, _, _, annots) =
   190       post_fold_term_type collect_annot (subst, 0, spots, []) t'
   191     fun insert_annot t _ (cp, annots as (p, T) :: annots') =
   192         if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
   193       | insert_annot t _ x = (t, x)
   194   in
   195     t |> post_traverse_term_type insert_annot (0, rev annots)
   196       |> fst
   197   end
   198 
   199 (* (7) Annotate *)
   200 fun annotate_types ctxt t =
   201   let
   202     val t' = generalize_types ctxt t
   203     val subst = match_types ctxt t' t
   204     val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
   205     val typing_spots =
   206       t' |> typing_spot_table
   207          |> reverse_greedy
   208          |> sort int_ord
   209   in introduce_annotations subst typing_spots t t' end
   210 
   211 end;