src/HOL/Tools/Nitpick/nitpick_peephole.ML
changeset 34124 c4628a1dcf75
parent 33982 1ae222745c4a
child 34126 8a2c5d7aff51
--- a/src/HOL/Tools/Nitpick/nitpick_peephole.ML	Mon Dec 14 16:48:49 2009 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_peephole.ML	Thu Dec 17 15:22:11 2009 +0100
@@ -7,6 +7,7 @@
 
 signature NITPICK_PEEPHOLE =
 sig
+  type n_ary_index = Kodkod.n_ary_index
   type formula = Kodkod.formula
   type int_expr = Kodkod.int_expr
   type rel_expr = Kodkod.rel_expr
@@ -14,29 +15,29 @@
   type expr_assign = Kodkod.expr_assign
 
   type name_pool = {
-    rels: Kodkod.n_ary_index list,
-    vars: Kodkod.n_ary_index list,
+    rels: n_ary_index list,
+    vars: n_ary_index list,
     formula_reg: int,
     rel_reg: int}
 
   val initial_pool : name_pool
-  val not3_rel : rel_expr
-  val suc_rel : rel_expr
-  val nat_add_rel : rel_expr
-  val int_add_rel : rel_expr
-  val nat_subtract_rel : rel_expr
-  val int_subtract_rel : rel_expr
-  val nat_multiply_rel : rel_expr
-  val int_multiply_rel : rel_expr
-  val nat_divide_rel : rel_expr
-  val int_divide_rel : rel_expr
-  val nat_modulo_rel : rel_expr
-  val int_modulo_rel : rel_expr
-  val nat_less_rel : rel_expr
-  val int_less_rel : rel_expr
-  val gcd_rel : rel_expr
-  val lcm_rel : rel_expr
-  val norm_frac_rel : rel_expr
+  val not3_rel : n_ary_index
+  val suc_rel : n_ary_index
+  val unsigned_bit_word_sel_rel : n_ary_index
+  val signed_bit_word_sel_rel : n_ary_index
+  val nat_add_rel : n_ary_index
+  val int_add_rel : n_ary_index
+  val nat_subtract_rel : n_ary_index
+  val int_subtract_rel : n_ary_index
+  val nat_multiply_rel : n_ary_index
+  val int_multiply_rel : n_ary_index
+  val nat_divide_rel : n_ary_index
+  val int_divide_rel : n_ary_index
+  val nat_less_rel : n_ary_index
+  val int_less_rel : n_ary_index
+  val gcd_rel : n_ary_index
+  val lcm_rel : n_ary_index
+  val norm_frac_rel : n_ary_index
   val atom_for_bool : int -> bool -> rel_expr
   val formula_for_bool : bool -> formula
   val atom_for_nat : int * int -> int -> int
@@ -44,6 +45,7 @@
   val max_int_for_card : int -> int
   val int_for_atom : int * int -> int -> int
   val atom_for_int : int * int -> int -> int
+  val is_twos_complement_representable : int -> int -> bool
   val inline_rel_expr : rel_expr -> bool
   val empty_n_ary_rel : int -> rel_expr
   val num_seq : int -> int -> int_expr list
@@ -105,23 +107,23 @@
   {rels = [(2, 10), (3, 20), (4, 10)], vars = [], formula_reg = 10,
    rel_reg = 10}
 
-val not3_rel = Rel (2, 0)
-val suc_rel = Rel (2, 1)
-val nat_add_rel = Rel (3, 0)
-val int_add_rel = Rel (3, 1)
-val nat_subtract_rel = Rel (3, 2)
-val int_subtract_rel = Rel (3, 3)
-val nat_multiply_rel = Rel (3, 4)
-val int_multiply_rel = Rel (3, 5)
-val nat_divide_rel = Rel (3, 6)
-val int_divide_rel = Rel (3, 7)
-val nat_modulo_rel = Rel (3, 8)
-val int_modulo_rel = Rel (3, 9)
-val nat_less_rel = Rel (3, 10)
-val int_less_rel = Rel (3, 11)
-val gcd_rel = Rel (3, 12)
-val lcm_rel = Rel (3, 13)
-val norm_frac_rel = Rel (4, 0)
+val not3_rel = (2, 0)
+val suc_rel = (2, 1)
+val unsigned_bit_word_sel_rel = (2, 2)
+val signed_bit_word_sel_rel = (2, 3)
+val nat_add_rel = (3, 0)
+val int_add_rel = (3, 1)
+val nat_subtract_rel = (3, 2)
+val int_subtract_rel = (3, 3)
+val nat_multiply_rel = (3, 4)
+val int_multiply_rel = (3, 5)
+val nat_divide_rel = (3, 6)
+val int_divide_rel = (3, 7)
+val nat_less_rel = (3, 8)
+val int_less_rel = (3, 9)
+val gcd_rel = (3, 10)
+val lcm_rel = (3, 11)
+val norm_frac_rel = (4, 0)
 
 (* int -> bool -> rel_expr *)
 fun atom_for_bool j0 = Atom o Integer.add j0 o int_for_bool
@@ -140,6 +142,9 @@
   if n < min_int_for_card k orelse n > max_int_for_card k then ~1
   else if n < 0 then n + k + j0
   else n + j0
+(* int -> int -> bool *)
+fun is_twos_complement_representable bits n =
+  let val max = reasonable_power 2 bits in n >= ~ max andalso n < max end
 
 (* rel_expr -> bool *)
 fun is_none_product (Product (r1, r2)) =
@@ -365,16 +370,28 @@
     (* rel_expr -> rel_expr *)
     fun s_not3 (Atom j) = Atom (if j = main_j0 then j + 1 else j - 1)
       | s_not3 (r as Join (r1, r2)) =
-        if r2 = not3_rel then r1 else Join (r, not3_rel)
-      | s_not3 r = Join (r, not3_rel)
+        if r2 = Rel not3_rel then r1 else Join (r, Rel not3_rel)
+      | s_not3 r = Join (r, Rel not3_rel)
 
     (* rel_expr -> rel_expr -> formula *)
     fun s_rel_eq r1 r2 =
       (case (r1, r2) of
-         (Join (r11, r12), _) =>
-         if r12 = not3_rel then s_rel_eq r11 (s_not3 r2) else raise SAME ()
-       | (_, Join (r21, r22)) =>
-         if r22 = not3_rel then s_rel_eq r21 (s_not3 r1) else raise SAME ()
+         (Join (r11, Rel x), _) =>
+         if x = not3_rel then s_rel_eq r11 (s_not3 r2) else raise SAME ()
+       | (_, Join (r21, Rel x)) =>
+         if x = not3_rel then s_rel_eq r21 (s_not3 r1) else raise SAME ()
+       | (RelIf (f, r11, r12), _) =>
+         if inline_rel_expr r2 then
+           s_formula_if f (s_rel_eq r11 r2) (s_rel_eq r12 r2)
+         else
+           raise SAME ()
+       | (_, RelIf (f, r21, r22)) =>
+         if inline_rel_expr r1 then
+           s_formula_if f (s_rel_eq r1 r21) (s_rel_eq r1 r22)
+         else
+           raise SAME ()
+       | (RelLet (bs, r1'), Atom _) => s_formula_let bs (s_rel_eq r1' r2)
+       | (Atom _, RelLet (bs, r2')) => s_formula_let bs (s_rel_eq r1 r2')
        | _ => raise SAME ())
       handle SAME () =>
              case rel_expr_equal r1 r2 of
@@ -499,8 +516,8 @@
       | s_join (r1 as RelIf (f, r11, r12)) r2 =
         if inline_rel_expr r2 then s_rel_if f (s_join r11 r2) (s_join r12 r2)
         else Join (r1, r2)
-      | s_join (r1 as Atom j1) (r2 as Rel (2, j2)) =
-        if r2 = suc_rel then
+      | s_join (r1 as Atom j1) (r2 as Rel (x as (2, j2))) =
+        if x = suc_rel then
           let val n = to_nat j1 + 1 in
             if n < nat_card then from_nat n else None
           end
@@ -511,8 +528,8 @@
           s_project (s_join r21 r1) is
         else
           Join (r1, r2)
-      | s_join r1 (Join (r21, r22 as Rel (3, j22))) =
-        ((if r22 = nat_add_rel then
+      | s_join r1 (Join (r21, r22 as Rel (x as (3, j22)))) =
+        ((if x = nat_add_rel then
             case (r21, r1) of
               (Atom j1, Atom j2) =>
               let val n = to_nat j1 + to_nat j2 in
@@ -521,19 +538,19 @@
             | (Atom j, r) =>
               (case to_nat j of
                  0 => r
-               | 1 => s_join r suc_rel
+               | 1 => s_join r (Rel suc_rel)
                | _ => raise SAME ())
             | (r, Atom j) =>
               (case to_nat j of
                  0 => r
-               | 1 => s_join r suc_rel
+               | 1 => s_join r (Rel suc_rel)
                | _ => raise SAME ())
             | _ => raise SAME ()
-          else if r22 = nat_subtract_rel then
+          else if x = nat_subtract_rel then
             case (r21, r1) of
               (Atom j1, Atom j2) => from_nat (nat_minus (to_nat j1) (to_nat j2))
             | _ => raise SAME ()
-          else if r22 = nat_multiply_rel then
+          else if x = nat_multiply_rel then
             case (r21, r1) of
               (Atom j1, Atom j2) =>
               let val n = to_nat j1 * to_nat j2 in
@@ -596,20 +613,20 @@
       in aux (arity_of_rel_expr r) r end
 
     (* rel_expr -> rel_expr -> rel_expr *)
-    fun s_nat_subtract r1 r2 = fold s_join [r1, r2] nat_subtract_rel
+    fun s_nat_subtract r1 r2 = fold s_join [r1, r2] (Rel nat_subtract_rel)
     fun s_nat_less (Atom j1) (Atom j2) = from_bool (j1 < j2)
-      | s_nat_less r1 r2 = fold s_join [r1, r2] nat_less_rel
+      | s_nat_less r1 r2 = fold s_join [r1, r2] (Rel nat_less_rel)
     fun s_int_less (Atom j1) (Atom j2) = from_bool (to_int j1 < to_int j2)
-      | s_int_less r1 r2 = fold s_join [r1, r2] int_less_rel
+      | s_int_less r1 r2 = fold s_join [r1, r2] (Rel int_less_rel)
 
     (* rel_expr -> int -> int -> rel_expr *)
     fun d_project_seq r j0 n = Project (r, num_seq j0 n)
     (* rel_expr -> rel_expr *)
-    fun d_not3 r = Join (r, not3_rel)
+    fun d_not3 r = Join (r, Rel not3_rel)
     (* rel_expr -> rel_expr -> rel_expr *)
-    fun d_nat_subtract r1 r2 = List.foldl Join nat_subtract_rel [r1, r2]
-    fun d_nat_less r1 r2 = List.foldl Join nat_less_rel [r1, r2]
-    fun d_int_less r1 r2 = List.foldl Join int_less_rel [r1, r2]
+    fun d_nat_subtract r1 r2 = List.foldl Join (Rel nat_subtract_rel) [r1, r2]
+    fun d_nat_less r1 r2 = List.foldl Join (Rel nat_less_rel) [r1, r2]
+    fun d_int_less r1 r2 = List.foldl Join (Rel int_less_rel) [r1, r2]
   in
     if optim then
       {kk_all = s_all, kk_exist = s_exist, kk_formula_let = s_formula_let,