1 (* Title: HOL/Tools/Sledgehammer/sledgehammer_annotate.ML |
|
2 Author: Jasmin Blanchette, TU Muenchen |
|
3 Author: Steffen Juilf Smolka, TU Muenchen |
|
4 |
|
5 Supplements term with a locally minmal, complete set of type constraints. |
|
6 Complete: The constraints suffice to infer the term's types. |
|
7 Minimal: Reducing the set of constraints further will make it incomplete. |
|
8 |
|
9 When configuring the pretty printer appropriately, the constraints will show up |
|
10 as type annotations when printing the term. This allows the term to be printed |
|
11 and reparsed without a change of types. |
|
12 |
|
13 NOTE: Terms should be unchecked before calling annotate_types to avoid awkward |
|
14 syntax. |
|
15 *) |
|
16 |
|
17 signature SLEDGEHAMMER_ANNOTATE = |
|
18 sig |
|
19 val annotate_types : Proof.context -> term -> term |
|
20 end; |
|
21 |
|
22 structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE = |
|
23 struct |
|
24 |
|
25 (* Util *) |
|
26 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s |
|
27 | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s |
|
28 | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s |
|
29 | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s |
|
30 | post_traverse_term_type' f env (Abs (x, T1, b)) s = |
|
31 let |
|
32 val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s |
|
33 in f (Abs (x, T1, b')) (T1 --> T2) s' end |
|
34 | post_traverse_term_type' f env (u $ v) s = |
|
35 let |
|
36 val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s |
|
37 val ((v', s''), _) = post_traverse_term_type' f env v s' |
|
38 in f (u' $ v') T s'' end |
|
39 handle Bind => raise Fail "Sledgehammer_Annotate: post_traverse_term_type'" |
|
40 |
|
41 fun post_traverse_term_type f s t = |
|
42 post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst |
|
43 fun post_fold_term_type f s t = |
|
44 post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd |
|
45 |
|
46 fun fold_map_atypes f T s = |
|
47 case T of |
|
48 Type (name, Ts) => |
|
49 let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in |
|
50 (Type (name, Ts), s) |
|
51 end |
|
52 | _ => f T s |
|
53 |
|
54 (** get unique elements of a list **) |
|
55 local |
|
56 fun unique' b x [] = if b then [x] else [] |
|
57 | unique' b x (y :: ys) = |
|
58 if x = y then unique' false x ys |
|
59 else unique' true y ys |> b ? cons x |
|
60 in |
|
61 fun unique ord xs = |
|
62 case sort ord xs of x :: ys => unique' true x ys | [] => [] |
|
63 end |
|
64 |
|
65 (** Data structures, orders **) |
|
66 val indexname_ord = Term_Ord.fast_indexname_ord |
|
67 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord) |
|
68 structure Var_Set_Tab = Table( |
|
69 type key = indexname list |
|
70 val ord = list_ord indexname_ord) |
|
71 |
|
72 (* (1) Generalize types *) |
|
73 fun generalize_types ctxt t = |
|
74 let |
|
75 val erase_types = map_types (fn _ => dummyT) |
|
76 (* use schematic type variables *) |
|
77 val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern |
|
78 val infer_types = singleton (Type_Infer_Context.infer_types ctxt) |
|
79 in |
|
80 t |> erase_types |> infer_types |
|
81 end |
|
82 |
|
83 (* (2) match types *) |
|
84 fun match_types ctxt t1 t2 = |
|
85 let |
|
86 val thy = Proof_Context.theory_of ctxt |
|
87 val get_types = post_fold_term_type (K cons) [] |
|
88 in |
|
89 fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty |
|
90 handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Annotate: match_types" |
|
91 end |
|
92 |
|
93 |
|
94 (* (3) handle trivial tfrees *) |
|
95 fun handle_trivial_tfrees ctxt (t', subst) = |
|
96 let |
|
97 val add_tfree_names = |
|
98 snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I) |
|
99 |
|
100 val trivial_tfree_names = |
|
101 Vartab.fold add_tfree_names subst [] |
|
102 |> filter_out (Variable.is_declared ctxt) |
|
103 |> unique fast_string_ord |
|
104 val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names |
|
105 |
|
106 val trivial_tvar_names = |
|
107 Vartab.fold |
|
108 (fn (tvar_name, (_, TFree (tfree_name, _))) => |
|
109 tfree_name_trivial tfree_name ? cons tvar_name |
|
110 | _ => I) |
|
111 subst |
|
112 [] |
|
113 |> sort indexname_ord |
|
114 val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names |
|
115 |
|
116 val t' = |
|
117 t' |> map_types |
|
118 (map_type_tvar |
|
119 (fn (idxn, sort) => |
|
120 if tvar_name_trivial idxn then dummyT else TVar (idxn, sort))) |
|
121 |
|
122 val subst = |
|
123 subst |> fold Vartab.delete trivial_tvar_names |
|
124 |> Vartab.map |
|
125 (K (apsnd (map_type_tfree |
|
126 (fn (name, sort) => |
|
127 if tfree_name_trivial name then dummyT |
|
128 else TFree (name, sort))))) |
|
129 in |
|
130 (t', subst) |
|
131 end |
|
132 |
|
133 |
|
134 (* (4) Typing-spot table *) |
|
135 local |
|
136 fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z |
|
137 | key_of_atype _ = I |
|
138 fun key_of_type T = fold_atyps key_of_atype T [] |
|
139 fun update_tab t T (tab, pos) = |
|
140 (case key_of_type T of |
|
141 [] => tab |
|
142 | key => |
|
143 let val cost = (size_of_typ T, (size_of_term t, pos)) in |
|
144 case Var_Set_Tab.lookup tab key of |
|
145 NONE => Var_Set_Tab.update_new (key, cost) tab |
|
146 | SOME old_cost => |
|
147 (case cost_ord (cost, old_cost) of |
|
148 LESS => Var_Set_Tab.update (key, cost) tab |
|
149 | _ => tab) |
|
150 end, |
|
151 pos + 1) |
|
152 in |
|
153 val typing_spot_table = |
|
154 post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst |
|
155 end |
|
156 |
|
157 (* (5) Reverse-greedy *) |
|
158 fun reverse_greedy typing_spot_tab = |
|
159 let |
|
160 fun update_count z = |
|
161 fold (fn tvar => fn tab => |
|
162 let val c = Vartab.lookup tab tvar |> the_default 0 in |
|
163 Vartab.update (tvar, c + z) tab |
|
164 end) |
|
165 fun superfluous tcount = |
|
166 forall (fn tvar => the (Vartab.lookup tcount tvar) > 1) |
|
167 fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) = |
|
168 if superfluous tcount tvars then (spots, update_count ~1 tvars tcount) |
|
169 else (spot :: spots, tcount) |
|
170 val (typing_spots, tvar_count_tab) = |
|
171 Var_Set_Tab.fold |
|
172 (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k)) |
|
173 typing_spot_tab ([], Vartab.empty) |
|
174 |>> sort_distinct (rev_order o cost_ord o pairself snd) |
|
175 in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end |
|
176 |
|
177 (* (6) Introduce annotations *) |
|
178 fun introduce_annotations subst spots t t' = |
|
179 let |
|
180 fun subst_atype (T as TVar (idxn, S)) subst = |
|
181 (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst) |
|
182 | subst_atype T subst = (T, subst) |
|
183 val subst_type = fold_map_atypes subst_atype |
|
184 fun collect_annot _ T (subst, cp, ps as p :: ps', annots) = |
|
185 if p <> cp then |
|
186 (subst, cp + 1, ps, annots) |
|
187 else |
|
188 let val (T, subst) = subst_type T subst in |
|
189 (subst, cp + 1, ps', (p, T)::annots) |
|
190 end |
|
191 | collect_annot _ _ x = x |
|
192 val (_, _, _, annots) = |
|
193 post_fold_term_type collect_annot (subst, 0, spots, []) t' |
|
194 fun insert_annot t _ (cp, annots as (p, T) :: annots') = |
|
195 if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots')) |
|
196 | insert_annot t _ x = (t, x) |
|
197 in |
|
198 t |> post_traverse_term_type insert_annot (0, rev annots) |
|
199 |> fst |
|
200 end |
|
201 |
|
202 (* (7) Annotate *) |
|
203 fun annotate_types ctxt t = |
|
204 let |
|
205 val t' = generalize_types ctxt t |
|
206 val subst = match_types ctxt t' t |
|
207 val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt |
|
208 val typing_spots = |
|
209 t' |> typing_spot_table |
|
210 |> reverse_greedy |
|
211 |> sort int_ord |
|
212 in introduce_annotations subst typing_spots t t' end |
|
213 |
|
214 end; |
|