ommit trivial tfrees in annotations
authorsmolkas
Wed, 26 Jun 2013 18:24:41 +0200
changeset 52452 2207825d67f3
parent 52451 e64c1344f21b
child 52453 2cba5906d836
ommit trivial tfrees in annotations
src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Tue Jun 25 17:13:09 2013 -0500
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Wed Jun 26 18:24:41 2013 +0200
@@ -42,11 +42,31 @@
 fun post_fold_term_type f s t =
   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
 
-(* Data structures, orders *)
+fun fold_map_atypes f T s =
+  case T of
+    Type (name, Ts) =>
+        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
+          (Type (name, Ts), s)
+        end
+  | _ => f T s
+
+(** get unique elements of a list **)
+local
+  fun unique' b x [] = if b then [x] else []
+    | unique' b x (y :: ys) =
+      if x = y then unique' false x ys
+      else unique' true y ys |> b ? cons x
+in
+  fun unique ord xs =
+    case sort ord xs of x :: ys => unique' true x ys | [] => []
+end
+
+(** Data structures, orders **)
+val indexname_ord = Term_Ord.fast_indexname_ord
 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
 structure Var_Set_Tab = Table(
   type key = indexname list
-  val ord = list_ord Term_Ord.fast_indexname_ord)
+  val ord = list_ord indexname_ord)
 
 (* (1) Generalize types *)
 fun generalize_types ctxt t =
@@ -59,10 +79,61 @@
      t |> erase_types |> infer_types
   end
 
-(* (2) Typing-spot table *)
+(* (2) match types *)
+fun match_types ctxt t1 t2 =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    val get_types = post_fold_term_type (K cons) []
+  in
+    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
+  end
+
+
+(* (3) handle trivial tfrees  *)
+fun handle_trivial_tfrees ctxt (t', subst) =
+  let
+
+    val add_tfree_names =
+      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
+
+    val trivial_tfree_names =
+      Vartab.fold add_tfree_names subst []
+      |> filter_out (Variable.is_declared ctxt)
+      |> unique fast_string_ord
+    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
+
+    val trivial_tvar_names =
+      Vartab.fold
+        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
+               tfree_name_trivial tfree_name ? cons tvar_name
+          | _ => I)
+        subst
+        []
+      |> sort indexname_ord
+    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
+
+    val t' =
+      t' |> map_types
+              (map_type_tvar
+                (fn (idxn, sort) =>
+                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
+
+    val subst =
+      subst |> fold Vartab.delete trivial_tvar_names
+            |> Vartab.map
+               (K (apsnd (map_type_tfree
+                           (fn (name, sort) =>
+                              if tfree_name_trivial name then dummyT
+                              else TFree (name, sort)))))
+  in
+    (t', subst)
+  end
+
+
+(* (4) Typing-spot table *)
 local
 fun key_of_atype (TVar (z, _)) =
-    Ord_List.insert Term_Ord.fast_indexname_ord z
+    Ord_List.insert indexname_ord z
   | key_of_atype _ = I
 fun key_of_type T = fold_atyps key_of_atype T []
 fun update_tab t T (tab, pos) =
@@ -83,7 +154,7 @@
   post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
 end
 
-(* (3) Reverse-greedy *)
+(* (5) Reverse-greedy *)
 fun reverse_greedy typing_spot_tab =
   let
     fun update_count z =
@@ -103,72 +174,43 @@
       |>> sort_distinct (rev_order o cost_ord o pairself snd)
   in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
 
-(* (4) Introduce annotations *)
-fun introduce_annotations ctxt spots t t' =
+(* (6) Introduce annotations *)
+fun introduce_annotations subst spots t t' =
   let
-    val thy = Proof_Context.theory_of ctxt
-    val get_types = post_fold_term_type (K cons) []
-    fun match_types tp =
-      fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
-    fun unica' b x [] = if b then [x] else []
-      | unica' b x (y :: ys) =
-        if x = y then unica' false x ys
-        else unica' true y ys |> b ? cons x
-    fun unica ord xs =
-      case sort ord xs of x :: ys => unica' true x ys | [] => []
-    val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
-    fun erase_unica_tfrees env =
-      let
-        val unica =
-          Vartab.fold (add_all_tfree_namesT o snd o snd) env []
-          |> filter_out (Variable.is_declared ctxt)
-          |> unica fast_string_ord
-        val erase_unica = map_atyps
-          (fn T as TFree (s, _) =>
-              if Ord_List.member fast_string_ord unica s then dummyT else T
-            | T => T)
-      in Vartab.map (K (apsnd erase_unica)) env end
-    val env = match_types (t', t) |> erase_unica_tfrees
-    fun get_annot env (TFree _) = (false, (env, dummyT))
-      | get_annot env (T as TVar (v, S)) =
-        let val T' = Envir.subst_type env T in
-          if T' = dummyT then (false, (env, dummyT))
-          else (true, (Vartab.update (v, (S, dummyT)) env, T'))
-        end
-      | get_annot env (Type (S, Ts)) =
-        (case fold_rev (fn T => fn (b, (env, Ts)) =>
-                  let
-                    val (b', (env', T)) = get_annot env T
-                  in (b orelse b', (env', T :: Ts)) end)
-                Ts (false, (env, [])) of
-           (true, (env', Ts)) => (true, (env', Type (S, Ts)))
-         | (false, (env', _)) => (false, (env', dummyT)))
-    fun post1 _ T (env, cp, ps as p :: ps', annots) =
+    fun subst_atype (T as TVar (idxn, S)) subst =
+        (Envir.subst_type subst T,
+         Vartab.update (idxn, (S, dummyT)) subst)
+      | subst_atype T subst = (T, subst)
+    val subst_type = fold_map_atypes subst_atype
+    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
         if p <> cp then
-          (env, cp + 1, ps, annots)
+          (subst, cp + 1, ps, annots)
         else
-          let val (annot_necessary, (env', T')) = get_annot env T in
-            (env', cp + 1, ps', annots |> annot_necessary ? cons (p, T'))
+          let val (T, subst) = subst_type T subst in
+            (subst, cp + 1, ps', (p, T)::annots)
           end
-      | post1 _ _ accum = accum
-    val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
-    fun post2 t _ (cp, annots as (p, T) :: annots') =
+      | collect_annot _ _ x = x
+    val (_, _, _, annots) =
+      post_fold_term_type collect_annot (subst, 0, spots, []) t'
+    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
         if p <> cp then (t, (cp + 1, annots))
         else (Type.constraint T t, (cp + 1, annots'))
-      | post2 t _ x = (t, x)
+      | insert_annot t _ x = (t, x)
   in
-    t |> post_traverse_term_type post2 (0, rev annots)
+    t |> post_traverse_term_type insert_annot (0, rev annots)
       |> fst
   end
 
-(* (5) Annotate *)
+(* (7) Annotate *)
 fun annotate_types ctxt t =
   let
     val t' = generalize_types ctxt t
+    val subst = match_types ctxt t' t
+    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
     val typing_spots =
       t' |> typing_spot_table
          |> reverse_greedy
          |> sort int_ord
-  in introduce_annotations ctxt typing_spots t t' end
+  in introduce_annotations subst typing_spots t t' end
 
 end