# HG changeset patch # User smolkas # Date 1372263881 -7200 # Node ID 2207825d67f31b29fa7f60cf68c1767d3282212e # Parent e64c1344f21b48b82bbea6890e6f6f908bb19d0e ommit trivial tfrees in annotations diff -r e64c1344f21b -r 2207825d67f3 src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML Tue Jun 25 17:13:09 2013 -0500 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML Wed Jun 26 18:24:41 2013 +0200 @@ -42,11 +42,31 @@ fun post_fold_term_type f s t = post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd -(* Data structures, orders *) +fun fold_map_atypes f T s = + case T of + Type (name, Ts) => + let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in + (Type (name, Ts), s) + end + | _ => f T s + +(** get unique elements of a list **) +local + fun unique' b x [] = if b then [x] else [] + | unique' b x (y :: ys) = + if x = y then unique' false x ys + else unique' true y ys |> b ? cons x +in + fun unique ord xs = + case sort ord xs of x :: ys => unique' true x ys | [] => [] +end + +(** Data structures, orders **) +val indexname_ord = Term_Ord.fast_indexname_ord val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord) structure Var_Set_Tab = Table( type key = indexname list - val ord = list_ord Term_Ord.fast_indexname_ord) + val ord = list_ord indexname_ord) (* (1) Generalize types *) fun generalize_types ctxt t = @@ -59,10 +79,61 @@ t |> erase_types |> infer_types end -(* (2) Typing-spot table *) +(* (2) match types *) +fun match_types ctxt t1 t2 = + let + val thy = Proof_Context.theory_of ctxt + val get_types = post_fold_term_type (K cons) [] + in + fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty + end + + +(* (3) handle trivial tfrees *) +fun handle_trivial_tfrees ctxt (t', subst) = + let + + val add_tfree_names = + snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I) + + val trivial_tfree_names = + Vartab.fold add_tfree_names subst [] + |> filter_out (Variable.is_declared ctxt) + |> unique fast_string_ord + val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names + + val trivial_tvar_names = + Vartab.fold + (fn (tvar_name, (_, TFree (tfree_name, _))) => + tfree_name_trivial tfree_name ? cons tvar_name + | _ => I) + subst + [] + |> sort indexname_ord + val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names + + val t' = + t' |> map_types + (map_type_tvar + (fn (idxn, sort) => + if tvar_name_trivial idxn then dummyT else TVar (idxn, sort))) + + val subst = + subst |> fold Vartab.delete trivial_tvar_names + |> Vartab.map + (K (apsnd (map_type_tfree + (fn (name, sort) => + if tfree_name_trivial name then dummyT + else TFree (name, sort))))) + in + (t', subst) + end + + +(* (4) Typing-spot table *) local fun key_of_atype (TVar (z, _)) = - Ord_List.insert Term_Ord.fast_indexname_ord z + Ord_List.insert indexname_ord z | key_of_atype _ = I fun key_of_type T = fold_atyps key_of_atype T [] fun update_tab t T (tab, pos) = @@ -83,7 +154,7 @@ post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst end -(* (3) Reverse-greedy *) +(* (5) Reverse-greedy *) fun reverse_greedy typing_spot_tab = let fun update_count z = @@ -103,72 +174,43 @@ |>> sort_distinct (rev_order o cost_ord o pairself snd) in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end -(* (4) Introduce annotations *) -fun introduce_annotations ctxt spots t t' = +(* (6) Introduce annotations *) +fun introduce_annotations subst spots t t' = let - val thy = Proof_Context.theory_of ctxt - val get_types = post_fold_term_type (K cons) [] - fun match_types tp = - fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty - fun unica' b x [] = if b then [x] else [] - | unica' b x (y :: ys) = - if x = y then unica' false x ys - else unica' true y ys |> b ? cons x - fun unica ord xs = - case sort ord xs of x :: ys => unica' true x ys | [] => [] - val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I) - fun erase_unica_tfrees env = - let - val unica = - Vartab.fold (add_all_tfree_namesT o snd o snd) env [] - |> filter_out (Variable.is_declared ctxt) - |> unica fast_string_ord - val erase_unica = map_atyps - (fn T as TFree (s, _) => - if Ord_List.member fast_string_ord unica s then dummyT else T - | T => T) - in Vartab.map (K (apsnd erase_unica)) env end - val env = match_types (t', t) |> erase_unica_tfrees - fun get_annot env (TFree _) = (false, (env, dummyT)) - | get_annot env (T as TVar (v, S)) = - let val T' = Envir.subst_type env T in - if T' = dummyT then (false, (env, dummyT)) - else (true, (Vartab.update (v, (S, dummyT)) env, T')) - end - | get_annot env (Type (S, Ts)) = - (case fold_rev (fn T => fn (b, (env, Ts)) => - let - val (b', (env', T)) = get_annot env T - in (b orelse b', (env', T :: Ts)) end) - Ts (false, (env, [])) of - (true, (env', Ts)) => (true, (env', Type (S, Ts))) - | (false, (env', _)) => (false, (env', dummyT))) - fun post1 _ T (env, cp, ps as p :: ps', annots) = + fun subst_atype (T as TVar (idxn, S)) subst = + (Envir.subst_type subst T, + Vartab.update (idxn, (S, dummyT)) subst) + | subst_atype T subst = (T, subst) + val subst_type = fold_map_atypes subst_atype + fun collect_annot _ T (subst, cp, ps as p :: ps', annots) = if p <> cp then - (env, cp + 1, ps, annots) + (subst, cp + 1, ps, annots) else - let val (annot_necessary, (env', T')) = get_annot env T in - (env', cp + 1, ps', annots |> annot_necessary ? cons (p, T')) + let val (T, subst) = subst_type T subst in + (subst, cp + 1, ps', (p, T)::annots) end - | post1 _ _ accum = accum - val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t' - fun post2 t _ (cp, annots as (p, T) :: annots') = + | collect_annot _ _ x = x + val (_, _, _, annots) = + post_fold_term_type collect_annot (subst, 0, spots, []) t' + fun insert_annot t _ (cp, annots as (p, T) :: annots') = if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots')) - | post2 t _ x = (t, x) + | insert_annot t _ x = (t, x) in - t |> post_traverse_term_type post2 (0, rev annots) + t |> post_traverse_term_type insert_annot (0, rev annots) |> fst end -(* (5) Annotate *) +(* (7) Annotate *) fun annotate_types ctxt t = let val t' = generalize_types ctxt t + val subst = match_types ctxt t' t + val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt val typing_spots = t' |> typing_spot_table |> reverse_greedy |> sort int_ord - in introduce_annotations ctxt typing_spots t t' end + in introduce_annotations subst typing_spots t t' end end