src/Provers/Arith/fast_lin_arith.ML
 author wenzelm Thu Jan 19 21:22:08 2006 +0100 (2006-01-19 ago) changeset 18708 4b3dadb4fe33 parent 18572 dab1dd61e59d child 19049 2103a8e14eaa permissions -rw-r--r--
setup: theory -> theory;
1 (*  Title:      Provers/Arith/fast_lin_arith.ML
2     ID:         \$Id\$
3     Author:     Tobias Nipkow
4     Copyright   1998  TU Munich
6 A generic linear arithmetic package.
7 It provides two tactics
9     lin_arith_tac:         int -> tactic
10 cut_lin_arith_tac: thms -> int -> tactic
12 and a simplification procedure
14     lin_arith_prover: theory -> simpset -> term -> thm option
16 Only take premises and conclusions into account that are already (negated)
17 (in)equations. lin_arith_prover tries to prove or disprove the term.
18 *)
20 (* Debugging: set Fast_Arith.trace *)
22 (*** Data needed for setting up the linear arithmetic package ***)
24 signature LIN_ARITH_LOGIC =
25 sig
26   val conjI:		thm
27   val ccontr:           thm (* (~ P ==> False) ==> P *)
28   val notI:             thm (* (P ==> False) ==> ~ P *)
29   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
30   val not_leD:          thm (* ~(m <= n) ==> n < m *)
31   val sym:		thm (* x = y ==> y = x *)
32   val mk_Eq: thm -> thm
33   val atomize: thm -> thm list
34   val mk_Trueprop: term -> term
35   val neg_prop: term -> term
36   val is_False: thm -> bool
37   val is_nat: typ list * term -> bool
38   val mk_nat_thm: theory -> term -> thm
39 end;
40 (*
41 mk_Eq(~in) = `in == False'
42 mk_Eq(in) = `in == True'
43 where `in' is an (in)equality.
45 neg_prop(t) = neg if t is wrapped up in Trueprop and
46   nt is the (logically) negated version of t, where the negation
47   of a negative term is the term itself (no double negation!);
49 is_nat(parameter-types,t) =  t:nat
50 mk_nat_thm(t) = "0 <= t"
51 *)
53 signature LIN_ARITH_DATA =
54 sig
55   val decomp:
56     theory -> term -> ((term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool) option
57   val number_of: IntInf.int * typ -> term
58 end;
59 (*
60 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
61    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
62          p/q is the decomposition of the sum terms x/y into a list
63          of summand * multiplicity pairs and a constant summand and
64          d indicates if the domain is discrete.
66 ss must reduce contradictory <= to False.
67    It should also cancel common summands to keep <= reduced;
68    otherwise <= can grow to massive proportions.
69 *)
71 signature FAST_LIN_ARITH =
72 sig
73   val setup: theory -> theory
74   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
75                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}
76                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
77                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset})
78                 -> theory -> theory
79   val trace           : bool ref
80   val fast_arith_neq_limit: int ref
81   val lin_arith_prover: theory -> simpset -> term -> thm option
82   val     lin_arith_tac:    bool -> int -> tactic
83   val cut_lin_arith_tac: simpset -> int -> tactic
84 end;
86 functor Fast_Lin_Arith(structure LA_Logic:LIN_ARITH_LOGIC
87                        and       LA_Data:LIN_ARITH_DATA) : FAST_LIN_ARITH =
88 struct
91 (** theory data **)
93 (* data kind 'Provers/fast_lin_arith' *)
95 structure Data = TheoryDataFun
96 (struct
97   val name = "Provers/fast_lin_arith";
98   type T = {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
99             lessD: thm list, neqE: thm list, simpset: Simplifier.simpset};
101   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
102                lessD = [], neqE = [], simpset = Simplifier.empty_ss};
103   val copy = I;
104   val extend = I;
106   fun merge _
107     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
108       lessD = lessD1, neqE=neqE1, simpset = simpset1},
109      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
110       lessD = lessD2, neqE=neqE2, simpset = simpset2}) =
112      mult_mono_thms = Drule.merge_rules (mult_mono_thms1, mult_mono_thms2),
113      inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
114      lessD = Drule.merge_rules (lessD1, lessD2),
115      neqE = Drule.merge_rules (neqE1, neqE2),
116      simpset = Simplifier.merge_ss (simpset1, simpset2)};
118   fun print _ _ = ();
119 end);
121 val map_data = Data.map;
122 val setup = Data.init;
126 (*** A fast decision procedure ***)
127 (*** Code ported from HOL Light ***)
128 (* possible optimizations:
129    use (var,coeff) rep or vector rep  tp save space;
130    treat non-negative atoms separately rather than adding 0 <= atom
131 *)
133 val trace = ref false;
135 datatype lineq_type = Eq | Le | Lt;
137 datatype injust = Asm of int
138                 | Nat of int (* index of atom *)
139                 | LessD of injust
140                 | NotLessD of injust
141                 | NotLeD of injust
142                 | NotLeDD of injust
143                 | Multiplied of IntInf.int * injust
144                 | Multiplied2 of IntInf.int * injust
145                 | Added of injust * injust;
147 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust;
149 fun el 0 (h::_) = h
150   | el n (_::t) = el (n - 1) t
151   | el _ _  = sys_error "el";
153 (* ------------------------------------------------------------------------- *)
154 (* Finding a (counter) example from the trace of a failed elimination        *)
155 (* ------------------------------------------------------------------------- *)
156 (* Examples are represented as rational numbers,                             *)
157 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
159 exception NoEx;
161 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
162    In general, true means the bound is included, false means it is excluded.
163    Need to know if it is a lower or upper bound for unambiguous interpretation!
164 *)
166 fun elim_eqns(ineqs,Lineq(i,Le,cs,_)) = (i,true,cs)::ineqs
167   | elim_eqns(ineqs,Lineq(i,Eq,cs,_)) = (i,true,cs)::(~i,true,map ~ cs)::ineqs
168   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
170 (* PRE: ex[v] must be 0! *)
171 fun eval (ex:Rat.rat list) v (a:IntInf.int,le,cs:IntInf.int list) =
172   let val rs = map Rat.rat_of_intinf cs
173       val rsum = Library.foldl Rat.add (Rat.zero, map Rat.mult (rs ~~ ex))
174   in (Rat.mult (Rat.add(Rat.rat_of_intinf a,Rat.neg rsum), Rat.inv(el v rs)), le) end;
175 (* If el v rs < 0, le should be negated.
176    Instead this swap is taken into account in ratrelmin2.
177 *)
179 fun ratrelmin2(x as (r,ler),y as (s,les)) =
180   if r=s then (r, (not ler) andalso (not les)) else if Rat.le(r,s) then x else y;
181 fun ratrelmax2(x as (r,ler),y as (s,les)) =
182   if r=s then (r,ler andalso les) else if Rat.le(r,s) then y else x;
184 val ratrelmin = foldr1 ratrelmin2;
185 val ratrelmax = foldr1 ratrelmax2;
187 fun ratexact up (r,exact) =
188   if exact then r else
189   let val (p,q) = Rat.quotient_of_rat r
190       val nth = Rat.inv(Rat.rat_of_intinf q)
191   in Rat.add(r,if up then nth else Rat.neg nth) end;
193 fun ratmiddle(r,s) = Rat.mult(Rat.add(r,s),Rat.inv(Rat.rat_of_int 2));
195 fun choose2 d ((lb,exactl),(ub,exactu)) =
196   if Rat.le(lb,Rat.zero) andalso (lb <> Rat.zero orelse exactl) andalso
197      Rat.le(Rat.zero,ub) andalso (ub <> Rat.zero orelse exactu)
198   then Rat.zero else
199   if not d
200   then (if Rat.ge0 lb
201         then if exactl then lb else ratmiddle(lb,ub)
202         else if exactu then ub else ratmiddle(lb,ub))
203   else (* discrete domain, both bounds must be exact *)
204   if Rat.ge0 lb then let val lb' = Rat.roundup lb
205                     in if Rat.le(lb',ub) then lb' else raise NoEx end
206                else let val ub' = Rat.rounddown ub
207                     in if Rat.le(lb,ub') then ub' else raise NoEx end;
209 fun findex1 discr (ex,(v,lineqs)) =
210   let val nz = List.filter (fn (Lineq(_,_,cs,_)) => el v cs <> 0) lineqs;
211       val ineqs = Library.foldl elim_eqns ([],nz)
212       val (ge,le) = List.partition (fn (_,_,cs) => el v cs > 0) ineqs
213       val lb = ratrelmax(map (eval ex v) ge)
214       val ub = ratrelmin(map (eval ex v) le)
215   in nth_update (v, choose2 (nth discr v) (lb, ub)) ex end;
217 fun findex discr = Library.foldl (findex1 discr);
219 fun elim1 v x =
220   map (fn (a,le,bs) => (Rat.add(a,Rat.neg(Rat.mult(el v bs,x))), le,
221                         nth_update (v, Rat.zero) bs));
223 fun single_var v (_,_,cs) = (filter_out (equal Rat.zero) cs = [el v cs]);
225 (* The base case:
226    all variables occur only with positive or only with negative coefficients *)
227 fun pick_vars discr (ineqs,ex) =
228   let val nz = filter_out (fn (_,_,cs) => forall (equal Rat.zero) cs) ineqs
229   in case nz of [] => ex
230      | (_,_,cs) :: _ =>
231        let val v = find_index (not o equal Rat.zero) cs
232            val d = nth discr v
233            val pos = Rat.ge0(el v cs)
234            val sv = List.filter (single_var v) nz
235            val minmax =
236              if pos then if d then Rat.roundup o fst o ratrelmax
237                          else ratexact true o ratrelmax
238                     else if d then Rat.rounddown o fst o ratrelmin
239                          else ratexact false o ratrelmin
240            val bnds = map (fn (a,le,bs) => (Rat.mult(a,Rat.inv(el v bs)),le)) sv
241            val x = minmax((Rat.zero,if pos then true else false)::bnds)
242            val ineqs' = elim1 v x nz
243            val ex' = nth_update (v, x) ex
244        in pick_vars discr (ineqs',ex') end
245   end;
247 fun findex0 discr n lineqs =
248   let val ineqs = Library.foldl elim_eqns ([],lineqs)
249       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_intinf a, le, map Rat.rat_of_intinf cs))
250                        ineqs
251   in pick_vars discr (rineqs,replicate n Rat.zero) end;
253 (* ------------------------------------------------------------------------- *)
254 (* End of counter example finder. The actual decision procedure starts here. *)
255 (* ------------------------------------------------------------------------- *)
257 (* ------------------------------------------------------------------------- *)
258 (* Calculate new (in)equality type after addition.                           *)
259 (* ------------------------------------------------------------------------- *)
261 fun find_add_type(Eq,x) = x
262   | find_add_type(x,Eq) = x
263   | find_add_type(_,Lt) = Lt
264   | find_add_type(Lt,_) = Lt
265   | find_add_type(Le,Le) = Le;
267 (* ------------------------------------------------------------------------- *)
268 (* Multiply out an (in)equation.                                             *)
269 (* ------------------------------------------------------------------------- *)
271 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
272   if n = 1 then i
273   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
274   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
275   else Lineq (n * k, ty, map (curry op* n) l, Multiplied (n, just));
277 (* ------------------------------------------------------------------------- *)
278 (* Add together (in)equations.                                               *)
279 (* ------------------------------------------------------------------------- *)
281 fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
282   let val l = map2 (curry (op +)) l1 l2
285 (* ------------------------------------------------------------------------- *)
286 (* Elimination of variable between a single pair of (in)equations.           *)
287 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
288 (* ------------------------------------------------------------------------- *)
290 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
291   let val c1 = el v l1 and c2 = el v l2
292       val m = lcm(abs c1, abs c2)
293       val m1 = m div (abs c1) and m2 = m div (abs c2)
294       val (n1,n2) =
295         if (c1 >= 0) = (c2 >= 0)
296         then if ty1 = Eq then (~m1,m2)
297              else if ty2 = Eq then (m1,~m2)
298                   else sys_error "elim_var"
299         else (m1,m2)
300       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
301                     then (~n1,~n2) else (n1,n2)
302   in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end;
304 (* ------------------------------------------------------------------------- *)
305 (* The main refutation-finding code.                                         *)
306 (* ------------------------------------------------------------------------- *)
308 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
310 fun is_answer (ans as Lineq(k,ty,l,_)) =
311   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
313 fun calc_blowup (l:IntInf.int list) =
314   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
315   in (length p) * (length n) end;
317 (* ------------------------------------------------------------------------- *)
318 (* Main elimination code:                                                    *)
319 (*                                                                           *)
320 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
321 (*                                                                           *)
322 (* (2) If there are any equations, picks a variable with the lowest absolute *)
323 (* coefficient in any of them, and uses it to eliminate.                     *)
324 (*                                                                           *)
325 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
326 (* blowup (number of consequences generated) and eliminates it.              *)
327 (* ------------------------------------------------------------------------- *)
329 fun allpairs f xs ys =
330   List.concat(map (fn x => map (fn y => f x y) ys) xs);
332 fun extract_first p =
333   let fun extract xs (y::ys) = if p y then (SOME y,xs@ys)
334                                else extract (y::xs) ys
335         | extract xs []      = (NONE,xs)
336   in extract [] end;
338 fun print_ineqs ineqs =
339   if !trace then
340      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
341        IntInf.toString c ^
342        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
343        commas(map IntInf.toString l)) ineqs))
344   else ();
346 type history = (int * lineq list) list;
347 datatype result = Success of injust | Failure of history;
349 fun elim(ineqs,hist) =
350   let val dummy = print_ineqs ineqs;
351       val (triv,nontriv) = List.partition is_trivial ineqs in
352   if not(null triv)
353   then case Library.find_first is_answer triv of
354          NONE => elim(nontriv,hist)
355        | SOME(Lineq(_,_,_,j)) => Success j
356   else
357   if null nontriv then Failure(hist)
358   else
359   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
360   if not(null eqs) then
361      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
362          val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y)))
363                            (List.filter (fn i => i<>0) clist)
364          val c = hd sclist
365          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
366                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
367          val v = find_index_eq c ceq
368          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0)
369                                      (othereqs @ noneqs)
370          val others = map (elim_var v eq) roth @ ioth
371      in elim(others,(v,nontriv)::hist) end
372   else
373   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
374       val numlist = 0 upto (length(hd lists) - 1)
375       val coeffs = map (fn i => map (el i) lists) numlist
376       val blows = map calc_blowup coeffs
377       val iblows = blows ~~ numlist
378       val nziblows = List.filter (fn (i,_) => i<>0) iblows
379   in if null nziblows then Failure((~1,nontriv)::hist)
380      else
381      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
382          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs
383          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes
384      in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
385   end
386   end
387   end;
389 (* ------------------------------------------------------------------------- *)
390 (* Translate back a proof.                                                   *)
391 (* ------------------------------------------------------------------------- *)
393 fun trace_thm msg th =
394     if !trace then (tracing msg; tracing (Display.string_of_thm th); th) else th;
396 fun trace_msg msg =
397     if !trace then tracing msg else ();
399 (* FIXME OPTIMIZE!!!! (partly done already)
400    Addition/Multiplication need i*t representation rather than t+t+...
401    Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
402    because Numerals are not known early enough.
404 Simplification may detect a contradiction 'prematurely' due to type
405 information: n+1 <= 0 is simplified to False and does not need to be crossed
406 with 0 <= n.
407 *)
408 local
409  exception FalseE of thm
410 in
411 fun mkthm (sg, ss) asms just =
412   let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} =
413           Data.get sg;
414       val simpset' = Simplifier.inherit_context ss simpset;
415       val atoms = Library.foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
416                             map fst lhs  union  (map fst rhs  union  ats))
417                         ([], List.mapPartial (fn thm => if Thm.no_prems thm
418                                         then LA_Data.decomp sg (concl_of thm)
419                                         else NONE) asms)
421       fun add2 thm1 thm2 =
422         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
423         in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms
424         end;
426       fun try_add [] _ = NONE
427         | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of
428              NONE => try_add thm1s thm2 | some => some;
430       fun addthms thm1 thm2 =
431         case add2 thm1 thm2 of
432           NONE => (case try_add ([thm1] RL inj_thms) thm2 of
433                      NONE => ( valOf(try_add ([thm2] RL inj_thms) thm1)
434                                handle Option =>
435                                (trace_thm "" thm1; trace_thm "" thm2;
436                                 sys_error "Lin.arith. failed to add thms")
437                              )
438                    | SOME thm => thm)
439         | SOME thm => thm;
441       fun multn(n,thm) =
442         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
443         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
444 (*
445       fun multn2(n,thm) =
446         let val SOME(mth,cv) =
447               get_first (fn (th,cv) => SOME(thm RS th,cv) handle THM _ => NONE) mult_mono_thms
448             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
449         in instantiate ([],[(cv,ct)]) mth end
450 *)
451       fun multn2(n,thm) =
452         let val SOME(mth) =
453               get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms
454             fun cvar(th,_ \$ (_ \$ _ \$ var)) = cterm_of (#sign(rep_thm th)) var;
455             val cv = cvar(mth, hd(prems_of mth));
456             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
457         in instantiate ([],[(cv,ct)]) mth end
459       fun simp thm =
460         let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm)
461         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
463       fun mk(Asm i) = trace_thm "Asm" (nth asms i)
464         | mk(Nat i) = (trace_msg "Nat"; LA_Logic.mk_nat_thm sg (nth atoms i))
465         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
466         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
467         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
468         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
469         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
470         | mk(Multiplied(n,j)) = (trace_msg("*"^IntInf.toString n); trace_thm "*" (multn(n,mk j)))
471         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^IntInf.toString n); trace_thm "**" (multn2(n,mk j)))
473   in trace_msg "mkthm";
474      let val thm = trace_thm "Final thm:" (mk just)
475      in let val fls = simplify simpset' thm
476         in trace_thm "After simplification:" fls;
477            if LA_Logic.is_False fls then fls
478            else
479             (tracing "Assumptions:"; List.app print_thm asms;
480              tracing "Proved:"; print_thm fls;
481              warning "Linear arithmetic should have refuted the assumptions.\n\
482                      \Please inform Tobias Nipkow (nipkow@in.tum.de).";
483              fls)
484         end
485      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
486   end
487 end;
489 fun coeff poly atom : IntInf.int =
490   AList.lookup (op =) poly atom |> the_default 0;
492 fun lcms is = Library.foldl lcm (1, is);
494 fun integ(rlhs,r,rel,rrhs,s,d) =
495 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
496     val m = lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
497     fun mult(t,r) =
498         let val (i,j) = Rat.quotient_of_rat r
499         in (t,i * (m div j)) end
500 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
502 fun mklineq n atoms =
503   fn (item,k) =>
504   let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
505       val lhsa = map (coeff lhs) atoms
506       and rhsa = map (coeff rhs) atoms
507       val diff = map2 (curry (op -)) rhsa lhsa
508       val c = i-j
509       val just = Asm k
510       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
511   in case rel of
512       "<="   => lineq(c,Le,diff,just)
513      | "~<=" => if discrete
514                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
515                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
516      | "<"   => if discrete
517                 then lineq(c+1,Le,diff,LessD(just))
518                 else lineq(c,Lt,diff,just)
519      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
520      | "="   => lineq(c,Eq,diff,just)
521      | _     => sys_error("mklineq" ^ rel)
522   end;
524 (* ------------------------------------------------------------------------- *)
525 (* Print (counter) example                                                   *)
526 (* ------------------------------------------------------------------------- *)
528 fun print_atom((a,d),r) =
529   let val (p,q) = Rat.quotient_of_rat r
530       val s = if d then IntInf.toString p else
531               if p = 0 then "0"
532               else IntInf.toString p ^ "/" ^ IntInf.toString q
533   in a ^ " = " ^ s end;
535 fun print_ex sds =
536   curry (op ~~) sds
537   #> map print_atom
538   #> commas
539   #> curry (op ^) "Counter example:\n"
540   #> tracing;
542 fun trace_ex(sg,params,atoms,discr,n,hist:history) =
543   if null hist then ()
544   else let val frees = map Free params;
545            fun s_of_t t = Sign.string_of_term sg (subst_bounds(frees,t));
546            val (v,lineqs) :: hist' = hist
547            val start = if v = ~1 then (findex0 discr n lineqs,hist')
548                        else (replicate n Rat.zero,hist)
549        in warning "arith failed - see trace for a counter example";
550           print_ex ((map s_of_t atoms)~~discr) (findex discr start)
551           handle NoEx => (tracing "Sorry, no counter example.")
552        end;
554 fun mknat pTs ixs (atom,i) =
555   if LA_Logic.is_nat(pTs,atom)
556   then let val l = map (fn j => if j=i then 1 else 0) ixs
557        in SOME(Lineq(0,Le,l,Nat(i))) end
558   else NONE
560 (* This code is tricky. It takes a list of premises in the order they occur
561 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
562 ones as NONE. Going through the premises, each numeric one is converted into
563 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
564 >. Thus split_items returns a list of equation systems. This may blow up if
565 there are many ~=, but in practice it does not seem to happen. The really
566 tricky bit is to arrange the order of the cases such that they coincide with
567 the order in which the cases are in the end generated by the tactic that
568 applies the generated refutation thms (see function 'refute_tac').
570 For variables n of type nat, a constraint 0 <= n is added.
571 *)
572 fun split_items(items) =
573   let fun elim_neq front _ [] = [rev front]
574         | elim_neq front n (NONE::ineqs) = elim_neq front (n+1) ineqs
575         | elim_neq front n (SOME(ineq as (l,i,rel,r,j,d))::ineqs) =
576           if rel = "~=" then elim_neq front n (ineqs @ [SOME(l,i,"<",r,j,d)]) @
577                              elim_neq front n (ineqs @ [SOME(r,j,"<",l,i,d)])
578           else elim_neq ((ineq,n) :: front) (n+1) ineqs
579   in elim_neq [] 0 items end;
581 fun add_atoms(ats,((lhs,_,_,rhs,_,_),_)) =
582     (map fst lhs) union ((map fst rhs) union ats)
584 fun add_datoms(dats,((lhs,_,_,rhs,_,d),_)) =
585     (map (pair d o fst) lhs) union ((map (pair d o fst) rhs) union dats)
587 fun discr initems = map fst (Library.foldl add_datoms ([],initems));
589 fun refutes sg (pTs,params) ex =
590 let
591   fun refute (initems::initemss) js =
592     let val atoms = Library.foldl add_atoms ([],initems)
593         val n = length atoms
594         val mkleq = mklineq n atoms
595         val ixs = 0 upto (n-1)
596         val iatoms = atoms ~~ ixs
597         val natlineqs = List.mapPartial (mknat pTs ixs) iatoms
598         val ineqs = map mkleq initems @ natlineqs
599     in case elim(ineqs,[]) of
600          Success(j) =>
601            (trace_msg "Contradiction!"; refute initemss (js@[j]))
602        | Failure(hist) =>
603            (if not ex then ()
604             else trace_ex(sg,params,atoms,discr initems,n,hist);
605             NONE)
606     end
607     | refute [] js = SOME js
608 in refute end;
610 fun refute sg ps ex items = refutes sg ps ex (split_items items) [];
612 fun refute_tac ss (i,justs) =
613   fn state =>
614     let val sg = #sign(rep_thm state)
615         val {neqE, ...} = Data.get sg;
616         fun just1 j = REPEAT_DETERM(eresolve_tac neqE i) THEN
617           METAHYPS (fn asms => rtac (mkthm (sg, ss) asms j) 1) i
618     in DETERM(resolve_tac [LA_Logic.notI,LA_Logic.ccontr] i) THEN
619        EVERY(map just1 justs)
620     end
621     state;
623 fun count P xs = length(List.filter P xs);
625 (* The limit on the number of ~= allowed.
626    Because each ~= is split into two cases, this can lead to an explosion.
627 *)
628 val fast_arith_neq_limit = ref 9;
630 fun prove sg ps ex Hs concl =
631 let val Hitems = map (LA_Data.decomp sg) Hs
632 in if count (fn NONE => false | SOME(_,_,r,_,_,_) => r = "~=") Hitems
633       > !fast_arith_neq_limit then NONE
634    else
635    case LA_Data.decomp sg concl of
636      NONE => refute sg ps ex (Hitems@[NONE])
637    | SOME(citem as (r,i,rel,l,j,d)) =>
638        let val neg::rel0 = explode rel
639            val nrel = if neg = "~" then implode rel0 else "~"^rel
640        in refute sg ps ex (Hitems @ [SOME(r,i,nrel,l,j,d)]) end
641 end;
643 (*
644 Fast but very incomplete decider. Only premises and conclusions
645 that are already (negated) (in)equations are taken into account.
646 *)
647 fun simpset_lin_arith_tac ss ex i st = SUBGOAL (fn (A,_) =>
648   let val params = rev(Logic.strip_params A)
649       val pTs = map snd params
650       val Hs = Logic.strip_assums_hyp A
651       val concl = Logic.strip_assums_concl A
652   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
653      case prove (Thm.sign_of_thm st) (pTs,params) ex Hs concl of
654        NONE => (trace_msg "Refutation failed."; no_tac)
655      | SOME js => (trace_msg "Refutation succeeded."; refute_tac ss (i,js))
656   end) i st;
658 fun lin_arith_tac ex i st =
659   simpset_lin_arith_tac (Simplifier.theory_context (Thm.theory_of_thm st) Simplifier.empty_ss)
660     ex i st;
662 fun cut_lin_arith_tac ss i =
663   cut_facts_tac (Simplifier.prems_of_ss ss) i THEN
664   simpset_lin_arith_tac ss false i;
666 (** Forward proof from theorems **)
668 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
669 to splits of ~= premises) such that it coincides with the order of the cases
670 generated by function split_items. *)
672 datatype splittree = Tip of thm list
673                    | Spl of thm * cterm * splittree * cterm * splittree
675 fun extract imp =
676 let val (Il,r) = Thm.dest_comb imp
677     val (_,imp1) = Thm.dest_comb Il
678     val (Ict1,_) = Thm.dest_comb imp1
679     val (_,ct1) = Thm.dest_comb Ict1
680     val (Ir,_) = Thm.dest_comb r
681     val (_,Ict2r) = Thm.dest_comb Ir
682     val (Ict2,_) = Thm.dest_comb Ict2r
683     val (_,ct2) = Thm.dest_comb Ict2
684 in (ct1,ct2) end;
686 fun splitasms sg asms =
687 let val {neqE, ...}  = Data.get sg;
688     fun split(asms',[]) = Tip(rev asms')
689       | split(asms',asm::asms) =
690       (case get_first (fn th => SOME(asm COMP th) handle THM _ => NONE) neqE
691        of SOME spl =>
692           let val (ct1,ct2) = extract(cprop_of spl)
693               val thm1 = assume ct1 and thm2 = assume ct2
694           in Spl(spl,ct1,split(asms',asms@[thm1]),ct2,split(asms',asms@[thm2]))
695           end
696        | NONE => split(asm::asms', asms))
697 in split([],asms) end;
699 fun fwdproof ctxt (Tip asms) (j::js) = (mkthm ctxt asms j, js)
700   | fwdproof ctxt (Spl(thm,ct1,tree1,ct2,tree2)) js =
701     let val (thm1,js1) = fwdproof ctxt tree1 js
702         val (thm2,js2) = fwdproof ctxt tree2 js1
703         val thm1' = implies_intr ct1 thm1
704         val thm2' = implies_intr ct2 thm2
705     in (thm2' COMP (thm1' COMP thm), js2) end;
706 (* needs handle THM _ => NONE ? *)
708 fun prover (ctxt as (sg, _)) thms Tconcl js pos =
709 let val nTconcl = LA_Logic.neg_prop Tconcl
710     val cnTconcl = cterm_of sg nTconcl
711     val nTconclthm = assume cnTconcl
712     val tree = splitasms sg (thms @ [nTconclthm])
713     val (thm,_) = fwdproof ctxt tree js
714     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
715 in SOME(LA_Logic.mk_Eq((implies_intr cnTconcl thm) COMP contr)) end
716 (* in case concl contains ?-var, which makes assume fail: *)
717 handle THM _ => NONE;
719 (* PRE: concl is not negated!
720    This assumption is OK because
721    1. lin_arith_prover tries both to prove and disprove concl and
722    2. lin_arith_prover is applied by the simplifier which
723       dives into terms and will thus try the non-negated concl anyway.
724 *)
725 fun lin_arith_prover sg ss concl =
726 let
727     val thms = List.concat(map LA_Logic.atomize (prems_of_ss ss));
728     val Hs = map (#prop o rep_thm) thms
729     val Tconcl = LA_Logic.mk_Trueprop concl
730 in case prove sg ([],[]) false Hs Tconcl of (* concl provable? *)
731      SOME js => prover (sg, ss) thms Tconcl js true
732    | NONE => let val nTconcl = LA_Logic.neg_prop Tconcl
733           in case prove sg ([],[]) false Hs nTconcl of (* ~concl provable? *)
734                SOME js => prover (sg, ss) thms nTconcl js false
735              | NONE => NONE
736           end
737 end;
739 end;