src/HOL/Tools/SMT/smt_normalize.ML
changeset 41059 d2b1fc1b8e19
parent 40686 4725ed462387
child 41072 9f9bc1bdacef
--- a/src/HOL/Tools/SMT/smt_normalize.ML	Mon Dec 06 16:54:22 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_normalize.ML	Tue Dec 07 14:53:12 2010 +0100
@@ -23,12 +23,14 @@
     (int * thm) list * Proof.context
   val atomize_conv: Proof.context -> conv
   val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv
+  val setup: theory -> theory
 end
 
 structure SMT_Normalize: SMT_NORMALIZE =
 struct
 
 structure U = SMT_Utils
+structure B = SMT_Builtin
 
 infix 2 ??
 fun (test ?? f) x = if test x then f x else x
@@ -95,6 +97,9 @@
 fun rewrite_bool_cases ctxt =
   map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ??
     Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt)))
+
+val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
+
 end
 
 
@@ -203,10 +208,20 @@
   val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
   val uses_nat_int = Term.exists_subterm (member (op aconv)
     [@{const of_nat (int)}, @{const nat}])
+
+  val nat_ops = [
+    @{const less (nat)}, @{const less_eq (nat)},
+    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
+    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
+  val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops
 in
 fun nat_as_int ctxt =
   map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #>
   exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding
+
+val setup_nat_as_int =
+  B.add_builtin_typ_ext (@{typ nat}, K true) #>
+  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
 end
 
 
@@ -263,7 +278,7 @@
     | _ =>
         (case Term.strip_comb (Thm.term_of ct) of
           (Const (c as (_, T)), ts) =>
-            if SMT_Builtin.is_partially_builtin ctxt c
+            if SMT_Builtin.is_builtin_fun ctxt c ts
             then eta_args_conv norm_conv
               (length (Term.binder_types T) - length ts)
             else args_conv o norm_conv
@@ -294,7 +309,7 @@
     | _ =>
         (case Term.strip_comb t of
           (Const (c as (_, T)), ts) =>
-            if SMT_Builtin.is_builtin ctxt c
+            if SMT_Builtin.is_builtin_fun ctxt c ts
             then length (Term.binder_types T) = length ts andalso
               forall (is_normed ctxt) ts
             else forall (is_normed ctxt) ts
@@ -302,6 +317,11 @@
 in
 fun norm_binder_conv ctxt =
   U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
+
+val setup_unfolded_quants =
+  fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex},
+    @{const_name Ex1}]
+
 end
 
 fun norm_def ctxt thm =
@@ -325,6 +345,10 @@
       Conv.rewr_conv @{thm atomize_all}
   | _ => Conv.all_conv) ct
 
+val setup_atomize =
+  fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="},
+    @{const_name all}, @{const_name Trueprop}]
+
 fun normalize_rule ctxt =
   Conv.fconv_rule (
     (* reduce lambda abstractions, except at known binders: *)
@@ -554,4 +578,14 @@
     |-> (if with_datatypes then datatype_selectors else pair)
   end
 
+
+
+(* setup *)
+
+val setup =
+  setup_bool_case #>
+  setup_nat_as_int #>
+  setup_unfolded_quants #>
+  setup_atomize
+
 end