src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
changeset 55202 824c48a539c9
parent 54821 a12796872603
child 55205 8450622db0c5
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML	Fri Jan 31 10:23:32 2014 +0100
     1.3 @@ -0,0 +1,214 @@
     1.4 +(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
     1.5 +    Author:     Jasmin Blanchette, TU Muenchen
     1.6 +    Author:     Steffen Juilf Smolka, TU Muenchen
     1.7 +
     1.8 +Supplements term with a locally minmal, complete set of type constraints.
     1.9 +Complete: The constraints suffice to infer the term's types.
    1.10 +Minimal: Reducing the set of constraints further will make it incomplete.
    1.11 +
    1.12 +When configuring the pretty printer appropriately, the constraints will show up
    1.13 +as type annotations when printing the term. This allows the term to be printed
    1.14 +and reparsed without a change of types.
    1.15 +
    1.16 +NOTE: Terms should be unchecked before calling annotate_types to avoid awkward
    1.17 +syntax.
    1.18 +*)
    1.19 +
    1.20 +signature SLEDGEHAMMER_ISAR_ANNOTATE =
    1.21 +sig
    1.22 +  val annotate_types : Proof.context -> term -> term
    1.23 +end;
    1.24 +
    1.25 +structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
    1.26 +struct
    1.27 +
    1.28 +(* Util *)
    1.29 +fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    1.30 +  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    1.31 +  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    1.32 +  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    1.33 +  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    1.34 +    let
    1.35 +      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    1.36 +    in f (Abs (x, T1, b')) (T1 --> T2) s' end
    1.37 +  | post_traverse_term_type' f env (u $ v) s =
    1.38 +    let
    1.39 +      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    1.40 +      val ((v', s''), _) = post_traverse_term_type' f env v s'
    1.41 +    in f (u' $ v') T s'' end
    1.42 +    handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
    1.43 +
    1.44 +fun post_traverse_term_type f s t =
    1.45 +  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    1.46 +fun post_fold_term_type f s t =
    1.47 +  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    1.48 +
    1.49 +fun fold_map_atypes f T s =
    1.50 +  case T of
    1.51 +    Type (name, Ts) =>
    1.52 +        let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
    1.53 +          (Type (name, Ts), s)
    1.54 +        end
    1.55 +  | _ => f T s
    1.56 +
    1.57 +(** get unique elements of a list **)
    1.58 +local
    1.59 +  fun unique' b x [] = if b then [x] else []
    1.60 +    | unique' b x (y :: ys) =
    1.61 +      if x = y then unique' false x ys
    1.62 +      else unique' true y ys |> b ? cons x
    1.63 +in
    1.64 +  fun unique ord xs =
    1.65 +    case sort ord xs of x :: ys => unique' true x ys | [] => []
    1.66 +end
    1.67 +
    1.68 +(** Data structures, orders **)
    1.69 +val indexname_ord = Term_Ord.fast_indexname_ord
    1.70 +val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    1.71 +structure Var_Set_Tab = Table(
    1.72 +  type key = indexname list
    1.73 +  val ord = list_ord indexname_ord)
    1.74 +
    1.75 +(* (1) Generalize types *)
    1.76 +fun generalize_types ctxt t =
    1.77 +  let
    1.78 +    val erase_types = map_types (fn _ => dummyT)
    1.79 +    (* use schematic type variables *)
    1.80 +    val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
    1.81 +    val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
    1.82 +  in
    1.83 +     t |> erase_types |> infer_types
    1.84 +  end
    1.85 +
    1.86 +(* (2) match types *)
    1.87 +fun match_types ctxt t1 t2 =
    1.88 +  let
    1.89 +    val thy = Proof_Context.theory_of ctxt
    1.90 +    val get_types = post_fold_term_type (K cons) []
    1.91 +  in
    1.92 +    fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
    1.93 +    handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
    1.94 +  end
    1.95 +
    1.96 +
    1.97 +(* (3) handle trivial tfrees  *)
    1.98 +fun handle_trivial_tfrees ctxt (t', subst) =
    1.99 +  let
   1.100 +    val add_tfree_names =
   1.101 +      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
   1.102 +
   1.103 +    val trivial_tfree_names =
   1.104 +      Vartab.fold add_tfree_names subst []
   1.105 +      |> filter_out (Variable.is_declared ctxt)
   1.106 +      |> unique fast_string_ord
   1.107 +    val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
   1.108 +
   1.109 +    val trivial_tvar_names =
   1.110 +      Vartab.fold
   1.111 +        (fn (tvar_name, (_, TFree (tfree_name, _))) =>
   1.112 +               tfree_name_trivial tfree_name ? cons tvar_name
   1.113 +          | _ => I)
   1.114 +        subst
   1.115 +        []
   1.116 +      |> sort indexname_ord
   1.117 +    val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
   1.118 +
   1.119 +    val t' =
   1.120 +      t' |> map_types
   1.121 +              (map_type_tvar
   1.122 +                (fn (idxn, sort) =>
   1.123 +                  if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
   1.124 +
   1.125 +    val subst =
   1.126 +      subst |> fold Vartab.delete trivial_tvar_names
   1.127 +            |> Vartab.map
   1.128 +               (K (apsnd (map_type_tfree
   1.129 +                           (fn (name, sort) =>
   1.130 +                              if tfree_name_trivial name then dummyT
   1.131 +                              else TFree (name, sort)))))
   1.132 +  in
   1.133 +    (t', subst)
   1.134 +  end
   1.135 +
   1.136 +
   1.137 +(* (4) Typing-spot table *)
   1.138 +local
   1.139 +fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
   1.140 +  | key_of_atype _ = I
   1.141 +fun key_of_type T = fold_atyps key_of_atype T []
   1.142 +fun update_tab t T (tab, pos) =
   1.143 +  (case key_of_type T of
   1.144 +     [] => tab
   1.145 +   | key =>
   1.146 +     let val cost = (size_of_typ T, (size_of_term t, pos)) in
   1.147 +       case Var_Set_Tab.lookup tab key of
   1.148 +         NONE => Var_Set_Tab.update_new (key, cost) tab
   1.149 +       | SOME old_cost =>
   1.150 +         (case cost_ord (cost, old_cost) of
   1.151 +            LESS => Var_Set_Tab.update (key, cost) tab
   1.152 +          | _ => tab)
   1.153 +     end,
   1.154 +   pos + 1)
   1.155 +in
   1.156 +val typing_spot_table =
   1.157 +  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   1.158 +end
   1.159 +
   1.160 +(* (5) Reverse-greedy *)
   1.161 +fun reverse_greedy typing_spot_tab =
   1.162 +  let
   1.163 +    fun update_count z =
   1.164 +      fold (fn tvar => fn tab =>
   1.165 +        let val c = Vartab.lookup tab tvar |> the_default 0 in
   1.166 +          Vartab.update (tvar, c + z) tab
   1.167 +        end)
   1.168 +    fun superfluous tcount =
   1.169 +      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
   1.170 +    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
   1.171 +      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
   1.172 +      else (spot :: spots, tcount)
   1.173 +    val (typing_spots, tvar_count_tab) =
   1.174 +      Var_Set_Tab.fold
   1.175 +        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   1.176 +        typing_spot_tab ([], Vartab.empty)
   1.177 +      |>> sort_distinct (rev_order o cost_ord o pairself snd)
   1.178 +  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
   1.179 +
   1.180 +(* (6) Introduce annotations *)
   1.181 +fun introduce_annotations subst spots t t' =
   1.182 +  let
   1.183 +    fun subst_atype (T as TVar (idxn, S)) subst =
   1.184 +        (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
   1.185 +      | subst_atype T subst = (T, subst)
   1.186 +    val subst_type = fold_map_atypes subst_atype
   1.187 +    fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
   1.188 +        if p <> cp then
   1.189 +          (subst, cp + 1, ps, annots)
   1.190 +        else
   1.191 +          let val (T, subst) = subst_type T subst in
   1.192 +            (subst, cp + 1, ps', (p, T)::annots)
   1.193 +          end
   1.194 +      | collect_annot _ _ x = x
   1.195 +    val (_, _, _, annots) =
   1.196 +      post_fold_term_type collect_annot (subst, 0, spots, []) t'
   1.197 +    fun insert_annot t _ (cp, annots as (p, T) :: annots') =
   1.198 +        if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
   1.199 +      | insert_annot t _ x = (t, x)
   1.200 +  in
   1.201 +    t |> post_traverse_term_type insert_annot (0, rev annots)
   1.202 +      |> fst
   1.203 +  end
   1.204 +
   1.205 +(* (7) Annotate *)
   1.206 +fun annotate_types ctxt t =
   1.207 +  let
   1.208 +    val t' = generalize_types ctxt t
   1.209 +    val subst = match_types ctxt t' t
   1.210 +    val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
   1.211 +    val typing_spots =
   1.212 +      t' |> typing_spot_table
   1.213 +         |> reverse_greedy
   1.214 +         |> sort int_ord
   1.215 +  in introduce_annotations subst typing_spots t t' end
   1.216 +
   1.217 +end;