put annotate in own structure
authorsmolkas
Wed Nov 28 12:22:05 2012 +0100 (2012-11-28)
changeset 502581c708d7728c7
parent 50257 bafbc4a3d976
child 50259 9c64a52ae499
put annotate in own structure
src/HOL/Sledgehammer.thy
src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML
     1.1 --- a/src/HOL/Sledgehammer.thy	Wed Nov 28 12:21:42 2012 +0100
     1.2 +++ b/src/HOL/Sledgehammer.thy	Wed Nov 28 12:22:05 2012 +0100
     1.3 @@ -14,6 +14,7 @@
     1.4  ML_file "Tools/Sledgehammer/async_manager.ML"
     1.5  ML_file "Tools/Sledgehammer/sledgehammer_util.ML"
     1.6  ML_file "Tools/Sledgehammer/sledgehammer_fact.ML"
     1.7 +ML_file "Tools/Sledgehammer/sledgehammer_annotate.ML"
     1.8  ML_file "Tools/Sledgehammer/sledgehammer_reconstruct.ML" 
     1.9  ML_file "Tools/Sledgehammer/sledgehammer_provers.ML"
    1.10  ML_file "Tools/Sledgehammer/sledgehammer_minimize.ML"
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML	Wed Nov 28 12:22:05 2012 +0100
     2.3 @@ -0,0 +1,151 @@
     2.4 +signature SLEDGEHAMMER_ANNOTATE =
     2.5 +sig
     2.6 +  val annotate_types : Proof.context -> term -> term
     2.7 +end
     2.8 +
     2.9 +structure Sledgehammer_Annotate : SLEDGEHAMMER_ANNOTATE =
    2.10 +struct
    2.11 +
    2.12 +(* Util *)
    2.13 +fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    2.14 +  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    2.15 +  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    2.16 +  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    2.17 +  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    2.18 +    let
    2.19 +      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    2.20 +    in f (Abs (x, T1, b')) (T1 --> T2) s' end
    2.21 +  | post_traverse_term_type' f env (u $ v) s =
    2.22 +    let
    2.23 +      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    2.24 +      val ((v', s''), _) = post_traverse_term_type' f env v s'
    2.25 +    in f (u' $ v') T s'' end
    2.26 +
    2.27 +fun post_traverse_term_type f s t =
    2.28 +  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    2.29 +fun post_fold_term_type f s t =
    2.30 +  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    2.31 +
    2.32 +(* Data structures, orders *)
    2.33 +val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    2.34 +
    2.35 +structure Var_Set_Tab = Table(
    2.36 +  type key = indexname list
    2.37 +  val ord = list_ord Term_Ord.fast_indexname_ord)
    2.38 +
    2.39 +(* (1) Generalize types *)
    2.40 +fun generalize_types ctxt t =
    2.41 +  t |> map_types (fn _ => dummyT)
    2.42 +    |> Syntax.check_term
    2.43 +         (Proof_Context.set_mode Proof_Context.mode_pattern ctxt)
    2.44 +
    2.45 +(* (2) Typing-spot table *)
    2.46 +local
    2.47 +fun key_of_atype (TVar (z, _)) =
    2.48 +    Ord_List.insert Term_Ord.fast_indexname_ord z
    2.49 +  | key_of_atype _ = I
    2.50 +fun key_of_type T = fold_atyps key_of_atype T []
    2.51 +fun update_tab t T (tab, pos) =
    2.52 +  (case key_of_type T of
    2.53 +     [] => tab
    2.54 +   | key =>
    2.55 +     let val cost = (size_of_typ T, (size_of_term t, pos)) in
    2.56 +       case Var_Set_Tab.lookup tab key of
    2.57 +         NONE => Var_Set_Tab.update_new (key, cost) tab
    2.58 +       | SOME old_cost =>
    2.59 +         (case cost_ord (cost, old_cost) of
    2.60 +            LESS => Var_Set_Tab.update (key, cost) tab
    2.61 +          | _ => tab)
    2.62 +     end,
    2.63 +   pos + 1)
    2.64 +in
    2.65 +val typing_spot_table =
    2.66 +  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
    2.67 +end
    2.68 +
    2.69 +(* (3) Reverse-greedy *)
    2.70 +fun reverse_greedy typing_spot_tab =
    2.71 +  let
    2.72 +    fun update_count z =
    2.73 +      fold (fn tvar => fn tab =>
    2.74 +        let val c = Vartab.lookup tab tvar |> the_default 0 in
    2.75 +          Vartab.update (tvar, c + z) tab
    2.76 +        end)
    2.77 +    fun superfluous tcount =
    2.78 +      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
    2.79 +    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
    2.80 +      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
    2.81 +      else (spot :: spots, tcount)
    2.82 +    val (typing_spots, tvar_count_tab) =
    2.83 +      Var_Set_Tab.fold
    2.84 +        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
    2.85 +        typing_spot_tab ([], Vartab.empty)
    2.86 +      |>> sort_distinct (rev_order o cost_ord o pairself snd)
    2.87 +  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
    2.88 +
    2.89 +(* (4) Introduce annotations *)
    2.90 +fun introduce_annotations ctxt spots t t' =
    2.91 +  let
    2.92 +    val thy = Proof_Context.theory_of ctxt
    2.93 +    val get_types = post_fold_term_type (K cons) []
    2.94 +    fun match_types tp =
    2.95 +      fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
    2.96 +    fun unica' b x [] = if b then [x] else []
    2.97 +      | unica' b x (y :: ys) =
    2.98 +        if x = y then unica' false x ys
    2.99 +        else unica' true y ys |> b ? cons x
   2.100 +    fun unica ord xs =
   2.101 +      case sort ord xs of x :: ys => unica' true x ys | [] => []
   2.102 +    val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
   2.103 +    fun erase_unica_tfrees env =
   2.104 +      let
   2.105 +        val unica =
   2.106 +          Vartab.fold (add_all_tfree_namesT o snd o snd) env []
   2.107 +          |> filter_out (Variable.is_declared ctxt)
   2.108 +          |> unica fast_string_ord
   2.109 +        val erase_unica = map_atyps
   2.110 +          (fn T as TFree (s, _) =>
   2.111 +              if Ord_List.member fast_string_ord unica s then dummyT else T
   2.112 +            | T => T)
   2.113 +      in Vartab.map (K (apsnd erase_unica)) env end
   2.114 +    val env = match_types (t', t) |> erase_unica_tfrees
   2.115 +    fun get_annot env (TFree _) = (false, (env, dummyT))
   2.116 +      | get_annot env (T as TVar (v, S)) =
   2.117 +        let val T' = Envir.subst_type env T in
   2.118 +          if T' = dummyT then (false, (env, dummyT))
   2.119 +          else (true, (Vartab.update (v, (S, dummyT)) env, T'))
   2.120 +        end
   2.121 +      | get_annot env (Type (S, Ts)) =
   2.122 +        (case fold_rev (fn T => fn (b, (env, Ts)) =>
   2.123 +                  let
   2.124 +                    val (b', (env', T)) = get_annot env T
   2.125 +                  in (b orelse b', (env', T :: Ts)) end)
   2.126 +                Ts (false, (env, [])) of
   2.127 +           (true, (env', Ts)) => (true, (env', Type (S, Ts)))
   2.128 +         | (false, (env', _)) => (false, (env', dummyT)))
   2.129 +    fun post1 _ T (env, cp, ps as p :: ps', annots) =
   2.130 +        if p <> cp then
   2.131 +          (env, cp + 1, ps, annots)
   2.132 +        else
   2.133 +          let val (_, (env', T')) = get_annot env T in
   2.134 +            (env', cp + 1, ps', (p, T') :: annots)
   2.135 +          end
   2.136 +      | post1 _ _ accum = accum
   2.137 +    val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   2.138 +    fun post2 t _ (cp, annots as (p, T) :: annots') =
   2.139 +        if p <> cp then (t, (cp + 1, annots))
   2.140 +        else (Type.constraint T t, (cp + 1, annots'))
   2.141 +      | post2 t _ x = (t, x)
   2.142 +  in post_traverse_term_type post2 (0, rev annots) t |> fst end
   2.143 +
   2.144 +(* (5) Annotate *)
   2.145 +fun annotate_types ctxt t =
   2.146 +  let
   2.147 +    val t' = generalize_types ctxt t
   2.148 +    val typing_spots =
   2.149 +      t' |> typing_spot_table
   2.150 +         |> reverse_greedy
   2.151 +         |> sort int_ord
   2.152 +  in introduce_annotations ctxt typing_spots t t' end
   2.153 +
   2.154 +end
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML	Wed Nov 28 12:21:42 2012 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML	Wed Nov 28 12:22:05 2012 +0100
     3.3 @@ -53,6 +53,7 @@
     3.4  open ATP_Problem_Generate
     3.5  open ATP_Proof_Reconstruct
     3.6  open Sledgehammer_Util
     3.7 +open Sledgehammer_Annotate
     3.8  
     3.9  structure String_Redirect = ATP_Proof_Redirect(
    3.10    type key = step_name
    3.11 @@ -477,149 +478,6 @@
    3.12       else
    3.13         map (replace_dependencies_in_line (name, deps)) lines)  (* drop line *)
    3.14  
    3.15 -(** Type annotations **)
    3.16 -
    3.17 -fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    3.18 -  | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    3.19 -  | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    3.20 -  | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    3.21 -  | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    3.22 -    let
    3.23 -      val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
    3.24 -    in f (Abs (x, T1, b')) (T1 --> T2) s' end
    3.25 -  | post_traverse_term_type' f env (u $ v) s =
    3.26 -    let
    3.27 -      val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    3.28 -      val ((v', s''), _) = post_traverse_term_type' f env v s'
    3.29 -    in f (u' $ v') T s'' end
    3.30 -
    3.31 -fun post_traverse_term_type f s t =
    3.32 -  post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    3.33 -fun post_fold_term_type f s t =
    3.34 -  post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    3.35 -
    3.36 -(* Data structures, orders *)
    3.37 -val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    3.38 -
    3.39 -structure Var_Set_Tab = Table(
    3.40 -  type key = indexname list
    3.41 -  val ord = list_ord Term_Ord.fast_indexname_ord)
    3.42 -
    3.43 -(* (1) Generalize types *)
    3.44 -fun generalize_types ctxt t =
    3.45 -  t |> map_types (fn _ => dummyT)
    3.46 -    |> Syntax.check_term
    3.47 -         (Proof_Context.set_mode Proof_Context.mode_pattern ctxt)
    3.48 -
    3.49 -(* (2) Typing-spot table *)
    3.50 -local
    3.51 -fun key_of_atype (TVar (z, _)) =
    3.52 -    Ord_List.insert Term_Ord.fast_indexname_ord z
    3.53 -  | key_of_atype _ = I
    3.54 -fun key_of_type T = fold_atyps key_of_atype T []
    3.55 -fun update_tab t T (tab, pos) =
    3.56 -  (case key_of_type T of
    3.57 -     [] => tab
    3.58 -   | key =>
    3.59 -     let val cost = (size_of_typ T, (size_of_term t, pos)) in
    3.60 -       case Var_Set_Tab.lookup tab key of
    3.61 -         NONE => Var_Set_Tab.update_new (key, cost) tab
    3.62 -       | SOME old_cost =>
    3.63 -         (case cost_ord (cost, old_cost) of
    3.64 -            LESS => Var_Set_Tab.update (key, cost) tab
    3.65 -          | _ => tab)
    3.66 -     end,
    3.67 -   pos + 1)
    3.68 -in
    3.69 -val typing_spot_table =
    3.70 -  post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
    3.71 -end
    3.72 -
    3.73 -(* (3) Reverse-greedy *)
    3.74 -fun reverse_greedy typing_spot_tab =
    3.75 -  let
    3.76 -    fun update_count z =
    3.77 -      fold (fn tvar => fn tab =>
    3.78 -        let val c = Vartab.lookup tab tvar |> the_default 0 in
    3.79 -          Vartab.update (tvar, c + z) tab
    3.80 -        end)
    3.81 -    fun superfluous tcount =
    3.82 -      forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
    3.83 -    fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
    3.84 -      if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
    3.85 -      else (spot :: spots, tcount)
    3.86 -    val (typing_spots, tvar_count_tab) =
    3.87 -      Var_Set_Tab.fold
    3.88 -        (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
    3.89 -        typing_spot_tab ([], Vartab.empty)
    3.90 -      |>> sort_distinct (rev_order o cost_ord o pairself snd)
    3.91 -  in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
    3.92 -
    3.93 -(* (4) Introduce annotations *)
    3.94 -fun introduce_annotations ctxt spots t t' =
    3.95 -  let
    3.96 -    val thy = Proof_Context.theory_of ctxt
    3.97 -    val get_types = post_fold_term_type (K cons) []
    3.98 -    fun match_types tp =
    3.99 -      fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
   3.100 -    fun unica' b x [] = if b then [x] else []
   3.101 -      | unica' b x (y :: ys) =
   3.102 -        if x = y then unica' false x ys
   3.103 -        else unica' true y ys |> b ? cons x
   3.104 -    fun unica ord xs =
   3.105 -      case sort ord xs of x :: ys => unica' true x ys | [] => []
   3.106 -    val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
   3.107 -    fun erase_unica_tfrees env =
   3.108 -      let
   3.109 -        val unica =
   3.110 -          Vartab.fold (add_all_tfree_namesT o snd o snd) env []
   3.111 -          |> filter_out (Variable.is_declared ctxt)
   3.112 -          |> unica fast_string_ord
   3.113 -        val erase_unica = map_atyps
   3.114 -          (fn T as TFree (s, _) =>
   3.115 -              if Ord_List.member fast_string_ord unica s then dummyT else T
   3.116 -            | T => T)
   3.117 -      in Vartab.map (K (apsnd erase_unica)) env end
   3.118 -    val env = match_types (t', t) |> erase_unica_tfrees
   3.119 -    fun get_annot env (TFree _) = (false, (env, dummyT))
   3.120 -      | get_annot env (T as TVar (v, S)) =
   3.121 -        let val T' = Envir.subst_type env T in
   3.122 -          if T' = dummyT then (false, (env, dummyT))
   3.123 -          else (true, (Vartab.update (v, (S, dummyT)) env, T'))
   3.124 -        end
   3.125 -      | get_annot env (Type (S, Ts)) =
   3.126 -        (case fold_rev (fn T => fn (b, (env, Ts)) =>
   3.127 -                  let
   3.128 -                    val (b', (env', T)) = get_annot env T
   3.129 -                  in (b orelse b', (env', T :: Ts)) end)
   3.130 -                Ts (false, (env, [])) of
   3.131 -           (true, (env', Ts)) => (true, (env', Type (S, Ts)))
   3.132 -         | (false, (env', _)) => (false, (env', dummyT)))
   3.133 -    fun post1 _ T (env, cp, ps as p :: ps', annots) =
   3.134 -        if p <> cp then
   3.135 -          (env, cp + 1, ps, annots)
   3.136 -        else
   3.137 -          let val (_, (env', T')) = get_annot env T in
   3.138 -            (env', cp + 1, ps', (p, T') :: annots)
   3.139 -          end
   3.140 -      | post1 _ _ accum = accum
   3.141 -    val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   3.142 -    fun post2 t _ (cp, annots as (p, T) :: annots') =
   3.143 -        if p <> cp then (t, (cp + 1, annots))
   3.144 -        else (Type.constraint T t, (cp + 1, annots'))
   3.145 -      | post2 t _ x = (t, x)
   3.146 -  in post_traverse_term_type post2 (0, rev annots) t |> fst end
   3.147 -
   3.148 -(* (5) Annotate *)
   3.149 -fun annotate_types ctxt t =
   3.150 -  let
   3.151 -    val t' = generalize_types ctxt t
   3.152 -    val typing_spots =
   3.153 -      t' |> typing_spot_table
   3.154 -         |> reverse_greedy
   3.155 -         |> sort int_ord
   3.156 -  in introduce_annotations ctxt typing_spots t t' end
   3.157 -
   3.158  val indent_size = 2
   3.159  val no_label = ("", ~1)
   3.160