eta-expand built-in constants; also rewrite partially applied natural number terms
authorboehmes
Fri Oct 29 18:17:10 2010 +0200 (2010-10-29)
changeset 4027996365b4ae7b6
parent 40278 0fc78bb54f18
child 40280 0dd2827e8596
eta-expand built-in constants; also rewrite partially applied natural number terms
src/HOL/Tools/SMT/smt_normalize.ML
     1.1 --- a/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:09 2010 +0200
     1.2 +++ b/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:10 2010 +0200
     1.3 @@ -129,27 +129,29 @@
     1.4    val nat_rewriting = @{lemma
     1.5      "0 = nat 0"
     1.6      "1 = nat 1"
     1.7 -    "number_of i = nat (number_of i)"
     1.8 +    "(number_of :: int => nat) = (%i. nat (number_of i))"
     1.9      "int (nat 0) = 0"
    1.10      "int (nat 1) = 1"
    1.11 -    "a < b = (int a < int b)"
    1.12 -    "a <= b = (int a <= int b)"
    1.13 -    "Suc a = nat (int a + 1)"
    1.14 -    "a + b = nat (int a + int b)"
    1.15 -    "a - b = nat (int a - int b)"
    1.16 -    "a * b = nat (int a * int b)"
    1.17 -    "a div b = nat (int a div int b)"
    1.18 -    "a mod b = nat (int a mod int b)"
    1.19 -    "min a b = nat (min (int a) (int b))"
    1.20 -    "max a b = nat (max (int a) (int b))"
    1.21 +    "op < = (%a b. int a < int b)"
    1.22 +    "op <= = (%a b. int a <= int b)"
    1.23 +    "Suc = (%a. nat (int a + 1))"
    1.24 +    "op + = (%a b. nat (int a + int b))"
    1.25 +    "op - = (%a b. nat (int a - int b))"
    1.26 +    "op * = (%a b. nat (int a * int b))"
    1.27 +    "op div = (%a b. nat (int a div int b))"
    1.28 +    "op mod = (%a b. nat (int a mod int b))"
    1.29 +    "min = (%a b. nat (min (int a) (int b)))"
    1.30 +    "max = (%a b. nat (max (int a) (int b)))"
    1.31      "int (nat (int a + int b)) = int a + int b"
    1.32 +    "int (nat (int a + 1)) = int a + 1"  (* special rule due to Suc above *)
    1.33      "int (nat (int a * int b)) = int a * int b"
    1.34      "int (nat (int a div int b)) = int a div int b"
    1.35      "int (nat (int a mod int b)) = int a mod int b"
    1.36      "int (nat (min (int a) (int b))) = min (int a) (int b)"
    1.37      "int (nat (max (int a) (int b))) = max (int a) (int b)"
    1.38 -    by (simp_all add: nat_mult_distrib nat_div_distrib nat_mod_distrib
    1.39 -      int_mult[symmetric] zdiv_int[symmetric] zmod_int[symmetric])}
    1.40 +    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
    1.41 +      nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric]
    1.42 +      zmod_int[symmetric])}
    1.43  
    1.44    fun on_positive num f x = 
    1.45      (case try HOLogic.dest_number (Thm.term_of num) of
    1.46 @@ -171,9 +173,10 @@
    1.47  
    1.48    val nat_ss = HOL_ss
    1.49      addsimps nat_rewriting
    1.50 -    addsimprocs [Simplifier.make_simproc {
    1.51 -      name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
    1.52 -      proc = cancel_int_nat_simproc, identifier = [] }]
    1.53 +    addsimprocs [
    1.54 +      Simplifier.make_simproc {
    1.55 +        name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
    1.56 +        proc = cancel_int_nat_simproc, identifier = [] }]
    1.57  
    1.58    fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
    1.59  
    1.60 @@ -198,6 +201,14 @@
    1.61  local
    1.62    val eta_conv = eta_expand_conv
    1.63  
    1.64 +  fun args_conv cv ct =
    1.65 +    (case Thm.term_of ct of
    1.66 +      _ $ _ => Conv.combination_conv (args_conv cv) cv
    1.67 +    | _ => Conv.all_conv) ct
    1.68 +
    1.69 +  fun eta_args_conv cv 0 = args_conv o cv
    1.70 +    | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1))
    1.71 +
    1.72    fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt
    1.73    and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt)
    1.74    and keep_let_conv ctxt = Conv.combination_conv
    1.75 @@ -229,30 +240,48 @@
    1.76      | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv
    1.77      | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv)
    1.78      | Abs _ => Conv.abs_conv (norm_conv o snd)
    1.79 -    | _ $ _ => Conv.comb_conv o norm_conv
    1.80 -    | _ => K Conv.all_conv) ctxt ct
    1.81 +    | _ =>
    1.82 +        (case Term.strip_comb (Thm.term_of ct) of
    1.83 +          (Const (c as (_, T)), ts) =>
    1.84 +            if SMT_Builtin.is_builtin ctxt c
    1.85 +            then eta_args_conv norm_conv
    1.86 +              (length (Term.binder_types T) - length ts)
    1.87 +            else args_conv o norm_conv
    1.88 +        | (_, ts) => args_conv o norm_conv)) ctxt ct
    1.89  
    1.90 -  fun is_normed t =
    1.91 +  fun is_normed ctxt t =
    1.92      (case t of
    1.93 -      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed u
    1.94 +      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u
    1.95      | Const (@{const_name All}, _) $ _ => false
    1.96      | Const (@{const_name All}, _) => false
    1.97 -    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed u
    1.98 +    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u
    1.99      | Const (@{const_name Ex}, _) $ _ => false
   1.100      | Const (@{const_name Ex}, _) => false
   1.101      | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
   1.102 -        is_normed u1 andalso is_normed u2
   1.103 +        is_normed ctxt u1 andalso is_normed ctxt u2
   1.104      | Const (@{const_name Let}, _) $ _ $ _ => false
   1.105      | Const (@{const_name Let}, _) $ _ => false
   1.106      | Const (@{const_name Let}, _) => false
   1.107 +    | Const (@{const_name Ex1}, _) $ _ => false
   1.108      | Const (@{const_name Ex1}, _) => false
   1.109 +    | Const (@{const_name Ball}, _) $ _ $ _ => false
   1.110 +    | Const (@{const_name Ball}, _) $ _ => false
   1.111      | Const (@{const_name Ball}, _) => false
   1.112 +    | Const (@{const_name Bex}, _) $ _ $ _ => false
   1.113 +    | Const (@{const_name Bex}, _) $ _ => false
   1.114      | Const (@{const_name Bex}, _) => false
   1.115 -    | Abs (_, _, u) => is_normed u
   1.116 -    | u1 $ u2 => is_normed u1 andalso is_normed u2
   1.117 -    | _ => true)
   1.118 +    | Abs (_, _, u) => is_normed ctxt u
   1.119 +    | _ =>
   1.120 +        (case Term.strip_comb t of
   1.121 +          (Const (c as (_, T)), ts) =>
   1.122 +            if SMT_Builtin.is_builtin ctxt c
   1.123 +            then length (Term.binder_types T) = length ts andalso
   1.124 +              forall (is_normed ctxt) ts
   1.125 +            else forall (is_normed ctxt) ts
   1.126 +        | (_, ts) => forall (is_normed ctxt) ts))
   1.127  in
   1.128 -fun norm_binder_conv ctxt = if_conv is_normed Conv.all_conv (norm_conv ctxt)
   1.129 +fun norm_binder_conv ctxt =
   1.130 +  if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
   1.131  end
   1.132  
   1.133  fun norm_def ctxt thm =