13 |
13 |
14 lin_arith_prover: Sign.sg -> simpset -> term -> thm option |
14 lin_arith_prover: Sign.sg -> simpset -> term -> thm option |
15 |
15 |
16 Only take premises and conclusions into account that are already (negated) |
16 Only take premises and conclusions into account that are already (negated) |
17 (in)equations. lin_arith_prover tries to prove or disprove the term. |
17 (in)equations. lin_arith_prover tries to prove or disprove the term. |
18 |
|
19 FIXME: convert to IntInf.int throughout. |
|
20 *) |
18 *) |
21 |
19 |
22 (* Debugging: set Fast_Arith.trace *) |
20 (* Debugging: set Fast_Arith.trace *) |
23 |
21 |
24 (*** Data needed for setting up the linear arithmetic package ***) |
22 (*** Data needed for setting up the linear arithmetic package ***) |
53 |
51 |
54 signature LIN_ARITH_DATA = |
52 signature LIN_ARITH_DATA = |
55 sig |
53 sig |
56 val decomp: |
54 val decomp: |
57 Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option |
55 Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option |
58 val number_of: int * typ -> term |
56 val number_of: IntInf.int * typ -> term |
59 end; |
57 end; |
60 (* |
58 (* |
61 decomp(`x Rel y') should yield (p,i,Rel,q,j,d) |
59 decomp(`x Rel y') should yield (p,i,Rel,q,j,d) |
62 where Rel is one of "<", "~<", "<=", "~<=" and "=" and |
60 where Rel is one of "<", "~<", "<=", "~<=" and "=" and |
63 p/q is the decomposition of the sum terms x/y into a list |
61 p/q is the decomposition of the sum terms x/y into a list |
139 | Nat of int (* index of atom *) |
137 | Nat of int (* index of atom *) |
140 | LessD of injust |
138 | LessD of injust |
141 | NotLessD of injust |
139 | NotLessD of injust |
142 | NotLeD of injust |
140 | NotLeD of injust |
143 | NotLeDD of injust |
141 | NotLeDD of injust |
144 | Multiplied of int * injust |
142 | Multiplied of IntInf.int * injust |
145 | Multiplied2 of int * injust |
143 | Multiplied2 of IntInf.int * injust |
146 | Added of injust * injust; |
144 | Added of injust * injust; |
147 |
145 |
148 datatype lineq = Lineq of int * lineq_type * int list * injust; |
146 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust; |
149 |
147 |
150 fun el 0 (h::_) = h |
148 fun el 0 (h::_) = h |
151 | el n (_::t) = el (n - 1) t |
149 | el n (_::t) = el (n - 1) t |
152 | el _ _ = sys_error "el"; |
150 | el _ _ = sys_error "el"; |
153 |
151 |
169 | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs; |
167 | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs; |
170 |
168 |
171 val rat0 = rat_of_int 0; |
169 val rat0 = rat_of_int 0; |
172 |
170 |
173 (* PRE: ex[v] must be 0! *) |
171 (* PRE: ex[v] must be 0! *) |
174 fun eval (ex:rat list) v (a:int,le,cs:int list) = |
172 fun eval (ex:rat list) v (a:IntInf.int,le,cs:IntInf.int list) = |
175 let val rs = map rat_of_int cs |
173 let val rs = map rat_of_intinf cs |
176 val rsum = Library.foldl ratadd (rat0,map ratmul (rs ~~ ex)) |
174 val rsum = Library.foldl ratadd (rat0,map ratmul (rs ~~ ex)) |
177 in (ratmul(ratadd(rat_of_int a,ratneg rsum), ratinv(el v rs)), le) end; |
175 in (ratmul(ratadd(rat_of_intinf a,ratneg rsum), ratinv(el v rs)), le) end; |
178 (* If el v rs < 0, le should be negated. |
176 (* If el v rs < 0, le should be negated. |
179 Instead this swap is taken into account in ratrelmin2. |
177 Instead this swap is taken into account in ratrelmin2. |
180 *) |
178 *) |
181 |
179 |
182 fun ratge0 r = fst(rep_rat r) >= 0; |
180 fun ratge0 r = fst(rep_rat r) >= 0; |
256 in pick_vars discr (ineqs',ex') end |
254 in pick_vars discr (ineqs',ex') end |
257 end; |
255 end; |
258 |
256 |
259 fun findex0 discr n lineqs = |
257 fun findex0 discr n lineqs = |
260 let val ineqs = Library.foldl elim_eqns ([],lineqs) |
258 let val ineqs = Library.foldl elim_eqns ([],lineqs) |
261 val rineqs = map (fn (a,le,cs) => (rat_of_int a, le, map rat_of_int cs)) |
259 val rineqs = map (fn (a,le,cs) => (rat_of_intinf a, le, map rat_of_intinf cs)) |
262 ineqs |
260 ineqs |
263 in pick_vars discr (rineqs,replicate n rat0) end; |
261 in pick_vars discr (rineqs,replicate n rat0) end; |
264 |
262 |
265 (* ------------------------------------------------------------------------- *) |
263 (* ------------------------------------------------------------------------- *) |
266 (* End of counter example finder. The actual decision procedure starts here. *) |
264 (* End of counter example finder. The actual decision procedure starts here. *) |
299 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *) |
297 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *) |
300 (* ------------------------------------------------------------------------- *) |
298 (* ------------------------------------------------------------------------- *) |
301 |
299 |
302 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) = |
300 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) = |
303 let val c1 = el v l1 and c2 = el v l2 |
301 let val c1 = el v l1 and c2 = el v l2 |
304 val m = IntInf.toInt (lcm(IntInf.fromInt (abs c1), IntInf.fromInt(abs c2))) |
302 val m = lcm(abs c1, abs c2) |
305 val m1 = m div (abs c1) and m2 = m div (abs c2) |
303 val m1 = m div (abs c1) and m2 = m div (abs c2) |
306 val (n1,n2) = |
304 val (n1,n2) = |
307 if (c1 >= 0) = (c2 >= 0) |
305 if (c1 >= 0) = (c2 >= 0) |
308 then if ty1 = Eq then (~m1,m2) |
306 then if ty1 = Eq then (~m1,m2) |
309 else if ty2 = Eq then (m1,~m2) |
307 else if ty2 = Eq then (m1,~m2) |
320 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l; |
318 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l; |
321 |
319 |
322 fun is_answer (ans as Lineq(k,ty,l,_)) = |
320 fun is_answer (ans as Lineq(k,ty,l,_)) = |
323 case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0; |
321 case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0; |
324 |
322 |
325 fun calc_blowup l = |
323 fun calc_blowup (l:IntInf.int list) = |
326 let val (p,n) = List.partition (apl(0,op<)) (List.filter (apl(0,op<>)) l) |
324 let val (p,n) = List.partition (apl(0,op<)) (List.filter (apl(0,op<>)) l) |
327 in (length p) * (length n) end; |
325 in (length p) * (length n) end; |
328 |
326 |
329 (* ------------------------------------------------------------------------- *) |
327 (* ------------------------------------------------------------------------- *) |
330 (* Main elimination code: *) |
328 (* Main elimination code: *) |
348 in extract [] end; |
346 in extract [] end; |
349 |
347 |
350 fun print_ineqs ineqs = |
348 fun print_ineqs ineqs = |
351 if !trace then |
349 if !trace then |
352 tracing(cat_lines(""::map (fn Lineq(c,t,l,_) => |
350 tracing(cat_lines(""::map (fn Lineq(c,t,l,_) => |
353 string_of_int c ^ |
351 IntInf.toString c ^ |
354 (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^ |
352 (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^ |
355 commas(map string_of_int l)) ineqs)) |
353 commas(map IntInf.toString l)) ineqs)) |
356 else (); |
354 else (); |
357 |
355 |
358 type history = (int * lineq list) list; |
356 type history = (int * lineq list) list; |
359 datatype result = Success of injust | Failure of history; |
357 datatype result = Success of injust | Failure of history; |
360 |
358 |
369 if null nontriv then Failure(hist) |
367 if null nontriv then Failure(hist) |
370 else |
368 else |
371 let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in |
369 let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in |
372 if not(null eqs) then |
370 if not(null eqs) then |
373 let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs) |
371 let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs) |
374 val sclist = sort (fn (x,y) => int_ord(abs(x),abs(y))) |
372 val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y))) |
375 (List.filter (fn i => i<>0) clist) |
373 (List.filter (fn i => i<>0) clist) |
376 val c = hd sclist |
374 val c = hd sclist |
377 val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) = |
375 val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) = |
378 extract_first (fn Lineq(_,_,l,_) => c mem l) eqs |
376 extract_first (fn Lineq(_,_,l,_) => c mem l) eqs |
379 val v = find_index_eq c ceq |
377 val v = find_index_eq c ceq |
476 | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD)) |
474 | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD)) |
477 | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD) |
475 | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD) |
478 | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD)) |
476 | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD)) |
479 | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD) |
477 | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD) |
480 | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2))) |
478 | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2))) |
481 | mk(Multiplied(n,j)) = (trace_msg("*"^string_of_int n); trace_thm "*" (multn(n,mk j))) |
479 | mk(Multiplied(n,j)) = (trace_msg("*"^IntInf.toString n); trace_thm "*" (multn(n,mk j))) |
482 | mk(Multiplied2(n,j)) = simp (trace_msg("**"^string_of_int n); trace_thm "**" (multn2(n,mk j))) |
480 | mk(Multiplied2(n,j)) = simp (trace_msg("**"^IntInf.toString n); trace_thm "**" (multn2(n,mk j))) |
483 |
481 |
484 in trace_msg "mkthm"; |
482 in trace_msg "mkthm"; |
485 let val thm = trace_thm "Final thm:" (mk just) |
483 let val thm = trace_thm "Final thm:" (mk just) |
486 in let val fls = simplify simpset thm |
484 in let val fls = simplify simpset thm |
487 in trace_thm "After simplification:" fls; |
485 in trace_thm "After simplification:" fls; |
495 end |
493 end |
496 end handle FalseE thm => (trace_thm "False reached early:" thm; thm) |
494 end handle FalseE thm => (trace_thm "False reached early:" thm; thm) |
497 end |
495 end |
498 end; |
496 end; |
499 |
497 |
500 fun coeff poly atom = case assoc(poly,atom) of NONE => 0 | SOME i => i; |
498 fun coeff poly atom : IntInf.int = |
501 |
499 case assoc(poly,atom) of NONE => 0 | SOME i => i; |
502 fun lcms_intinf is = Library.foldl lcm (1, is); |
500 |
503 fun lcms is = IntInf.toInt (lcms_intinf (map IntInf.fromInt is)); |
501 fun lcms is = Library.foldl lcm (1, is); |
504 |
502 |
505 fun integ(rlhs,r,rel,rrhs,s,d) = |
503 fun integ(rlhs,r,rel,rrhs,s,d) = |
506 let val (rn,rd) = pairself IntInf.toInt (rep_rat r) and (sn,sd) = pairself IntInf.toInt (rep_rat s) |
504 let val (rn,rd) = rep_rat r and (sn,sd) = rep_rat s |
507 val m = IntInf.toInt (lcms_intinf(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs))) |
505 val m = lcms(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs)) |
508 fun mult(t,r) = |
506 fun mult(t,r) = |
509 let val (i,j) = pairself IntInf.toInt (rep_rat r) |
507 let val (i,j) = (rep_rat r) |
510 in (t,i * (m div j)) end |
508 in (t,i * (m div j)) end |
511 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end |
509 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end |
512 |
510 |
513 fun mklineq n atoms = |
511 fun mklineq n atoms = |
514 fn (item,k) => |
512 fn (item,k) => |