src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
changeset 55243 66709d41601e
parent 55213 dcb36a2540bc
child 55286 7bbbd9393ce0
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML	Sun Feb 02 19:15:25 2014 +0000
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML	Sun Feb 02 20:53:51 2014 +0100
     1.3 @@ -20,15 +20,14 @@
     1.4  structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
     1.5  struct
     1.6  
     1.7 -(* Util *)
     1.8  fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
     1.9    | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    1.10    | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    1.11    | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    1.12    | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    1.13 -    let
    1.14 -      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    1.15 -    in f (Abs (x, T1, b')) (T1 --> T2) s' end
    1.16 +    let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
    1.17 +      f (Abs (x, T1, b')) (T1 --> T2) s'
    1.18 +    end
    1.19    | post_traverse_term_type' f env (u $ v) s =
    1.20      let
    1.21        val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    1.22 @@ -49,25 +48,13 @@
    1.23          end
    1.24    | _ => f T s
    1.25  
    1.26 -(** get unique elements of a list **)
    1.27 -local
    1.28 -  fun unique' b x [] = if b then [x] else []
    1.29 -    | unique' b x (y :: ys) =
    1.30 -      if x = y then unique' false x ys
    1.31 -      else unique' true y ys |> b ? cons x
    1.32 -in
    1.33 -  fun unique ord xs =
    1.34 -    case sort ord xs of x :: ys => unique' true x ys | [] => []
    1.35 -end
    1.36 -
    1.37 -(** Data structures, orders **)
    1.38  val indexname_ord = Term_Ord.fast_indexname_ord
    1.39  val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    1.40 +
    1.41  structure Var_Set_Tab = Table(
    1.42    type key = indexname list
    1.43    val ord = list_ord indexname_ord)
    1.44  
    1.45 -(* (1) Generalize types *)
    1.46  fun generalize_types ctxt t =
    1.47    let
    1.48      val erase_types = map_types (fn _ => dummyT)
    1.49 @@ -78,7 +65,6 @@
    1.50       t |> erase_types |> infer_types
    1.51    end
    1.52  
    1.53 -(* (2) match types *)
    1.54  fun match_types ctxt t1 t2 =
    1.55    let
    1.56      val thy = Proof_Context.theory_of ctxt
    1.57 @@ -88,17 +74,14 @@
    1.58      handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
    1.59    end
    1.60  
    1.61 -
    1.62 -(* (3) handle trivial tfrees  *)
    1.63  fun handle_trivial_tfrees ctxt (t', subst) =
    1.64    let
    1.65 -    val add_tfree_names =
    1.66 -      snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
    1.67 +    val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
    1.68  
    1.69      val trivial_tfree_names =
    1.70        Vartab.fold add_tfree_names subst []
    1.71        |> filter_out (Variable.is_declared ctxt)
    1.72 -      |> unique fast_string_ord
    1.73 +      |> distinct (op =)
    1.74      val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
    1.75  
    1.76      val trivial_tvar_names =
    1.77 @@ -128,30 +111,26 @@
    1.78      (t', subst)
    1.79    end
    1.80  
    1.81 -(* (4) Typing-spot table *)
    1.82 -local
    1.83  fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
    1.84    | key_of_atype _ = I
    1.85  fun key_of_type T = fold_atyps key_of_atype T []
    1.86 +
    1.87  fun update_tab t T (tab, pos) =
    1.88 -  (case key_of_type T of
    1.89 +  ((case key_of_type T of
    1.90       [] => tab
    1.91     | key =>
    1.92       let val cost = (size_of_typ T, (size_of_term t, pos)) in
    1.93 -       case Var_Set_Tab.lookup tab key of
    1.94 +       (case Var_Set_Tab.lookup tab key of
    1.95           NONE => Var_Set_Tab.update_new (key, cost) tab
    1.96         | SOME old_cost =>
    1.97           (case cost_ord (cost, old_cost) of
    1.98 -            LESS => Var_Set_Tab.update (key, cost) tab
    1.99 -          | _ => tab)
   1.100 -     end,
   1.101 +           LESS => Var_Set_Tab.update (key, cost) tab
   1.102 +         | _ => tab))
   1.103 +     end),
   1.104     pos + 1)
   1.105 -in
   1.106 -val typing_spot_table =
   1.107 -  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   1.108 -end
   1.109  
   1.110 -(* (5) Reverse-greedy *)
   1.111 +val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   1.112 +
   1.113  fun reverse_greedy typing_spot_tab =
   1.114    let
   1.115      fun update_count z =
   1.116 @@ -159,53 +138,53 @@
   1.117          let val c = Vartab.lookup tab tvar |> the_default 0 in
   1.118            Vartab.update (tvar, c + z) tab
   1.119          end)
   1.120 -    fun superfluous tcount =
   1.121 -      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
   1.122 +    fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
   1.123      fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
   1.124        if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
   1.125        else (spot :: spots, tcount)
   1.126 +
   1.127      val (typing_spots, tvar_count_tab) =
   1.128 -      Var_Set_Tab.fold
   1.129 -        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   1.130 +      Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   1.131          typing_spot_tab ([], Vartab.empty)
   1.132        |>> sort_distinct (rev_order o cost_ord o pairself snd)
   1.133 -  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
   1.134 +  in
   1.135 +    fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
   1.136 +  end
   1.137  
   1.138 -(* (6) Introduce annotations *)
   1.139  fun introduce_annotations subst spots t t' =
   1.140    let
   1.141      fun subst_atype (T as TVar (idxn, S)) subst =
   1.142          (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
   1.143        | subst_atype T subst = (T, subst)
   1.144 +
   1.145      val subst_type = fold_map_atypes subst_atype
   1.146 +
   1.147      fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
   1.148          if p <> cp then
   1.149            (subst, cp + 1, ps, annots)
   1.150          else
   1.151            let val (T, subst) = subst_type T subst in
   1.152 -            (subst, cp + 1, ps', (p, T)::annots)
   1.153 +            (subst, cp + 1, ps', (p, T) :: annots)
   1.154            end
   1.155        | collect_annot _ _ x = x
   1.156 -    val (_, _, _, annots) =
   1.157 -      post_fold_term_type collect_annot (subst, 0, spots, []) t'
   1.158 +
   1.159 +    val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
   1.160 +
   1.161      fun insert_annot t _ (cp, annots as (p, T) :: annots') =
   1.162          if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
   1.163        | insert_annot t _ x = (t, x)
   1.164    in
   1.165 -    t |> post_traverse_term_type insert_annot (0, rev annots)
   1.166 -      |> fst
   1.167 +    t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
   1.168    end
   1.169  
   1.170 -(* (7) Annotate *)
   1.171  fun annotate_types_in_term ctxt t =
   1.172    let
   1.173      val t' = generalize_types ctxt t
   1.174      val subst = match_types ctxt t' t
   1.175      val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
   1.176 -    val typing_spots =
   1.177 -      t' |> typing_spot_table
   1.178 -         |> reverse_greedy
   1.179 -         |> sort int_ord
   1.180 -  in introduce_annotations subst typing_spots t t' end
   1.181 +    val typing_spots = t' |> typing_spot_table |> reverse_greedy |> sort int_ord
   1.182 +  in
   1.183 +    introduce_annotations subst typing_spots t t'
   1.184 +  end
   1.185  
   1.186  end;