(* Title: HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
Author: Steffen Juilf Smolka, TU Muenchen
Author: Jasmin Blanchette, TU Muenchen
Supplements term with a locally minimal, complete set of type constraints.
Complete: The constraints suffice to infer the term's types. Minimal: Reducing
the set of constraints further will make it incomplete.
When configuring the pretty printer appropriately, the constraints will show up
as type annotations when printing the term. This allows the term to be printed
and reparsed without a change of types.
Terms should be unchecked before calling "annotate_types_in_term" to avoid
awkward syntax.
*)
signature SLEDGEHAMMER_ISAR_ANNOTATE =
sig
val annotate_types_in_term : Proof.context -> term -> term
end;
structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
struct
fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
| post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
| post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
| post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
| post_traverse_term_type' f env (Abs (x, T1, b)) s =
let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
f (Abs (x, T1, b')) (T1 --> T2) s'
end
| post_traverse_term_type' f env (u $ v) s =
let
val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
val ((v', s''), _) = post_traverse_term_type' f env v s'
in f (u' $ v') T s'' end
handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
fun post_traverse_term_type f s t =
post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
fun post_fold_term_type f s t =
post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
fun fold_map_atypes f T s =
(case T of
Type (name, Ts) =>
let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
(Type (name, Ts), s)
end
| _ => f T s)
val indexname_ord = Term_Ord.fast_indexname_ord
val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
structure Var_Set_Tab = Table(
type key = indexname list
val ord = list_ord indexname_ord)
fun generalize_types ctxt t =
let
val erase_types = map_types (fn _ => dummyT)
(* use schematic type variables *)
val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
in
t |> erase_types |> infer_types
end
fun match_types ctxt t1 t2 =
let
val thy = Proof_Context.theory_of ctxt
val get_types = post_fold_term_type (K cons) []
in
fold (perhaps o try o Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
end
fun handle_trivial_tfrees ctxt t' subst =
let
val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
val trivial_tfree_names =
Vartab.fold add_tfree_names subst []
|> filter_out (Variable.is_declared ctxt)
|> distinct (op =)
val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
val trivial_tvar_names =
Vartab.fold
(fn (tvar_name, (_, TFree (tfree_name, _))) =>
tfree_name_trivial tfree_name ? cons tvar_name
| _ => I)
subst
[]
|> sort indexname_ord
val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
val t' =
t' |> map_types
(map_type_tvar
(fn (idxn, sort) =>
if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
val subst =
subst |> fold Vartab.delete trivial_tvar_names
|> Vartab.map
(K (apsnd (map_type_tfree
(fn (name, sort) =>
if tfree_name_trivial name then dummyT
else TFree (name, sort)))))
in
(t', subst)
end
fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
| key_of_atype _ = I
fun key_of_type T = fold_atyps key_of_atype T []
fun update_tab t T (tab, pos) =
((case key_of_type T of
[] => tab
| key =>
let val cost = (size_of_typ T, (size_of_term t, pos)) in
(case Var_Set_Tab.lookup tab key of
NONE => Var_Set_Tab.update_new (key, cost) tab
| SOME old_cost =>
(case cost_ord (cost, old_cost) of
LESS => Var_Set_Tab.update (key, cost) tab
| _ => tab))
end),
pos + 1)
val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
fun reverse_greedy typing_spot_tab =
let
fun update_count z =
fold (fn tvar => fn tab =>
let val c = Vartab.lookup tab tvar |> the_default 0 in
Vartab.update (tvar, c + z) tab
end)
fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
else (spot :: spots, tcount)
val (typing_spots, tvar_count_tab) =
Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
typing_spot_tab ([], Vartab.empty)
|>> sort_distinct (rev_order o cost_ord o apply2 snd)
in
fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
end
fun introduce_annotations subst spots t t' =
let
fun subst_atype (T as TVar (idxn, S)) subst =
(Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
| subst_atype T subst = (T, subst)
val subst_type = fold_map_atypes subst_atype
fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
if p <> cp then
(subst, cp + 1, ps, annots)
else
let val (T, subst) = subst_type T subst in
(subst, cp + 1, ps', (p, T) :: annots)
end
| collect_annot _ _ x = x
val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
fun insert_annot t _ (cp, annots as (p, T) :: annots') =
if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
| insert_annot t _ x = (t, x)
in
t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
end
fun annotate_types_in_term ctxt t =
let
val t' = generalize_types ctxt t
val subst = match_types ctxt t' t
val (t'', subst') = handle_trivial_tfrees ctxt t' subst
val typing_spots = t'' |> typing_spot_table |> reverse_greedy |> sort int_ord
in
introduce_annotations subst' typing_spots t t''
end
end;