src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
changeset 52452 2207825d67f3
parent 52369 0b395800fdf0
child 52555 6811291d1869
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Tue Jun 25 17:13:09 2013 -0500
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Wed Jun 26 18:24:41 2013 +0200
     1.3 @@ -42,11 +42,31 @@
     1.4  fun post_fold_term_type f s t =
     1.5    post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
     1.6  
     1.7 -(* Data structures, orders *)
     1.8 +fun fold_map_atypes f T s =
     1.9 +  case T of
    1.10 +    Type (name, Ts) =>
    1.11 +        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
    1.12 +          (Type (name, Ts), s)
    1.13 +        end
    1.14 +  | _ => f T s
    1.15 +
    1.16 +(** get unique elements of a list **)
    1.17 +local
    1.18 +  fun unique' b x [] = if b then [x] else []
    1.19 +    | unique' b x (y :: ys) =
    1.20 +      if x = y then unique' false x ys
    1.21 +      else unique' true y ys |> b ? cons x
    1.22 +in
    1.23 +  fun unique ord xs =
    1.24 +    case sort ord xs of x :: ys => unique' true x ys | [] => []
    1.25 +end
    1.26 +
    1.27 +(** Data structures, orders **)
    1.28 +val indexname_ord = Term_Ord.fast_indexname_ord
    1.29  val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    1.30  structure Var_Set_Tab = Table(
    1.31    type key = indexname list
    1.32 -  val ord = list_ord Term_Ord.fast_indexname_ord)
    1.33 +  val ord = list_ord indexname_ord)
    1.34  
    1.35  (* (1) Generalize types *)
    1.36  fun generalize_types ctxt t =
    1.37 @@ -59,10 +79,61 @@
    1.38       t |> erase_types |> infer_types
    1.39    end
    1.40  
    1.41 -(* (2) Typing-spot table *)
    1.42 +(* (2) match types *)
    1.43 +fun match_types ctxt t1 t2 =
    1.44 +  let
    1.45 +    val thy = Proof_Context.theory_of ctxt
    1.46 +    val get_types = post_fold_term_type (K cons) []
    1.47 +  in
    1.48 +    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
    1.49 +  end
    1.50 +
    1.51 +
    1.52 +(* (3) handle trivial tfrees  *)
    1.53 +fun handle_trivial_tfrees ctxt (t', subst) =
    1.54 +  let
    1.55 +
    1.56 +    val add_tfree_names =
    1.57 +      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
    1.58 +
    1.59 +    val trivial_tfree_names =
    1.60 +      Vartab.fold add_tfree_names subst []
    1.61 +      |> filter_out (Variable.is_declared ctxt)
    1.62 +      |> unique fast_string_ord
    1.63 +    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
    1.64 +
    1.65 +    val trivial_tvar_names =
    1.66 +      Vartab.fold
    1.67 +        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
    1.68 +               tfree_name_trivial tfree_name ? cons tvar_name
    1.69 +          | _ => I)
    1.70 +        subst
    1.71 +        []
    1.72 +      |> sort indexname_ord
    1.73 +    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
    1.74 +
    1.75 +    val t' =
    1.76 +      t' |> map_types
    1.77 +              (map_type_tvar
    1.78 +                (fn (idxn, sort) =>
    1.79 +                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
    1.80 +
    1.81 +    val subst =
    1.82 +      subst |> fold Vartab.delete trivial_tvar_names
    1.83 +            |> Vartab.map
    1.84 +               (K (apsnd (map_type_tfree
    1.85 +                           (fn (name, sort) =>
    1.86 +                              if tfree_name_trivial name then dummyT
    1.87 +                              else TFree (name, sort)))))
    1.88 +  in
    1.89 +    (t', subst)
    1.90 +  end
    1.91 +
    1.92 +
    1.93 +(* (4) Typing-spot table *)
    1.94  local
    1.95  fun key_of_atype (TVar (z, _)) =
    1.96 -    Ord_List.insert Term_Ord.fast_indexname_ord z
    1.97 +    Ord_List.insert indexname_ord z
    1.98    | key_of_atype _ = I
    1.99  fun key_of_type T = fold_atyps key_of_atype T []
   1.100  fun update_tab t T (tab, pos) =
   1.101 @@ -83,7 +154,7 @@
   1.102    post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   1.103  end
   1.104  
   1.105 -(* (3) Reverse-greedy *)
   1.106 +(* (5) Reverse-greedy *)
   1.107  fun reverse_greedy typing_spot_tab =
   1.108    let
   1.109      fun update_count z =
   1.110 @@ -103,72 +174,43 @@
   1.111        |>> sort_distinct (rev_order o cost_ord o pairself snd)
   1.112    in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
   1.113  
   1.114 -(* (4) Introduce annotations *)
   1.115 -fun introduce_annotations ctxt spots t t' =
   1.116 +(* (6) Introduce annotations *)
   1.117 +fun introduce_annotations subst spots t t' =
   1.118    let
   1.119 -    val thy = Proof_Context.theory_of ctxt
   1.120 -    val get_types = post_fold_term_type (K cons) []
   1.121 -    fun match_types tp =
   1.122 -      fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
   1.123 -    fun unica' b x [] = if b then [x] else []
   1.124 -      | unica' b x (y :: ys) =
   1.125 -        if x = y then unica' false x ys
   1.126 -        else unica' true y ys |> b ? cons x
   1.127 -    fun unica ord xs =
   1.128 -      case sort ord xs of x :: ys => unica' true x ys | [] => []
   1.129 -    val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
   1.130 -    fun erase_unica_tfrees env =
   1.131 -      let
   1.132 -        val unica =
   1.133 -          Vartab.fold (add_all_tfree_namesT o snd o snd) env []
   1.134 -          |> filter_out (Variable.is_declared ctxt)
   1.135 -          |> unica fast_string_ord
   1.136 -        val erase_unica = map_atyps
   1.137 -          (fn T as TFree (s, _) =>
   1.138 -              if Ord_List.member fast_string_ord unica s then dummyT else T
   1.139 -            | T => T)
   1.140 -      in Vartab.map (K (apsnd erase_unica)) env end
   1.141 -    val env = match_types (t', t) |> erase_unica_tfrees
   1.142 -    fun get_annot env (TFree _) = (false, (env, dummyT))
   1.143 -      | get_annot env (T as TVar (v, S)) =
   1.144 -        let val T' = Envir.subst_type env T in
   1.145 -          if T' = dummyT then (false, (env, dummyT))
   1.146 -          else (true, (Vartab.update (v, (S, dummyT)) env, T'))
   1.147 -        end
   1.148 -      | get_annot env (Type (S, Ts)) =
   1.149 -        (case fold_rev (fn T => fn (b, (env, Ts)) =>
   1.150 -                  let
   1.151 -                    val (b', (env', T)) = get_annot env T
   1.152 -                  in (b orelse b', (env', T :: Ts)) end)
   1.153 -                Ts (false, (env, [])) of
   1.154 -           (true, (env', Ts)) => (true, (env', Type (S, Ts)))
   1.155 -         | (false, (env', _)) => (false, (env', dummyT)))
   1.156 -    fun post1 _ T (env, cp, ps as p :: ps', annots) =
   1.157 +    fun subst_atype (T as TVar (idxn, S)) subst =
   1.158 +        (Envir.subst_type subst T,
   1.159 +         Vartab.update (idxn, (S, dummyT)) subst)
   1.160 +      | subst_atype T subst = (T, subst)
   1.161 +    val subst_type = fold_map_atypes subst_atype
   1.162 +    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
   1.163          if p <> cp then
   1.164 -          (env, cp + 1, ps, annots)
   1.165 +          (subst, cp + 1, ps, annots)
   1.166          else
   1.167 -          let val (annot_necessary, (env', T')) = get_annot env T in
   1.168 -            (env', cp + 1, ps', annots |> annot_necessary ? cons (p, T'))
   1.169 +          let val (T, subst) = subst_type T subst in
   1.170 +            (subst, cp + 1, ps', (p, T)::annots)
   1.171            end
   1.172 -      | post1 _ _ accum = accum
   1.173 -    val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   1.174 -    fun post2 t _ (cp, annots as (p, T) :: annots') =
   1.175 +      | collect_annot _ _ x = x
   1.176 +    val (_, _, _, annots) =
   1.177 +      post_fold_term_type collect_annot (subst, 0, spots, []) t'
   1.178 +    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
   1.179          if p <> cp then (t, (cp + 1, annots))
   1.180          else (Type.constraint T t, (cp + 1, annots'))
   1.181 -      | post2 t _ x = (t, x)
   1.182 +      | insert_annot t _ x = (t, x)
   1.183    in
   1.184 -    t |> post_traverse_term_type post2 (0, rev annots)
   1.185 +    t |> post_traverse_term_type insert_annot (0, rev annots)
   1.186        |> fst
   1.187    end
   1.188  
   1.189 -(* (5) Annotate *)
   1.190 +(* (7) Annotate *)
   1.191  fun annotate_types ctxt t =
   1.192    let
   1.193      val t' = generalize_types ctxt t
   1.194 +    val subst = match_types ctxt t' t
   1.195 +    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
   1.196      val typing_spots =
   1.197        t' |> typing_spot_table
   1.198           |> reverse_greedy
   1.199           |> sort int_ord
   1.200 -  in introduce_annotations ctxt typing_spots t t' end
   1.201 +  in introduce_annotations subst typing_spots t t' end
   1.202  
   1.203  end