--- a/src/Provers/Arith/fast_lin_arith.ML Mon Dec 18 12:23:54 2000 +0100
+++ b/src/Provers/Arith/fast_lin_arith.ML Mon Dec 18 14:57:34 2000 +0100
@@ -53,7 +53,8 @@
signature LIN_ARITH_DATA =
sig
val decomp:
- Sign.sg -> term -> ((term*int)list * int * string * (term*int)list * int * bool)option
+ Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option
+ val number_of: int * typ -> term
end;
(*
decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
@@ -70,9 +71,9 @@
signature FAST_LIN_ARITH =
sig
val setup: (theory -> theory) list
- val map_data: ({add_mono_thms: thm list, inj_thms: thm list,
+ val map_data: ({add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
lessD: thm list, simpset: Simplifier.simpset}
- -> {add_mono_thms: thm list, inj_thms: thm list,
+ -> {add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
lessD: thm list, simpset: Simplifier.simpset})
-> theory -> theory
val trace : bool ref
@@ -93,19 +94,20 @@
structure DataArgs =
struct
val name = "Provers/fast_lin_arith";
- type T = {add_mono_thms: thm list, inj_thms: thm list,
+ type T = {add_mono_thms: thm list, mult_mono_thms: (thm*cterm)list, inj_thms: thm list,
lessD: thm list, simpset: Simplifier.simpset};
- val empty = {add_mono_thms = [], inj_thms = [],
+ val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
lessD = [], simpset = Simplifier.empty_ss};
val copy = I;
val prep_ext = I;
- fun merge ({add_mono_thms = add_mono_thms1, inj_thms = inj_thms1,
+ fun merge ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
lessD = lessD1, simpset = simpset1},
- {add_mono_thms = add_mono_thms2, inj_thms = inj_thms2,
+ {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
lessD = lessD2, simpset = simpset2}) =
{add_mono_thms = Drule.merge_rules (add_mono_thms1, add_mono_thms2),
+ mult_mono_thms= generic_merge (eq_thm o pairself fst) I I mult_mono_thms1 mult_mono_thms2,
inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
lessD = Drule.merge_rules (lessD1, lessD2),
simpset = Simplifier.merge_ss (simpset1, simpset2)};
@@ -137,6 +139,7 @@
| NotLeD of injust
| NotLeDD of injust
| Multiplied of int * injust
+ | Multiplied2 of int * injust
| Added of injust * injust;
datatype lineq = Lineq of int * lineq_type * int list * injust;
@@ -174,20 +177,13 @@
(* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *)
(* ------------------------------------------------------------------------- *)
-fun gcd x y =
- let fun gxd x y =
- if y = 0 then x else gxd y (x mod y)
- in if x < y then gxd y x else gxd x y end;
-
-fun lcm x y = (x * y) div gcd x y;
-
fun el 0 (h::_) = h
| el n (_::t) = el (n - 1) t
| el _ _ = sys_error "el";
fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
let val c1 = el v l1 and c2 = el v l2
- val m = lcm (abs c1) (abs c2)
+ val m = lcm(abs c1,abs c2)
val m1 = m div (abs c1) and m2 = m div (abs c2)
val (n1,n2) =
if (c1 >= 0) = (c2 >= 0)
@@ -290,6 +286,8 @@
(* FIXME OPTIMIZE!!!!
Addition/Multiplication need i*t representation rather than t+t+...
+ Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
+ because Numerals are not known early enough.
Simplification may detect a contradiction 'prematurely' due to type
information: n+1 <= 0 is simplified to False and does not need to be crossed
@@ -299,7 +297,7 @@
exception FalseE of thm
in
fun mkthm sg asms just =
- let val {add_mono_thms, inj_thms, lessD, simpset} = Data.get_sg sg;
+ let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset} = Data.get_sg sg;
val atoms = foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
map fst lhs union (map fst rhs union ats))
([], mapfilter (LA_Data.decomp sg o concl_of) asms)
@@ -324,8 +322,14 @@
let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
+ fun multn2(n,thm) =
+ let val Some(mth,cv) =
+ get_first (fn (th,cv) => Some(thm RS th,cv) handle _ => None) mult_mono_thms
+ val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
+ in instantiate ([],[(cv,ct)]) mth end
+
fun simp thm =
- let val thm' = simplify simpset thm
+ let val thm' = full_simplify simpset thm
in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
fun mk(Asm i) = trace_thm "Asm" (nth_elem(i,asms))
@@ -337,6 +341,7 @@
| mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
| mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
| mk(Multiplied(n,j)) = (trace_msg "*"; multn(n,mk j))
+ | mk(Multiplied2(n,j)) = simp (trace_msg "*2"; multn2(n,mk j))
in trace_msg "mkthm";
simplify simpset (mk just) handle FalseE thm => thm end
@@ -344,24 +349,34 @@
fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
+fun lcms is = foldl lcm (1,is);
+
+fun integ(rlhs,r,rel,rrhs,s,d) =
+let val (rn,rd) = rep_rat r and (sn,sd) = rep_rat s
+ val m = lcms(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs))
+ fun mult(t,r) = let val (i,j) = rep_rat r in (t,i * (m div j)) end
+in (m,(map mult rlhs, rn * (m div rd), rel, map mult rrhs, sn * (m div sd), d)) end
+
fun mklineq atoms =
let val n = length atoms in
- fn ((lhs,i,rel,rhs,j,discrete),k) =>
- let val lhsa = map (coeff lhs) atoms
+ fn (item,k) =>
+ let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
+ val lhsa = map (coeff lhs) atoms
and rhsa = map (coeff rhs) atoms
val diff = map2 (op -) (rhsa,lhsa)
val c = i-j
val just = Asm k
+ fun lineq(c,le,cs,j) = Some(Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j)))
in case rel of
- "<=" => Some(Lineq(c,Le,diff,just))
+ "<=" => lineq(c,Le,diff,just)
| "~<=" => if discrete
- then Some(Lineq(1-c,Le,map (op ~) diff,NotLeDD(just)))
- else Some(Lineq(~c,Lt,map (op ~) diff,NotLeD(just)))
+ then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
+ else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
| "<" => if discrete
- then Some(Lineq(c+1,Le,diff,LessD(just)))
- else Some(Lineq(c,Lt,diff,just))
- | "~<" => Some(Lineq(~c,Le,map (op~) diff,NotLessD(just)))
- | "=" => Some(Lineq(c,Eq,diff,just))
+ then lineq(c+1,Le,diff,LessD(just))
+ else lineq(c,Lt,diff,just)
+ | "~<" => lineq(~c,Le,map (op~) diff,NotLessD(just))
+ | "=" => lineq(c,Eq,diff,just)
| "~=" => None
| _ => sys_error("mklineq" ^ rel)
end