towards rtional arithmetic
authornipkow
Mon, 18 Dec 2000 14:57:34 +0100
changeset 10691 4ea37fba9c02
parent 10690 cd80241125b0
child 10692 6077fd933575
towards rtional arithmetic
src/Provers/Arith/fast_lin_arith.ML
--- a/src/Provers/Arith/fast_lin_arith.ML	Mon Dec 18 12:23:54 2000 +0100
+++ b/src/Provers/Arith/fast_lin_arith.ML	Mon Dec 18 14:57:34 2000 +0100
@@ -53,7 +53,8 @@
 signature LIN_ARITH_DATA =
 sig
   val decomp:
-    Sign.sg -> term -> ((term*int)list * int * string * (term*int)list * int * bool)option
+    Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option
+  val number_of: int * typ -> term
 end;
 (*
 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
@@ -70,9 +71,9 @@
 signature FAST_LIN_ARITH =
 sig
   val setup: (theory -> theory) list
-  val map_data: ({add_mono_thms: thm list, inj_thms: thm list,
+  val map_data: ({add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
                  lessD: thm list, simpset: Simplifier.simpset}
-                 -> {add_mono_thms: thm list, inj_thms: thm list,
+                 -> {add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
                      lessD: thm list, simpset: Simplifier.simpset})
                 -> theory -> theory
   val trace           : bool ref
@@ -93,19 +94,20 @@
 structure DataArgs =
 struct
   val name = "Provers/fast_lin_arith";
-  type T = {add_mono_thms: thm list, inj_thms: thm list,
+  type T = {add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
             lessD: thm list, simpset: Simplifier.simpset};
 
-  val empty = {add_mono_thms = [], inj_thms = [],
+  val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
                lessD = [], simpset = Simplifier.empty_ss};
   val copy = I;
   val prep_ext = I;
 
-  fun merge ({add_mono_thms = add_mono_thms1, inj_thms = inj_thms1,
+  fun merge ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
               lessD = lessD1, simpset = simpset1},
-             {add_mono_thms = add_mono_thms2, inj_thms = inj_thms2,
+             {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
               lessD = lessD2, simpset = simpset2}) =
     {add_mono_thms = Drule.merge_rules (add_mono_thms1, add_mono_thms2),
+     mult_mono_thms= generic_merge (eq_thm o pairself fst) I I mult_mono_thms1 mult_mono_thms2,
      inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
      lessD = Drule.merge_rules (lessD1, lessD2),
      simpset = Simplifier.merge_ss (simpset1, simpset2)};
@@ -137,6 +139,7 @@
                 | NotLeD of injust
                 | NotLeDD of injust
                 | Multiplied of int * injust
+                | Multiplied2 of int * injust
                 | Added of injust * injust;
 
 datatype lineq = Lineq of int * lineq_type * int list * injust;
@@ -174,20 +177,13 @@
 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
 (* ------------------------------------------------------------------------- *)
 
-fun gcd x y =
-  let fun gxd x y =
-    if y = 0 then x else gxd y (x mod y)
-  in if x < y then gxd y x else gxd x y end;
-
-fun lcm x y = (x * y) div gcd x y;
-
 fun el 0 (h::_) = h
   | el n (_::t) = el (n - 1) t
   | el _ _  = sys_error "el";
 
 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   let val c1 = el v l1 and c2 = el v l2
-      val m = lcm (abs c1) (abs c2)
+      val m = lcm(abs c1,abs c2)
       val m1 = m div (abs c1) and m2 = m div (abs c2)
       val (n1,n2) =
         if (c1 >= 0) = (c2 >= 0)
@@ -290,6 +286,8 @@
 
 (* FIXME OPTIMIZE!!!!
    Addition/Multiplication need i*t representation rather than t+t+...
+   Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
+   because Numerals are not known early enough.
 
 Simplification may detect a contradiction 'prematurely' due to type
 information: n+1 <= 0 is simplified to False and does not need to be crossed
@@ -299,7 +297,7 @@
  exception FalseE of thm
 in
 fun mkthm sg asms just =
-  let val {add_mono_thms, inj_thms, lessD, simpset} = Data.get_sg sg;
+  let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset} = Data.get_sg sg;
       val atoms = foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
                             map fst lhs  union  (map fst rhs  union  ats))
                         ([], mapfilter (LA_Data.decomp sg o concl_of) asms)
@@ -324,8 +322,14 @@
         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
 
+      fun multn2(n,thm) =
+        let val Some(mth,cv) =
+              get_first (fn (th,cv) => Some(thm RS th,cv) handle _ => None) mult_mono_thms
+            val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
+        in instantiate ([],[(cv,ct)]) mth end
+
       fun simp thm =
-        let val thm' = simplify simpset thm
+        let val thm' = full_simplify simpset thm
         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
 
       fun mk(Asm i) = trace_thm "Asm" (nth_elem(i,asms))
@@ -337,6 +341,7 @@
         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
         | mk(Multiplied(n,j)) = (trace_msg "*"; multn(n,mk j))
+        | mk(Multiplied2(n,j)) = simp (trace_msg "*2"; multn2(n,mk j))
 
   in trace_msg "mkthm";
      simplify simpset (mk just) handle FalseE thm => thm end
@@ -344,24 +349,34 @@
 
 fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
 
+fun lcms is = foldl lcm (1,is);
+
+fun integ(rlhs,r,rel,rrhs,s,d) =
+let val (rn,rd) = rep_rat r and (sn,sd) = rep_rat s
+    val m = lcms(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs))
+    fun mult(t,r) = let val (i,j) = rep_rat r in (t,i * (m div j)) end
+in (m,(map mult rlhs, rn * (m div rd), rel, map mult rrhs, sn * (m div sd), d)) end
+
 fun mklineq atoms =
   let val n = length atoms in
-    fn ((lhs,i,rel,rhs,j,discrete),k) =>
-    let val lhsa = map (coeff lhs) atoms
+    fn (item,k) =>
+    let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
+        val lhsa = map (coeff lhs) atoms
         and rhsa = map (coeff rhs) atoms
         val diff = map2 (op -) (rhsa,lhsa)
         val c = i-j
         val just = Asm k
+        fun lineq(c,le,cs,j) = Some(Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j)))
     in case rel of
-        "<="   => Some(Lineq(c,Le,diff,just))
+        "<="   => lineq(c,Le,diff,just)
        | "~<=" => if discrete
-                  then Some(Lineq(1-c,Le,map (op ~) diff,NotLeDD(just)))
-                  else Some(Lineq(~c,Lt,map (op ~) diff,NotLeD(just)))
+                  then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
+                  else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
        | "<"   => if discrete
-                  then Some(Lineq(c+1,Le,diff,LessD(just)))
-                  else Some(Lineq(c,Lt,diff,just))
-       | "~<"  => Some(Lineq(~c,Le,map (op~) diff,NotLessD(just)))
-       | "="   => Some(Lineq(c,Eq,diff,just))
+                  then lineq(c+1,Le,diff,LessD(just))
+                  else lineq(c,Lt,diff,just)
+       | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
+       | "="   => lineq(c,Eq,diff,just)
        | "~="  => None
        | _     => sys_error("mklineq" ^ rel)   
     end