src/HOL/Tools/Sledgehammer/sledgehammer_annotate.ML
changeset 52110 411db77f96f2
parent 51877 71052c42edf2
child 52366 ff89424b5094
equal deleted inserted replaced
52109:39ac12f31f5c 52110:411db77f96f2
    31 
    31 
    32 fun post_traverse_term_type f s t =
    32 fun post_traverse_term_type f s t =
    33   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    33   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
    34 fun post_fold_term_type f s t =
    34 fun post_fold_term_type f s t =
    35   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
    35   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
       
    36 
       
    37 local
       
    38 fun natify_numeral' (t as Const (s, T)) =
       
    39     (case s of
       
    40       "Groups.zero_class.zero" => Const (s, @{typ "nat"})
       
    41     | "Groups.one_class.one" => Const (s, @{typ "nat"})
       
    42     | "Num.numeral_class.numeral" => Const(s, @{typ "num"} --> @{typ "nat"})
       
    43     | "Num.numeral_class.neg_numeral" => Const(s, @{typ "num"} --> @{typ "nat"})
       
    44     | _ => t)
       
    45   | natify_numeral' t = t
       
    46 in
       
    47 val natify_numerals = Term.map_aterms natify_numeral'
       
    48 end
    36 
    49 
    37 (* Data structures, orders *)
    50 (* Data structures, orders *)
    38 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    51 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
    39 structure Var_Set_Tab = Table(
    52 structure Var_Set_Tab = Table(
    40   type key = indexname list
    53   type key = indexname list
   141     val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   154     val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
   142     fun post2 t _ (cp, annots as (p, T) :: annots') =
   155     fun post2 t _ (cp, annots as (p, T) :: annots') =
   143         if p <> cp then (t, (cp + 1, annots))
   156         if p <> cp then (t, (cp + 1, annots))
   144         else (Type.constraint T t, (cp + 1, annots'))
   157         else (Type.constraint T t, (cp + 1, annots'))
   145       | post2 t _ x = (t, x)
   158       | post2 t _ x = (t, x)
   146   in post_traverse_term_type post2 (0, rev annots) t |> fst end
   159   in
       
   160     t |> natify_numerals (* typing all numerals as "nat"s prevents the pretty
       
   161          printer from inserting additional, unwanted type annotations *)
       
   162       |> post_traverse_term_type post2 (0, rev annots)
       
   163       |> fst
       
   164   end
   147 
   165 
   148 (* (5) Annotate *)
   166 (* (5) Annotate *)
   149 fun annotate_types ctxt t =
   167 fun annotate_types ctxt t =
   150   let
   168   let
   151     val t' = generalize_types ctxt t
   169     val t' = generalize_types ctxt t