src/HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
author blanchet
Mon Jun 02 17:34:26 2014 +0200 (2014-06-02 ago)
changeset 57158 f028d93798e6
parent 55286 7bbbd9393ce0
child 57467 03345dad8430
permissions -rw-r--r--
simplified counterexample handling
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_isar_annotate.ML
     2     Author:     Steffen Juilf Smolka, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4 
     5 Supplements term with a locally minmal, complete set of type constraints. Complete: The constraints
     6 suffice to infer the term's types. Minimal: Reducing the set of constraints further will make it
     7 incomplete.
     8 
     9 When configuring the pretty printer appropriately, the constraints will show up as type annotations
    10 when printing the term. This allows the term to be printed and reparsed without a change of types.
    11 
    12 Note: Terms should be unchecked before calling "annotate_types_in_term" to avoid awkward syntax.
    13 *)
    14 
    15 signature SLEDGEHAMMER_ISAR_ANNOTATE =
    16 sig
    17   val annotate_types_in_term : Proof.context -> term -> term
    18 end;
    19 
    20 structure Sledgehammer_Isar_Annotate : SLEDGEHAMMER_ISAR_ANNOTATE =
    21 struct
    22 
    23 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
    24   | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
    25   | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
    26   | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
    27   | post_traverse_term_type' f env (Abs (x, T1, b)) s =
    28     let val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s in
    29       f (Abs (x, T1, b')) (T1 --> T2) s'
    30     end
    31   | post_traverse_term_type' f env (u $ v) s =
    32     let
    33       val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
    34       val ((v', s''), _) = post_traverse_term_type' f env v s'
    35     in f (u' $ v') T s'' end
    36     handle Bind => raise Fail "Sledgehammer_Isar_Annotate: post_traverse_term_type'"
    37 
    38 fun post_traverse_term_type f s t =
    39   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    40 fun post_fold_term_type f s t =
    41   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    42 
    43 fun fold_map_atypes f T s =
    44   (case T of
    45     Type (name, Ts) =>
    46     let val (Ts, s) = fold_map (fold_map_atypes f) Ts s in
    47       (Type (name, Ts), s)
    48     end
    49   | _ => f T s)
    50 
    51 val indexname_ord = Term_Ord.fast_indexname_ord
    52 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    53 
    54 structure Var_Set_Tab = Table(
    55   type key = indexname list
    56   val ord = list_ord indexname_ord)
    57 
    58 fun generalize_types ctxt t =
    59   let
    60     val erase_types = map_types (fn _ => dummyT)
    61     (* use schematic type variables *)
    62     val ctxt = ctxt |> Proof_Context.set_mode Proof_Context.mode_pattern
    63     val infer_types = singleton (Type_Infer_Context.infer_types ctxt)
    64   in
    65      t |> erase_types |> infer_types
    66   end
    67 
    68 fun match_types ctxt t1 t2 =
    69   let
    70     val thy = Proof_Context.theory_of ctxt
    71     val get_types = post_fold_term_type (K cons) []
    72   in
    73     fold (Sign.typ_match thy) (get_types t1 ~~ get_types t2) Vartab.empty
    74     handle Type.TYPE_MATCH => raise Fail "Sledgehammer_Isar_Annotate: match_types"
    75   end
    76 
    77 fun handle_trivial_tfrees ctxt (t', subst) =
    78   let
    79     val add_tfree_names = snd #> snd #> fold_atyps (fn TFree (x, _) => cons x | _ => I)
    80 
    81     val trivial_tfree_names =
    82       Vartab.fold add_tfree_names subst []
    83       |> filter_out (Variable.is_declared ctxt)
    84       |> distinct (op =)
    85     val tfree_name_trivial = Ord_List.member fast_string_ord trivial_tfree_names
    86 
    87     val trivial_tvar_names =
    88       Vartab.fold
    89         (fn (tvar_name, (_, TFree (tfree_name, _))) =>
    90                tfree_name_trivial tfree_name ? cons tvar_name
    91           | _ => I)
    92         subst
    93         []
    94       |> sort indexname_ord
    95     val tvar_name_trivial = Ord_List.member indexname_ord trivial_tvar_names
    96 
    97     val t' =
    98       t' |> map_types
    99               (map_type_tvar
   100                 (fn (idxn, sort) =>
   101                   if tvar_name_trivial idxn then dummyT else TVar (idxn, sort)))
   102 
   103     val subst =
   104       subst |> fold Vartab.delete trivial_tvar_names
   105             |> Vartab.map
   106                (K (apsnd (map_type_tfree
   107                            (fn (name, sort) =>
   108                               if tfree_name_trivial name then dummyT
   109                               else TFree (name, sort)))))
   110   in
   111     (t', subst)
   112   end
   113 
   114 fun key_of_atype (TVar (z, _)) = Ord_List.insert indexname_ord z
   115   | key_of_atype _ = I
   116 fun key_of_type T = fold_atyps key_of_atype T []
   117 
   118 fun update_tab t T (tab, pos) =
   119   ((case key_of_type T of
   120      [] => tab
   121    | key =>
   122      let val cost = (size_of_typ T, (size_of_term t, pos)) in
   123        (case Var_Set_Tab.lookup tab key of
   124          NONE => Var_Set_Tab.update_new (key, cost) tab
   125        | SOME old_cost =>
   126          (case cost_ord (cost, old_cost) of
   127            LESS => Var_Set_Tab.update (key, cost) tab
   128          | _ => tab))
   129      end),
   130    pos + 1)
   131 
   132 val typing_spot_table = post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
   133 
   134 fun reverse_greedy typing_spot_tab =
   135   let
   136     fun update_count z =
   137       fold (fn tvar => fn tab =>
   138         let val c = Vartab.lookup tab tvar |> the_default 0 in
   139           Vartab.update (tvar, c + z) tab
   140         end)
   141     fun superfluous tcount = forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
   142     fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
   143       if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
   144       else (spot :: spots, tcount)
   145 
   146     val (typing_spots, tvar_count_tab) =
   147       Var_Set_Tab.fold (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
   148         typing_spot_tab ([], Vartab.empty)
   149       |>> sort_distinct (rev_order o cost_ord o pairself snd)
   150   in
   151     fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst
   152   end
   153 
   154 fun introduce_annotations subst spots t t' =
   155   let
   156     fun subst_atype (T as TVar (idxn, S)) subst =
   157         (Envir.subst_type subst T, Vartab.update (idxn, (S, dummyT)) subst)
   158       | subst_atype T subst = (T, subst)
   159 
   160     val subst_type = fold_map_atypes subst_atype
   161 
   162     fun collect_annot _ T (subst, cp, ps as p :: ps', annots) =
   163         if p <> cp then
   164           (subst, cp + 1, ps, annots)
   165         else
   166           let val (T, subst) = subst_type T subst in
   167             (subst, cp + 1, ps', (p, T) :: annots)
   168           end
   169       | collect_annot _ _ x = x
   170 
   171     val (_, _, _, annots) = post_fold_term_type collect_annot (subst, 0, spots, []) t'
   172 
   173     fun insert_annot t _ (cp, annots as (p, T) :: annots') =
   174         if p <> cp then (t, (cp + 1, annots)) else (Type.constraint T t, (cp + 1, annots'))
   175       | insert_annot t _ x = (t, x)
   176   in
   177     t |> post_traverse_term_type insert_annot (0, rev annots) |> fst
   178   end
   179 
   180 fun annotate_types_in_term ctxt t =
   181   let
   182     val t' = generalize_types ctxt t
   183     val subst = match_types ctxt t' t
   184     val (t', subst) = (t', subst) |> handle_trivial_tfrees ctxt
   185     val typing_spots = t' |> typing_spot_table |> reverse_greedy |> sort int_ord
   186   in
   187     introduce_annotations subst typing_spots t t'
   188   end
   189 
   190 end;