src/HOL/Tools/SMT/smt_normalize.ML
changeset 50601 74da81de127f
parent 47207 9368aa814518
child 51575 907efc894051
--- a/src/HOL/Tools/SMT/smt_normalize.ML	Thu Dec 20 09:49:00 2012 +0100
+++ b/src/HOL/Tools/SMT/smt_normalize.ML	Fri Dec 21 11:05:42 2012 +0100
@@ -346,6 +346,14 @@
 
 (* unfolding of definitions and theory-specific rewritings *)
 
+fun expand_head_conv cv ct =
+  (case Thm.term_of ct of
+    _ $ _ =>
+      Conv.fun_conv (expand_head_conv cv) then_conv
+      Conv.try_conv (Thm.beta_conversion false)
+  | _ => cv) ct
+
+
 (** rewrite bool case expressions as if expressions **)
 
 local
@@ -355,7 +363,9 @@
   val thm = mk_meta_eq @{lemma
     "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp}
 
-  fun unfold_conv _ = SMT_Utils.if_true_conv is_bool_case (Conv.rewr_conv thm)
+  fun unfold_conv _ =
+    SMT_Utils.if_true_conv (is_bool_case o Term.head_of)
+      (expand_head_conv (Conv.rewr_conv thm))
 in
 
 fun rewrite_bool_case_conv ctxt =
@@ -393,8 +403,8 @@
     | abs_min_max _ _ = NONE
 
   fun unfold_amm_conv ctxt ct =
-    (case abs_min_max ctxt (Thm.term_of ct) of
-      SOME thm => Conv.rewr_conv thm
+    (case abs_min_max ctxt (Term.head_of (Thm.term_of ct)) of
+      SOME thm => expand_head_conv (Conv.rewr_conv thm)
     | NONE => Conv.all_conv) ct
 in
 
@@ -460,8 +470,11 @@
     "int (n * m) = int n * int m"
     "int (n div m) = int n div int m"
     "int (n mod m) = int n mod int m"
+    by (auto simp add: int_mult zdiv_int zmod_int)}
+
+  val int_if = mk_meta_eq @{lemma
     "int (if P then n else m) = (if P then int n else int m)"
-    by (auto simp add: int_mult zdiv_int zmod_int)}
+    by simp}
 
   fun mk_number_eq ctxt i lhs =
     let
@@ -471,12 +484,8 @@
       fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1       
     in Goal.norm_result (Goal.prove_internal [] eq tac) end
 
-  fun expand_head_conv cv ct =
-    (case Thm.term_of ct of
-      _ $ _ =>
-        Conv.fun_conv (expand_head_conv cv) then_conv
-        Thm.beta_conversion false
-    | _ => cv) ct
+  fun ite_conv cv1 cv2 =
+    Conv.combination_conv (Conv.combination_conv (Conv.arg_conv cv1) cv2) cv2
 
   fun int_conv ctxt ct =
     (case Thm.term_of ct of
@@ -484,7 +493,9 @@
         Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct)
     | @{const of_nat (int)} $ _ =>
         (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv
-        Conv.sub_conv (Conv.top_sweep_conv nat_conv) ctxt        
+        (Conv.rewr_conv int_if then_conv
+          ite_conv (nat_conv ctxt) (int_conv ctxt)) else_conv
+        Conv.sub_conv (Conv.top_sweep_conv nat_conv) ctxt
     | _ => Conv.no_conv) ct
 
   and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt