eta-expand built-in constants; also rewrite partially applied natural number terms
authorboehmes
Fri, 29 Oct 2010 18:17:10 +0200
changeset 40279 96365b4ae7b6
parent 40278 0fc78bb54f18
child 40280 0dd2827e8596
eta-expand built-in constants; also rewrite partially applied natural number terms
src/HOL/Tools/SMT/smt_normalize.ML
--- a/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:09 2010 +0200
+++ b/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:10 2010 +0200
@@ -129,27 +129,29 @@
   val nat_rewriting = @{lemma
     "0 = nat 0"
     "1 = nat 1"
-    "number_of i = nat (number_of i)"
+    "(number_of :: int => nat) = (%i. nat (number_of i))"
     "int (nat 0) = 0"
     "int (nat 1) = 1"
-    "a < b = (int a < int b)"
-    "a <= b = (int a <= int b)"
-    "Suc a = nat (int a + 1)"
-    "a + b = nat (int a + int b)"
-    "a - b = nat (int a - int b)"
-    "a * b = nat (int a * int b)"
-    "a div b = nat (int a div int b)"
-    "a mod b = nat (int a mod int b)"
-    "min a b = nat (min (int a) (int b))"
-    "max a b = nat (max (int a) (int b))"
+    "op < = (%a b. int a < int b)"
+    "op <= = (%a b. int a <= int b)"
+    "Suc = (%a. nat (int a + 1))"
+    "op + = (%a b. nat (int a + int b))"
+    "op - = (%a b. nat (int a - int b))"
+    "op * = (%a b. nat (int a * int b))"
+    "op div = (%a b. nat (int a div int b))"
+    "op mod = (%a b. nat (int a mod int b))"
+    "min = (%a b. nat (min (int a) (int b)))"
+    "max = (%a b. nat (max (int a) (int b)))"
     "int (nat (int a + int b)) = int a + int b"
+    "int (nat (int a + 1)) = int a + 1"  (* special rule due to Suc above *)
     "int (nat (int a * int b)) = int a * int b"
     "int (nat (int a div int b)) = int a div int b"
     "int (nat (int a mod int b)) = int a mod int b"
     "int (nat (min (int a) (int b))) = min (int a) (int b)"
     "int (nat (max (int a) (int b))) = max (int a) (int b)"
-    by (simp_all add: nat_mult_distrib nat_div_distrib nat_mod_distrib
-      int_mult[symmetric] zdiv_int[symmetric] zmod_int[symmetric])}
+    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
+      nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric]
+      zmod_int[symmetric])}
 
   fun on_positive num f x = 
     (case try HOLogic.dest_number (Thm.term_of num) of
@@ -171,9 +173,10 @@
 
   val nat_ss = HOL_ss
     addsimps nat_rewriting
-    addsimprocs [Simplifier.make_simproc {
-      name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
-      proc = cancel_int_nat_simproc, identifier = [] }]
+    addsimprocs [
+      Simplifier.make_simproc {
+        name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
+        proc = cancel_int_nat_simproc, identifier = [] }]
 
   fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
 
@@ -198,6 +201,14 @@
 local
   val eta_conv = eta_expand_conv
 
+  fun args_conv cv ct =
+    (case Thm.term_of ct of
+      _ $ _ => Conv.combination_conv (args_conv cv) cv
+    | _ => Conv.all_conv) ct
+
+  fun eta_args_conv cv 0 = args_conv o cv
+    | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1))
+
   fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt
   and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt)
   and keep_let_conv ctxt = Conv.combination_conv
@@ -229,30 +240,48 @@
     | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv
     | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv)
     | Abs _ => Conv.abs_conv (norm_conv o snd)
-    | _ $ _ => Conv.comb_conv o norm_conv
-    | _ => K Conv.all_conv) ctxt ct
+    | _ =>
+        (case Term.strip_comb (Thm.term_of ct) of
+          (Const (c as (_, T)), ts) =>
+            if SMT_Builtin.is_builtin ctxt c
+            then eta_args_conv norm_conv
+              (length (Term.binder_types T) - length ts)
+            else args_conv o norm_conv
+        | (_, ts) => args_conv o norm_conv)) ctxt ct
 
-  fun is_normed t =
+  fun is_normed ctxt t =
     (case t of
-      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed u
+      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u
     | Const (@{const_name All}, _) $ _ => false
     | Const (@{const_name All}, _) => false
-    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed u
+    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u
     | Const (@{const_name Ex}, _) $ _ => false
     | Const (@{const_name Ex}, _) => false
     | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
-        is_normed u1 andalso is_normed u2
+        is_normed ctxt u1 andalso is_normed ctxt u2
     | Const (@{const_name Let}, _) $ _ $ _ => false
     | Const (@{const_name Let}, _) $ _ => false
     | Const (@{const_name Let}, _) => false
+    | Const (@{const_name Ex1}, _) $ _ => false
     | Const (@{const_name Ex1}, _) => false
+    | Const (@{const_name Ball}, _) $ _ $ _ => false
+    | Const (@{const_name Ball}, _) $ _ => false
     | Const (@{const_name Ball}, _) => false
+    | Const (@{const_name Bex}, _) $ _ $ _ => false
+    | Const (@{const_name Bex}, _) $ _ => false
     | Const (@{const_name Bex}, _) => false
-    | Abs (_, _, u) => is_normed u
-    | u1 $ u2 => is_normed u1 andalso is_normed u2
-    | _ => true)
+    | Abs (_, _, u) => is_normed ctxt u
+    | _ =>
+        (case Term.strip_comb t of
+          (Const (c as (_, T)), ts) =>
+            if SMT_Builtin.is_builtin ctxt c
+            then length (Term.binder_types T) = length ts andalso
+              forall (is_normed ctxt) ts
+            else forall (is_normed ctxt) ts
+        | (_, ts) => forall (is_normed ctxt) ts))
 in
-fun norm_binder_conv ctxt = if_conv is_normed Conv.all_conv (norm_conv ctxt)
+fun norm_binder_conv ctxt =
+  if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
 end
 
 fun norm_def ctxt thm =