src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
changeset 55202 824c48a539c9
parent 55201 1ee776da8da7
child 55203 e872d196a73b
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Fri Jan 31 10:23:32 2014 +0100
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,214 +0,0 @@
-(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
-    Author:     Jasmin Blanchette, TU Muenchen
-    Author:     Steffen Juilf Smolka, TU Muenchen
-
-Supplements term with a locally minmal, complete set of type constraints.
-Complete: The constraints suffice to infer the term's types.
-Minimal: Reducing the set of constraints further will make it incomplete.
-
-When configuring the pretty printer appropriately, the constraints will show up
-as type annotations when printing the term. This allows the term to be printed
-and reparsed without a change of types.
-
-NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
-syntax.
-*)
-
-signature SLEDGEHAMMER_ANNOTATE =
-sig
-  val annotate_types : Proof.context -> term -> term
-end;
-
-structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
-struct
-
-(* Util *)
-fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
-  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
-  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
-  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
-  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
-    let
-      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
-    in f (Abs (x, T1, b')) (T1 --> T2) s' end
-  | post_traverse_term_type' f env (u $ v) s =
-    let
-      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
-      val ((v', s''), _) = post_traverse_term_type' f env v s'
-    in f (u' $ v') T s'' end
-    handle Bind => raise Fail "Sledgehammer_Annotate: post_traverse_term_type'"
-
-fun post_traverse_term_type f s t =
-  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
-fun post_fold_term_type f s t =
-  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
-
-fun fold_map_atypes f T s =
-  case T of
-    Type (name, Ts) =>
-        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
-          (Type (name, Ts), s)
-        end
-  | _ => f T s
-
-(** get unique elements of a list **)
-local
-  fun unique' b x [] = if b then [x] else []
-    | unique' b x (y :: ys) =
-      if x = y then unique' false x ys
-      else unique' true y ys |> b ? cons x
-in
-  fun unique ord xs =
-    case sort ord xs of x :: ys => unique' true x ys | [] => []
-end
-
-(** Data structures, orders **)
-val indexname_ord = Term_Ord.fast_indexname_ord
-val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
-structure Var_Set_Tab = Table(
-  type key = indexname list
-  val ord = list_ord indexname_ord)
-
-(* (1) Generalize types *)
-fun generalize_types ctxt t =
-  let
-    val erase_types = map_types (fn _ => dummyT)
-    (* use schematic type variables *)
-    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
-    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
-  in
-     t |> erase_types |> infer_types
-  end
-
-(* (2) match types *)
-fun match_types ctxt t1 t2 =
-  let
-    val thy = Proof_Context.theory_of ctxt
-    val get_types = post_fold_term_type (K cons) []
-  in
-    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
-    handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Annotate: match_types"
-  end
-
-
-(* (3) handle trivial tfrees  *)
-fun handle_trivial_tfrees ctxt (t', subst) =
-  let
-    val add_tfree_names =
-      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
-
-    val trivial_tfree_names =
-      Vartab.fold add_tfree_names subst []
-      |> filter_out (Variable.is_declared ctxt)
-      |> unique fast_string_ord
-    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
-
-    val trivial_tvar_names =
-      Vartab.fold
-        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
-               tfree_name_trivial tfree_name ? cons tvar_name
-          | _ => I)
-        subst
-        []
-      |> sort indexname_ord
-    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
-
-    val t' =
-      t' |> map_types
-              (map_type_tvar
-                (fn (idxn, sort) =>
-                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
-
-    val subst =
-      subst |> fold Vartab.delete trivial_tvar_names
-            |> Vartab.map
-               (K (apsnd (map_type_tfree
-                           (fn (name, sort) =>
-                              if tfree_name_trivial name then dummyT
-                              else TFree (name, sort)))))
-  in
-    (t', subst)
-  end
-
-
-(* (4) Typing-spot table *)
-local
-fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
-  | key_of_atype _ = I
-fun key_of_type T = fold_atyps key_of_atype T []
-fun update_tab t T (tab, pos) =
-  (case key_of_type T of
-     [] => tab
-   | key =>
-     let val cost = (size_of_typ T, (size_of_term t, pos)) in
-       case Var_Set_Tab.lookup tab key of
-         NONE => Var_Set_Tab.update_new (key, cost) tab
-       | SOME old_cost =>
-         (case cost_ord (cost, old_cost) of
-            LESS => Var_Set_Tab.update (key, cost) tab
-          | _ => tab)
-     end,
-   pos + 1)
-in
-val typing_spot_table =
-  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
-end
-
-(* (5) Reverse-greedy *)
-fun reverse_greedy typing_spot_tab =
-  let
-    fun update_count z =
-      fold (fn tvar => fn tab =>
-        let val c = Vartab.lookup tab tvar |> the_default 0 in
-          Vartab.update (tvar, c + z) tab
-        end)
-    fun superfluous tcount =
-      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
-    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
-      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
-      else (spot :: spots, tcount)
-    val (typing_spots, tvar_count_tab) =
-      Var_Set_Tab.fold
-        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
-        typing_spot_tab ([], Vartab.empty)
-      |>> sort_distinct (rev_order o cost_ord o pairself snd)
-  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
-
-(* (6) Introduce annotations *)
-fun introduce_annotations subst spots t t' =
-  let
-    fun subst_atype (T as TVar (idxn, S)) subst =
-        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
-      | subst_atype T subst = (T, subst)
-    val subst_type = fold_map_atypes subst_atype
-    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
-        if p <> cp then
-          (subst, cp + 1, ps, annots)
-        else
-          let val (T, subst) = subst_type T subst in
-            (subst, cp + 1, ps', (p, T)::annots)
-          end
-      | collect_annot _ _ x = x
-    val (_, _, _, annots) =
-      post_fold_term_type collect_annot (subst, 0, spots, []) t'
-    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
-        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
-      | insert_annot t _ x = (t, x)
-  in
-    t |> post_traverse_term_type insert_annot (0, rev annots)
-      |> fst
-  end
-
-(* (7) Annotate *)
-fun annotate_types ctxt t =
-  let
-    val t' = generalize_types ctxt t
-    val subst = match_types ctxt t' t
-    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
-    val typing_spots =
-      t' |> typing_spot_table
-         |> reverse_greedy
-         |> sort int_ord
-  in introduce_annotations subst typing_spots t t' end
-
-end;