only linear occurrences of multiplication are treated as built-in (SMT solvers only support linear arithmetic in general);
authorboehmes
Sun, 19 Dec 2010 17:55:56 +0100
changeset 41280 a7de9d36f4f2
parent 41277 1369c27c6966
child 41281 679118e35378
only linear occurrences of multiplication are treated as built-in (SMT solvers only support linear arithmetic in general); hide internal constants z3div and z3mod; rewrite div/mod to z3div/z3mod instead of adding extra rules characterizing div/mod in terms of z3div/z3mod
src/HOL/SMT.thy
src/HOL/Tools/SMT/smt_builtin.ML
src/HOL/Tools/SMT/smt_normalize.ML
src/HOL/Tools/SMT/smt_real.ML
src/HOL/Tools/SMT/smt_utils.ML
src/HOL/Tools/SMT/smtlib_interface.ML
src/HOL/Tools/SMT/z3_interface.ML
--- a/src/HOL/SMT.thy	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/SMT.thy	Sun Dec 19 17:55:56 2010 +0100
@@ -130,21 +130,6 @@
 definition z3mod :: "int \<Rightarrow> int \<Rightarrow> int" where
   "z3mod k l = (if 0 \<le> l then k mod l else k mod (-l))"
 
-lemma div_by_z3div:
-  "\<forall>k l. k div l = (
-    if k = 0 \<or> l = 0 then 0
-    else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3div k l
-    else z3div (-k) (-l))"
-  by (auto simp add: z3div_def trigger_def)
-
-lemma mod_by_z3mod:
-  "\<forall>k l. k mod l = (
-    if l = 0 then k
-    else if k = 0 then 0
-    else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3mod k l
-    else - z3mod (-k) (-l))"
-  by (auto simp add: z3mod_def trigger_def)
-
 
 
 subsection {* Setup *}
@@ -391,8 +376,8 @@
 
 hide_type term_bool
 hide_type (open) pattern
-hide_const Pattern fun_app
-hide_const (open) trigger pat nopat weight z3div z3mod
+hide_const Pattern fun_app z3div z3mod
+hide_const (open) trigger pat nopat weight
 
 
 
--- a/src/HOL/Tools/SMT/smt_builtin.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_builtin.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -188,17 +188,8 @@
   | SOME (_, Ext f) => f ctxt T ts
   | NONE => false)
 
-(* FIXME: move this information to the interfaces *)
-val only_partially_supported = [
-  @{const_name times},
-  @{const_name div_class.div},
-  @{const_name div_class.mod},
-  @{const_name inverse_class.divide} ]
-
-fun is_builtin_ext ctxt (c as (n, _)) ts =
-  if member (op =) only_partially_supported n then false
-  else
-    is_builtin_num_ext ctxt (Term.list_comb (Const c, ts)) orelse 
-    is_builtin_fun_ext ctxt c ts
+fun is_builtin_ext ctxt c ts =
+  is_builtin_num_ext ctxt (Term.list_comb (Const c, ts)) orelse 
+  is_builtin_fun_ext ctxt c ts
 
 end
--- a/src/HOL/Tools/SMT/smt_normalize.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_normalize.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -431,20 +431,27 @@
     "ALL i. i < 0 --> int (nat i) = 0"
     by simp_all}
 
-  val nat_ops = [
+  val simple_nat_ops = [
     @{const less (nat)}, @{const less_eq (nat)},
-    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
-    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
+    @{const Suc}, @{const plus (nat)}, @{const minus (nat)}]
+
+  val mult_nat_ops =
+    [@{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
+
+  val nat_ops = simple_nat_ops @ mult_nat_ops
 
   val nat_consts = nat_ops @ [@{const number_of (nat)},
     @{const zero_class.zero (nat)}, @{const one_class.one (nat)}]
 
   val nat_int_coercions = [@{const of_nat (int)}, @{const nat}]
 
-  val nat_ops' = nat_int_coercions @ nat_ops
+  val builtin_nat_ops = nat_int_coercions @ simple_nat_ops
 
   val is_nat_const = member (op aconv) nat_consts
 
+  fun is_nat_const' @{const of_nat (int)} = true
+    | is_nat_const' t = is_nat_const t
+
   val expands = map mk_meta_eq @{lemma
     "0 = nat 0"
     "1 = nat 1"
@@ -494,16 +501,17 @@
         Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct)
     | @{const of_nat (int)} $ _ =>
         (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv
-        Conv.top_sweep_conv nat_conv ctxt        
+        Conv.sub_conv (Conv.top_sweep_conv nat_conv) ctxt        
     | _ => Conv.no_conv) ct
 
   and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt
 
   and expand_conv ctxt =
-    U.if_conv (not o is_nat_const o Term.head_of) Conv.no_conv
+    U.if_conv (is_nat_const o Term.head_of)
       (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt)
+      (int_conv ctxt)
 
-  and nat_conv ctxt = U.if_exists_conv is_nat_const
+  and nat_conv ctxt = U.if_exists_conv is_nat_const'
     (Conv.top_sweep_conv expand_conv ctxt)
 
   val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions)
@@ -517,7 +525,7 @@
 
 val setup_nat_as_int =
   B.add_builtin_typ_ext (@{typ nat}, K true) #>
-  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
+  fold (B.add_builtin_fun_ext' o Term.dest_Const) builtin_nat_ops
 
 end
 
--- a/src/HOL/Tools/SMT/smt_real.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_real.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -12,6 +12,7 @@
 structure SMT_Real: SMT_REAL =
 struct
 
+structure U = SMT_Utils
 structure B = SMT_Builtin
 
 
@@ -29,18 +30,31 @@
   val smtlibC = SMTLIB_Interface.smtlibC
 
   fun real_num _ i = SOME (string_of_int i ^ ".0")
+
+  fun is_linear [t] = U.is_number t
+    | is_linear [t, u] = U.is_number t orelse U.is_number u
+    | is_linear _ = false
+
+  fun times _ T ts = if is_linear ts then SOME ((("*", 2), T), ts, T) else NONE
+    | times _ _ _  = NONE
+
+  fun divide _ T (ts as [_, t]) =
+        if U.is_number t then SOME ((("/", 2), T), ts, T) else NONE
+    | divide _ _ _ = NONE
 in
 
 val setup_builtins =
   B.add_builtin_typ smtlibC (@{typ real}, K (SOME "Real"), real_num) #>
   fold (B.add_builtin_fun' smtlibC) [
+    (@{const less (real)}, "<"),
+    (@{const less_eq (real)}, "<="),
     (@{const uminus (real)}, "~"),
     (@{const plus (real)}, "+"),
-    (@{const minus (real)}, "-"),
-    (@{const times (real)}, "*"),
-    (@{const less (real)}, "<"),
-    (@{const less_eq (real)}, "<=") ] #>
-  B.add_builtin_fun' Z3_Interface.smtlib_z3C (@{const divide (real)}, "/")
+    (@{const minus (real)}, "-") ] #>
+  B.add_builtin_fun SMTLIB_Interface.smtlibC
+    (Term.dest_Const @{const times (real)}, times) #>
+  B.add_builtin_fun Z3_Interface.smtlib_z3C
+    (Term.dest_Const @{const divide (real)}, divide)
 
 end
 
--- a/src/HOL/Tools/SMT/smt_utils.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_utils.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -28,6 +28,7 @@
   val dest_conj: term -> term * term
   val dest_disj: term -> term * term
   val under_quant: (term -> 'a) -> term -> 'a
+  val is_number: term -> bool
 
   (*patterns and instantiations*)
   val mk_const_pat: theory -> string -> (ctyp -> 'a) -> 'a * cterm
@@ -132,6 +133,19 @@
   | Const (@{const_name Ex}, _) $ Abs (_, _, u) => under_quant f u
   | _ => f t)
 
+val is_number =
+  let
+    fun is_num env (Const (@{const_name If}, _) $ _ $ t $ u) =
+          is_num env t andalso is_num env u
+      | is_num env (Const (@{const_name Let}, _) $ t $ Abs (_, _, u)) =
+          is_num (t :: env) u
+      | is_num env (Const (@{const_name uminus}, _) $ t) = is_num env t
+      | is_num env (Const (@{const_name divide}, _) $ t $ u) =
+          is_num env t andalso is_num env u
+      | is_num env (Bound i) = i < length env andalso is_num env (nth env i)
+      | is_num _ t = can HOLogic.dest_number t
+  in is_num [] end
+
 
 (* patterns and instantiations *)
 
--- a/src/HOL/Tools/SMT/smtlib_interface.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/smtlib_interface.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -16,6 +16,7 @@
 structure SMTLIB_Interface: SMTLIB_INTERFACE =
 struct
 
+structure U = SMT_Utils
 structure B = SMT_Builtin
 structure N = SMT_Normalize
 structure T = SMT_Translate
@@ -28,6 +29,12 @@
 local
   fun int_num _ i = SOME (string_of_int i)
 
+  fun is_linear [t] = U.is_number t
+    | is_linear [t, u] = U.is_number t orelse U.is_number u
+    | is_linear _ = false
+
+  fun times _ T ts = if is_linear ts then SOME ((("*", 2), T), ts, T) else NONE
+
   fun distinct _ (Type (_, [Type (_, [T]), _])) [t] =
         (case try HOLogic.dest_list t of
           SOME (ts as _ :: _) =>
@@ -56,8 +63,8 @@
     (@{const less_eq (int)}, "<="),
     (@{const uminus (int)}, "~"),
     (@{const plus (int)}, "+"),
-    (@{const minus (int)}, "-"),
-    (@{const times (int)}, "*") ] #>
+    (@{const minus (int)}, "-") ] #>
+  B.add_builtin_fun smtlibC (Term.dest_Const @{const times (int)}, times) #>
   B.add_builtin_fun smtlibC (Term.dest_Const @{const distinct ('a)}, distinct)
 
 end
--- a/src/HOL/Tools/SMT/z3_interface.ML	Sun Dec 19 00:13:25 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_interface.ML	Sun Dec 19 17:55:56 2010 +0100
@@ -44,17 +44,34 @@
         has_datatypes=true}
     end
 
-  fun is_int_div_mod @{const div (int)} = true
-    | is_int_div_mod @{const mod (int)} = true
-    | is_int_div_mod _ = false
+  fun is_div_mod (@{const div (int)} $ _ $ t) = U.is_number t
+    | is_div_mod (@{const mod (int)} $ _ $ t) = U.is_number t
+    | is_div_mod _ = false
+
+  val div_by_z3div = mk_meta_eq @{lemma
+    "k div l = (
+      if k = 0 | l = 0 then 0
+      else if (0 < k & 0 < l) | (k < 0 & 0 < l) then z3div k l
+      else z3div (-k) (-l))"
+    by (simp add: SMT.z3div_def)}
 
-  val have_int_div_mod =
-    exists (Term.exists_subterm is_int_div_mod o Thm.prop_of)
+  val mod_by_z3mod = mk_meta_eq @{lemma
+    "k mod l = (
+      if l = 0 then k
+      else if k = 0 then 0
+      else if (0 < k & 0 < l) | (k < 0 & 0 < l) then z3mod k l
+      else - z3mod (-k) (-l))"
+    by (simp add: z3mod_def)}
 
-  fun add_div_mod _ (thms, extra_thms) =
-    if have_int_div_mod thms orelse have_int_div_mod extra_thms then
-      (thms, @{thm div_by_z3div} :: @{thm mod_by_z3mod} :: extra_thms)
-    else (thms, extra_thms)
+  fun div_mod_conv _ =
+    U.if_true_conv is_div_mod (Conv.rewrs_conv [div_by_z3div, mod_by_z3mod])
+
+  fun rewrite_div_mod ctxt thm =
+    if Term.exists_subterm is_div_mod (Thm.prop_of thm) then
+      Conv.fconv_rule (Conv.top_conv div_mod_conv ctxt) thm
+    else thm
+
+  fun norm_div_mod ctxt = pairself (map (rewrite_div_mod ctxt))
 
   val setup_builtins =
     B.add_builtin_fun' smtlib_z3C (@{const z3div}, "div") #>
@@ -63,7 +80,7 @@
 
 val setup = Context.theory_map (
   setup_builtins #>
-  SMT_Normalize.add_extra_norm (smtlib_z3C, add_div_mod) #>
+  SMT_Normalize.add_extra_norm (smtlib_z3C, norm_div_mod) #>
   SMT_Translate.add_config (smtlib_z3C, translate_config))
 
 end