src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
changeset 55202 824c48a539c9
parent 55201 1ee776da8da7
child 55203 e872d196a73b
equal deleted inserted replaced
55201:1ee776da8da7 55202:824c48a539c9
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
       
     2     Author:     Jasmin Blanchette, TU Muenchen
       
     3     Author:     Steffen Juilf Smolka, TU Muenchen
       
     4 
       
     5 Supplements term with a locally minmal, complete set of type constraints.
       
     6 Complete: The constraints suffice to infer the term's types.
       
     7 Minimal: Reducing the set of constraints further will make it incomplete.
       
     8 
       
     9 When configuring the pretty printer appropriately, the constraints will show up
       
    10 as type annotations when printing the term. This allows the term to be printed
       
    11 and reparsed without a change of types.
       
    12 
       
    13 NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
       
    14 syntax.
       
    15 *)
       
    16 
       
    17 signature SLEDGEHAMMER_ANNOTATE =
       
    18 sig
       
    19   val annotate_types : Proof.context -> term -> term
       
    20 end;
       
    21 
       
    22 structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
       
    23 struct
       
    24 
       
    25 (* Util *)
       
    26 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
       
    27   | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
       
    28   | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
       
    29   | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
       
    30   | post_traverse_term_type' f env (Abs (x, T1, b)) s =
       
    31     let
       
    32       val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
       
    33     in f (Abs (x, T1, b')) (T1 --> T2) s' end
       
    34   | post_traverse_term_type' f env (u $ v) s =
       
    35     let
       
    36       val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
       
    37       val ((v', s''), _) = post_traverse_term_type' f env v s'
       
    38     in f (u' $ v') T s'' end
       
    39     handle Bind => raise Fail "Sledgehammer_Annotate: post_traverse_term_type'"
       
    40 
       
    41 fun post_traverse_term_type f s t =
       
    42   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
       
    43 fun post_fold_term_type f s t =
       
    44   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
       
    45 
       
    46 fun fold_map_atypes f T s =
       
    47   case T of
       
    48     Type (name, Ts) =>
       
    49         let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
       
    50           (Type (name, Ts), s)
       
    51         end
       
    52   | _ => f T s
       
    53 
       
    54 (** get unique elements of a list **)
       
    55 local
       
    56   fun unique' b x [] = if b then [x] else []
       
    57     | unique' b x (y :: ys) =
       
    58       if x = y then unique' false x ys
       
    59       else unique' true y ys |> b ? cons x
       
    60 in
       
    61   fun unique ord xs =
       
    62     case sort ord xs of x :: ys => unique' true x ys | [] => []
       
    63 end
       
    64 
       
    65 (** Data structures, orders **)
       
    66 val indexname_ord = Term_Ord.fast_indexname_ord
       
    67 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
       
    68 structure Var_Set_Tab = Table(
       
    69   type key = indexname list
       
    70   val ord = list_ord indexname_ord)
       
    71 
       
    72 (* (1) Generalize types *)
       
    73 fun generalize_types ctxt t =
       
    74   let
       
    75     val erase_types = map_types (fn _ => dummyT)
       
    76     (* use schematic type variables *)
       
    77     val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
       
    78     val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
       
    79   in
       
    80      t |> erase_types |> infer_types
       
    81   end
       
    82 
       
    83 (* (2) match types *)
       
    84 fun match_types ctxt t1 t2 =
       
    85   let
       
    86     val thy = Proof_Context.theory_of ctxt
       
    87     val get_types = post_fold_term_type (K cons) []
       
    88   in
       
    89     fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
       
    90     handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Annotate: match_types"
       
    91   end
       
    92 
       
    93 
       
    94 (* (3) handle trivial tfrees  *)
       
    95 fun handle_trivial_tfrees ctxt (t', subst) =
       
    96   let
       
    97     val add_tfree_names =
       
    98       snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
       
    99 
       
   100     val trivial_tfree_names =
       
   101       Vartab.fold add_tfree_names subst []
       
   102       |> filter_out (Variable.is_declared ctxt)
       
   103       |> unique fast_string_ord
       
   104     val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
       
   105 
       
   106     val trivial_tvar_names =
       
   107       Vartab.fold
       
   108         (fn (tvar_name, (_, TFree (tfree_name, _))) =>
       
   109                tfree_name_trivial tfree_name ? cons tvar_name
       
   110           | _ => I)
       
   111         subst
       
   112         []
       
   113       |> sort indexname_ord
       
   114     val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
       
   115 
       
   116     val t' =
       
   117       t' |> map_types
       
   118               (map_type_tvar
       
   119                 (fn (idxn, sort) =>
       
   120                   if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
       
   121 
       
   122     val subst =
       
   123       subst |> fold Vartab.delete trivial_tvar_names
       
   124             |> Vartab.map
       
   125                (K (apsnd (map_type_tfree
       
   126                            (fn (name, sort) =>
       
   127                               if tfree_name_trivial name then dummyT
       
   128                               else TFree (name, sort)))))
       
   129   in
       
   130     (t', subst)
       
   131   end
       
   132 
       
   133 
       
   134 (* (4) Typing-spot table *)
       
   135 local
       
   136 fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
       
   137   | key_of_atype _ = I
       
   138 fun key_of_type T = fold_atyps key_of_atype T []
       
   139 fun update_tab t T (tab, pos) =
       
   140   (case key_of_type T of
       
   141      [] => tab
       
   142    | key =>
       
   143      let val cost = (size_of_typ T, (size_of_term t, pos)) in
       
   144        case Var_Set_Tab.lookup tab key of
       
   145          NONE => Var_Set_Tab.update_new (key, cost) tab
       
   146        | SOME old_cost =>
       
   147          (case cost_ord (cost, old_cost) of
       
   148             LESS => Var_Set_Tab.update (key, cost) tab
       
   149           | _ => tab)
       
   150      end,
       
   151    pos + 1)
       
   152 in
       
   153 val typing_spot_table =
       
   154   post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
       
   155 end
       
   156 
       
   157 (* (5) Reverse-greedy *)
       
   158 fun reverse_greedy typing_spot_tab =
       
   159   let
       
   160     fun update_count z =
       
   161       fold (fn tvar => fn tab =>
       
   162         let val c = Vartab.lookup tab tvar |> the_default 0 in
       
   163           Vartab.update (tvar, c + z) tab
       
   164         end)
       
   165     fun superfluous tcount =
       
   166       forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
       
   167     fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
       
   168       if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
       
   169       else (spot :: spots, tcount)
       
   170     val (typing_spots, tvar_count_tab) =
       
   171       Var_Set_Tab.fold
       
   172         (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
       
   173         typing_spot_tab ([], Vartab.empty)
       
   174       |>> sort_distinct (rev_order o cost_ord o pairself snd)
       
   175   in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
       
   176 
       
   177 (* (6) Introduce annotations *)
       
   178 fun introduce_annotations subst spots t t' =
       
   179   let
       
   180     fun subst_atype (T as TVar (idxn, S)) subst =
       
   181         (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
       
   182       | subst_atype T subst = (T, subst)
       
   183     val subst_type = fold_map_atypes subst_atype
       
   184     fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
       
   185         if p <> cp then
       
   186           (subst, cp + 1, ps, annots)
       
   187         else
       
   188           let val (T, subst) = subst_type T subst in
       
   189             (subst, cp + 1, ps', (p, T)::annots)
       
   190           end
       
   191       | collect_annot _ _ x = x
       
   192     val (_, _, _, annots) =
       
   193       post_fold_term_type collect_annot (subst, 0, spots, []) t'
       
   194     fun insert_annot t _ (cp, annots as (p, T) :: annots') =
       
   195         if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
       
   196       | insert_annot t _ x = (t, x)
       
   197   in
       
   198     t |> post_traverse_term_type insert_annot (0, rev annots)
       
   199       |> fst
       
   200   end
       
   201 
       
   202 (* (7) Annotate *)
       
   203 fun annotate_types ctxt t =
       
   204   let
       
   205     val t' = generalize_types ctxt t
       
   206     val subst = match_types ctxt t' t
       
   207     val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
       
   208     val typing_spots =
       
   209       t' |> typing_spot_table
       
   210          |> reverse_greedy
       
   211          |> sort int_ord
       
   212   in introduce_annotations subst typing_spots t t' end
       
   213 
       
   214 end;