src/Provers/Arith/fast_lin_arith.ML
author wenzelm
Thu Jan 19 21:22:08 2006 +0100 (2006-01-19 ago)
changeset 18708 4b3dadb4fe33
parent 18572 dab1dd61e59d
child 19049 2103a8e14eaa
permissions -rw-r--r--
setup: theory -> theory;
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow
     4     Copyright   1998  TU Munich
     5 
     6 A generic linear arithmetic package.
     7 It provides two tactics
     8 
     9     lin_arith_tac:         int -> tactic
    10 cut_lin_arith_tac: thms -> int -> tactic
    11 
    12 and a simplification procedure
    13 
    14     lin_arith_prover: theory -> simpset -> term -> thm option
    15 
    16 Only take premises and conclusions into account that are already (negated)
    17 (in)equations. lin_arith_prover tries to prove or disprove the term.
    18 *)
    19 
    20 (* Debugging: set Fast_Arith.trace *)
    21 
    22 (*** Data needed for setting up the linear arithmetic package ***)
    23 
    24 signature LIN_ARITH_LOGIC =
    25 sig
    26   val conjI:		thm
    27   val ccontr:           thm (* (~ P ==> False) ==> P *)
    28   val notI:             thm (* (P ==> False) ==> ~ P *)
    29   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
    30   val not_leD:          thm (* ~(m <= n) ==> n < m *)
    31   val sym:		thm (* x = y ==> y = x *)
    32   val mk_Eq: thm -> thm
    33   val atomize: thm -> thm list
    34   val mk_Trueprop: term -> term
    35   val neg_prop: term -> term
    36   val is_False: thm -> bool
    37   val is_nat: typ list * term -> bool
    38   val mk_nat_thm: theory -> term -> thm
    39 end;
    40 (*
    41 mk_Eq(~in) = `in == False'
    42 mk_Eq(in) = `in == True'
    43 where `in' is an (in)equality.
    44 
    45 neg_prop(t) = neg if t is wrapped up in Trueprop and
    46   nt is the (logically) negated version of t, where the negation
    47   of a negative term is the term itself (no double negation!);
    48 
    49 is_nat(parameter-types,t) =  t:nat
    50 mk_nat_thm(t) = "0 <= t"
    51 *)
    52 
    53 signature LIN_ARITH_DATA =
    54 sig
    55   val decomp:
    56     theory -> term -> ((term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool) option
    57   val number_of: IntInf.int * typ -> term
    58 end;
    59 (*
    60 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    61    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    62          p/q is the decomposition of the sum terms x/y into a list
    63          of summand * multiplicity pairs and a constant summand and
    64          d indicates if the domain is discrete.
    65 
    66 ss must reduce contradictory <= to False.
    67    It should also cancel common summands to keep <= reduced;
    68    otherwise <= can grow to massive proportions.
    69 *)
    70 
    71 signature FAST_LIN_ARITH =
    72 sig
    73   val setup: theory -> theory
    74   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    75                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}
    76                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    77                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset})
    78                 -> theory -> theory
    79   val trace           : bool ref
    80   val fast_arith_neq_limit: int ref
    81   val lin_arith_prover: theory -> simpset -> term -> thm option
    82   val     lin_arith_tac:    bool -> int -> tactic
    83   val cut_lin_arith_tac: simpset -> int -> tactic
    84 end;
    85 
    86 functor Fast_Lin_Arith(structure LA_Logic:LIN_ARITH_LOGIC 
    87                        and       LA_Data:LIN_ARITH_DATA) : FAST_LIN_ARITH =
    88 struct
    89 
    90 
    91 (** theory data **)
    92 
    93 (* data kind 'Provers/fast_lin_arith' *)
    94 
    95 structure Data = TheoryDataFun
    96 (struct
    97   val name = "Provers/fast_lin_arith";
    98   type T = {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    99             lessD: thm list, neqE: thm list, simpset: Simplifier.simpset};
   100 
   101   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   102                lessD = [], neqE = [], simpset = Simplifier.empty_ss};
   103   val copy = I;
   104   val extend = I;
   105 
   106   fun merge _
   107     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
   108       lessD = lessD1, neqE=neqE1, simpset = simpset1},
   109      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
   110       lessD = lessD2, neqE=neqE2, simpset = simpset2}) =
   111     {add_mono_thms = Drule.merge_rules (add_mono_thms1, add_mono_thms2),
   112      mult_mono_thms = Drule.merge_rules (mult_mono_thms1, mult_mono_thms2),
   113      inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
   114      lessD = Drule.merge_rules (lessD1, lessD2),
   115      neqE = Drule.merge_rules (neqE1, neqE2),
   116      simpset = Simplifier.merge_ss (simpset1, simpset2)};
   117 
   118   fun print _ _ = ();
   119 end);
   120 
   121 val map_data = Data.map;
   122 val setup = Data.init;
   123 
   124 
   125 
   126 (*** A fast decision procedure ***)
   127 (*** Code ported from HOL Light ***)
   128 (* possible optimizations:
   129    use (var,coeff) rep or vector rep  tp save space;
   130    treat non-negative atoms separately rather than adding 0 <= atom
   131 *)
   132 
   133 val trace = ref false;
   134 
   135 datatype lineq_type = Eq | Le | Lt;
   136 
   137 datatype injust = Asm of int
   138                 | Nat of int (* index of atom *)
   139                 | LessD of injust
   140                 | NotLessD of injust
   141                 | NotLeD of injust
   142                 | NotLeDD of injust
   143                 | Multiplied of IntInf.int * injust
   144                 | Multiplied2 of IntInf.int * injust
   145                 | Added of injust * injust;
   146 
   147 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust;
   148 
   149 fun el 0 (h::_) = h
   150   | el n (_::t) = el (n - 1) t
   151   | el _ _  = sys_error "el";
   152 
   153 (* ------------------------------------------------------------------------- *)
   154 (* Finding a (counter) example from the trace of a failed elimination        *)
   155 (* ------------------------------------------------------------------------- *)
   156 (* Examples are represented as rational numbers,                             *)
   157 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   158 
   159 exception NoEx;
   160 
   161 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   162    In general, true means the bound is included, false means it is excluded.
   163    Need to know if it is a lower or upper bound for unambiguous interpretation!
   164 *)
   165 
   166 fun elim_eqns(ineqs,Lineq(i,Le,cs,_)) = (i,true,cs)::ineqs
   167   | elim_eqns(ineqs,Lineq(i,Eq,cs,_)) = (i,true,cs)::(~i,true,map ~ cs)::ineqs
   168   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
   169 
   170 (* PRE: ex[v] must be 0! *)
   171 fun eval (ex:Rat.rat list) v (a:IntInf.int,le,cs:IntInf.int list) =
   172   let val rs = map Rat.rat_of_intinf cs
   173       val rsum = Library.foldl Rat.add (Rat.zero, map Rat.mult (rs ~~ ex))
   174   in (Rat.mult (Rat.add(Rat.rat_of_intinf a,Rat.neg rsum), Rat.inv(el v rs)), le) end;
   175 (* If el v rs < 0, le should be negated.
   176    Instead this swap is taken into account in ratrelmin2.
   177 *)
   178 
   179 fun ratrelmin2(x as (r,ler),y as (s,les)) =
   180   if r=s then (r, (not ler) andalso (not les)) else if Rat.le(r,s) then x else y;
   181 fun ratrelmax2(x as (r,ler),y as (s,les)) =
   182   if r=s then (r,ler andalso les) else if Rat.le(r,s) then y else x;
   183 
   184 val ratrelmin = foldr1 ratrelmin2;
   185 val ratrelmax = foldr1 ratrelmax2;
   186 
   187 fun ratexact up (r,exact) =
   188   if exact then r else
   189   let val (p,q) = Rat.quotient_of_rat r
   190       val nth = Rat.inv(Rat.rat_of_intinf q)
   191   in Rat.add(r,if up then nth else Rat.neg nth) end;
   192 
   193 fun ratmiddle(r,s) = Rat.mult(Rat.add(r,s),Rat.inv(Rat.rat_of_int 2));
   194 
   195 fun choose2 d ((lb,exactl),(ub,exactu)) =
   196   if Rat.le(lb,Rat.zero) andalso (lb <> Rat.zero orelse exactl) andalso
   197      Rat.le(Rat.zero,ub) andalso (ub <> Rat.zero orelse exactu)
   198   then Rat.zero else
   199   if not d
   200   then (if Rat.ge0 lb
   201         then if exactl then lb else ratmiddle(lb,ub)
   202         else if exactu then ub else ratmiddle(lb,ub))
   203   else (* discrete domain, both bounds must be exact *)
   204   if Rat.ge0 lb then let val lb' = Rat.roundup lb
   205                     in if Rat.le(lb',ub) then lb' else raise NoEx end
   206                else let val ub' = Rat.rounddown ub
   207                     in if Rat.le(lb,ub') then ub' else raise NoEx end;
   208 
   209 fun findex1 discr (ex,(v,lineqs)) =
   210   let val nz = List.filter (fn (Lineq(_,_,cs,_)) => el v cs <> 0) lineqs;
   211       val ineqs = Library.foldl elim_eqns ([],nz)
   212       val (ge,le) = List.partition (fn (_,_,cs) => el v cs > 0) ineqs
   213       val lb = ratrelmax(map (eval ex v) ge)
   214       val ub = ratrelmin(map (eval ex v) le)
   215   in nth_update (v, choose2 (nth discr v) (lb, ub)) ex end;
   216 
   217 fun findex discr = Library.foldl (findex1 discr);
   218 
   219 fun elim1 v x =
   220   map (fn (a,le,bs) => (Rat.add(a,Rat.neg(Rat.mult(el v bs,x))), le,
   221                         nth_update (v, Rat.zero) bs));
   222 
   223 fun single_var v (_,_,cs) = (filter_out (equal Rat.zero) cs = [el v cs]);
   224 
   225 (* The base case:
   226    all variables occur only with positive or only with negative coefficients *)
   227 fun pick_vars discr (ineqs,ex) =
   228   let val nz = filter_out (fn (_,_,cs) => forall (equal Rat.zero) cs) ineqs
   229   in case nz of [] => ex
   230      | (_,_,cs) :: _ =>
   231        let val v = find_index (not o equal Rat.zero) cs
   232            val d = nth discr v
   233            val pos = Rat.ge0(el v cs)
   234            val sv = List.filter (single_var v) nz
   235            val minmax =
   236              if pos then if d then Rat.roundup o fst o ratrelmax
   237                          else ratexact true o ratrelmax
   238                     else if d then Rat.rounddown o fst o ratrelmin
   239                          else ratexact false o ratrelmin
   240            val bnds = map (fn (a,le,bs) => (Rat.mult(a,Rat.inv(el v bs)),le)) sv
   241            val x = minmax((Rat.zero,if pos then true else false)::bnds)
   242            val ineqs' = elim1 v x nz
   243            val ex' = nth_update (v, x) ex
   244        in pick_vars discr (ineqs',ex') end
   245   end;
   246 
   247 fun findex0 discr n lineqs =
   248   let val ineqs = Library.foldl elim_eqns ([],lineqs)
   249       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_intinf a, le, map Rat.rat_of_intinf cs))
   250                        ineqs
   251   in pick_vars discr (rineqs,replicate n Rat.zero) end;
   252 
   253 (* ------------------------------------------------------------------------- *)
   254 (* End of counter example finder. The actual decision procedure starts here. *)
   255 (* ------------------------------------------------------------------------- *)
   256 
   257 (* ------------------------------------------------------------------------- *)
   258 (* Calculate new (in)equality type after addition.                           *)
   259 (* ------------------------------------------------------------------------- *)
   260 
   261 fun find_add_type(Eq,x) = x
   262   | find_add_type(x,Eq) = x
   263   | find_add_type(_,Lt) = Lt
   264   | find_add_type(Lt,_) = Lt
   265   | find_add_type(Le,Le) = Le;
   266 
   267 (* ------------------------------------------------------------------------- *)
   268 (* Multiply out an (in)equation.                                             *)
   269 (* ------------------------------------------------------------------------- *)
   270 
   271 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   272   if n = 1 then i
   273   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
   274   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
   275   else Lineq (n * k, ty, map (curry op* n) l, Multiplied (n, just));
   276 
   277 (* ------------------------------------------------------------------------- *)
   278 (* Add together (in)equations.                                               *)
   279 (* ------------------------------------------------------------------------- *)
   280 
   281 fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   282   let val l = map2 (curry (op +)) l1 l2
   283   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   284 
   285 (* ------------------------------------------------------------------------- *)
   286 (* Elimination of variable between a single pair of (in)equations.           *)
   287 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   288 (* ------------------------------------------------------------------------- *)
   289 
   290 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   291   let val c1 = el v l1 and c2 = el v l2
   292       val m = lcm(abs c1, abs c2)
   293       val m1 = m div (abs c1) and m2 = m div (abs c2)
   294       val (n1,n2) =
   295         if (c1 >= 0) = (c2 >= 0)
   296         then if ty1 = Eq then (~m1,m2)
   297              else if ty2 = Eq then (m1,~m2)
   298                   else sys_error "elim_var"
   299         else (m1,m2)
   300       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   301                     then (~n1,~n2) else (n1,n2)
   302   in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end;
   303 
   304 (* ------------------------------------------------------------------------- *)
   305 (* The main refutation-finding code.                                         *)
   306 (* ------------------------------------------------------------------------- *)
   307 
   308 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   309 
   310 fun is_answer (ans as Lineq(k,ty,l,_)) =
   311   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   312 
   313 fun calc_blowup (l:IntInf.int list) =
   314   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
   315   in (length p) * (length n) end;
   316 
   317 (* ------------------------------------------------------------------------- *)
   318 (* Main elimination code:                                                    *)
   319 (*                                                                           *)
   320 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   321 (*                                                                           *)
   322 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   323 (* coefficient in any of them, and uses it to eliminate.                     *)
   324 (*                                                                           *)
   325 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   326 (* blowup (number of consequences generated) and eliminates it.              *)
   327 (* ------------------------------------------------------------------------- *)
   328 
   329 fun allpairs f xs ys =
   330   List.concat(map (fn x => map (fn y => f x y) ys) xs);
   331 
   332 fun extract_first p =
   333   let fun extract xs (y::ys) = if p y then (SOME y,xs@ys)
   334                                else extract (y::xs) ys
   335         | extract xs []      = (NONE,xs)
   336   in extract [] end;
   337 
   338 fun print_ineqs ineqs =
   339   if !trace then
   340      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   341        IntInf.toString c ^
   342        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   343        commas(map IntInf.toString l)) ineqs))
   344   else ();
   345 
   346 type history = (int * lineq list) list;
   347 datatype result = Success of injust | Failure of history;
   348 
   349 fun elim(ineqs,hist) =
   350   let val dummy = print_ineqs ineqs;
   351       val (triv,nontriv) = List.partition is_trivial ineqs in
   352   if not(null triv)
   353   then case Library.find_first is_answer triv of
   354          NONE => elim(nontriv,hist)
   355        | SOME(Lineq(_,_,_,j)) => Success j
   356   else
   357   if null nontriv then Failure(hist)
   358   else
   359   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   360   if not(null eqs) then
   361      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   362          val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y)))
   363                            (List.filter (fn i => i<>0) clist)
   364          val c = hd sclist
   365          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   366                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   367          val v = find_index_eq c ceq
   368          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0)
   369                                      (othereqs @ noneqs)
   370          val others = map (elim_var v eq) roth @ ioth
   371      in elim(others,(v,nontriv)::hist) end
   372   else
   373   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   374       val numlist = 0 upto (length(hd lists) - 1)
   375       val coeffs = map (fn i => map (el i) lists) numlist
   376       val blows = map calc_blowup coeffs
   377       val iblows = blows ~~ numlist
   378       val nziblows = List.filter (fn (i,_) => i<>0) iblows
   379   in if null nziblows then Failure((~1,nontriv)::hist)
   380      else
   381      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   382          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs
   383          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes
   384      in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
   385   end
   386   end
   387   end;
   388 
   389 (* ------------------------------------------------------------------------- *)
   390 (* Translate back a proof.                                                   *)
   391 (* ------------------------------------------------------------------------- *)
   392 
   393 fun trace_thm msg th = 
   394     if !trace then (tracing msg; tracing (Display.string_of_thm th); th) else th;
   395 
   396 fun trace_msg msg = 
   397     if !trace then tracing msg else ();
   398 
   399 (* FIXME OPTIMIZE!!!! (partly done already)
   400    Addition/Multiplication need i*t representation rather than t+t+...
   401    Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
   402    because Numerals are not known early enough.
   403 
   404 Simplification may detect a contradiction 'prematurely' due to type
   405 information: n+1 <= 0 is simplified to False and does not need to be crossed
   406 with 0 <= n.
   407 *)
   408 local
   409  exception FalseE of thm
   410 in
   411 fun mkthm (sg, ss) asms just =
   412   let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} =
   413           Data.get sg;
   414       val simpset' = Simplifier.inherit_context ss simpset;
   415       val atoms = Library.foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
   416                             map fst lhs  union  (map fst rhs  union  ats))
   417                         ([], List.mapPartial (fn thm => if Thm.no_prems thm
   418                                         then LA_Data.decomp sg (concl_of thm)
   419                                         else NONE) asms)
   420 
   421       fun add2 thm1 thm2 =
   422         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
   423         in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms
   424         end;
   425 
   426       fun try_add [] _ = NONE
   427         | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of
   428              NONE => try_add thm1s thm2 | some => some;
   429 
   430       fun addthms thm1 thm2 =
   431         case add2 thm1 thm2 of
   432           NONE => (case try_add ([thm1] RL inj_thms) thm2 of
   433                      NONE => ( valOf(try_add ([thm2] RL inj_thms) thm1)
   434                                handle Option =>
   435                                (trace_thm "" thm1; trace_thm "" thm2;
   436                                 sys_error "Lin.arith. failed to add thms")
   437                              )
   438                    | SOME thm => thm)
   439         | SOME thm => thm;
   440 
   441       fun multn(n,thm) =
   442         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
   443         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
   444 (*
   445       fun multn2(n,thm) =
   446         let val SOME(mth,cv) =
   447               get_first (fn (th,cv) => SOME(thm RS th,cv) handle THM _ => NONE) mult_mono_thms
   448             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
   449         in instantiate ([],[(cv,ct)]) mth end
   450 *)
   451       fun multn2(n,thm) =
   452         let val SOME(mth) =
   453               get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms
   454             fun cvar(th,_ $ (_ $ _ $ var)) = cterm_of (#sign(rep_thm th)) var;
   455             val cv = cvar(mth, hd(prems_of mth));
   456             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
   457         in instantiate ([],[(cv,ct)]) mth end
   458 
   459       fun simp thm =
   460         let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm)
   461         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
   462 
   463       fun mk(Asm i) = trace_thm "Asm" (nth asms i)
   464         | mk(Nat i) = (trace_msg "Nat"; LA_Logic.mk_nat_thm sg (nth atoms i))
   465         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
   466         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   467         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
   468         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   469         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   470         | mk(Multiplied(n,j)) = (trace_msg("*"^IntInf.toString n); trace_thm "*" (multn(n,mk j)))
   471         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^IntInf.toString n); trace_thm "**" (multn2(n,mk j)))
   472 
   473   in trace_msg "mkthm";
   474      let val thm = trace_thm "Final thm:" (mk just)
   475      in let val fls = simplify simpset' thm
   476         in trace_thm "After simplification:" fls;
   477            if LA_Logic.is_False fls then fls
   478            else
   479             (tracing "Assumptions:"; List.app print_thm asms;
   480              tracing "Proved:"; print_thm fls;
   481              warning "Linear arithmetic should have refuted the assumptions.\n\
   482                      \Please inform Tobias Nipkow (nipkow@in.tum.de).";
   483              fls)
   484         end
   485      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
   486   end
   487 end;
   488 
   489 fun coeff poly atom : IntInf.int =
   490   AList.lookup (op =) poly atom |> the_default 0;
   491 
   492 fun lcms is = Library.foldl lcm (1, is);
   493 
   494 fun integ(rlhs,r,rel,rrhs,s,d) =
   495 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   496     val m = lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   497     fun mult(t,r) = 
   498         let val (i,j) = Rat.quotient_of_rat r
   499         in (t,i * (m div j)) end
   500 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   501 
   502 fun mklineq n atoms =
   503   fn (item,k) =>
   504   let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
   505       val lhsa = map (coeff lhs) atoms
   506       and rhsa = map (coeff rhs) atoms
   507       val diff = map2 (curry (op -)) rhsa lhsa
   508       val c = i-j
   509       val just = Asm k
   510       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
   511   in case rel of
   512       "<="   => lineq(c,Le,diff,just)
   513      | "~<=" => if discrete
   514                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   515                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   516      | "<"   => if discrete
   517                 then lineq(c+1,Le,diff,LessD(just))
   518                 else lineq(c,Lt,diff,just)
   519      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   520      | "="   => lineq(c,Eq,diff,just)
   521      | _     => sys_error("mklineq" ^ rel)   
   522   end;
   523 
   524 (* ------------------------------------------------------------------------- *)
   525 (* Print (counter) example                                                   *)
   526 (* ------------------------------------------------------------------------- *)
   527 
   528 fun print_atom((a,d),r) =
   529   let val (p,q) = Rat.quotient_of_rat r
   530       val s = if d then IntInf.toString p else
   531               if p = 0 then "0"
   532               else IntInf.toString p ^ "/" ^ IntInf.toString q
   533   in a ^ " = " ^ s end;
   534 
   535 fun print_ex sds =
   536   curry (op ~~) sds
   537   #> map print_atom
   538   #> commas
   539   #> curry (op ^) "Counter example:\n"
   540   #> tracing;
   541 
   542 fun trace_ex(sg,params,atoms,discr,n,hist:history) =
   543   if null hist then ()
   544   else let val frees = map Free params;
   545            fun s_of_t t = Sign.string_of_term sg (subst_bounds(frees,t));
   546            val (v,lineqs) :: hist' = hist
   547            val start = if v = ~1 then (findex0 discr n lineqs,hist')
   548                        else (replicate n Rat.zero,hist)
   549        in warning "arith failed - see trace for a counter example";
   550           print_ex ((map s_of_t atoms)~~discr) (findex discr start)
   551           handle NoEx => (tracing "Sorry, no counter example.")
   552        end;
   553 
   554 fun mknat pTs ixs (atom,i) =
   555   if LA_Logic.is_nat(pTs,atom)
   556   then let val l = map (fn j => if j=i then 1 else 0) ixs
   557        in SOME(Lineq(0,Le,l,Nat(i))) end
   558   else NONE
   559 
   560 (* This code is tricky. It takes a list of premises in the order they occur
   561 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   562 ones as NONE. Going through the premises, each numeric one is converted into
   563 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   564 >. Thus split_items returns a list of equation systems. This may blow up if
   565 there are many ~=, but in practice it does not seem to happen. The really
   566 tricky bit is to arrange the order of the cases such that they coincide with
   567 the order in which the cases are in the end generated by the tactic that
   568 applies the generated refutation thms (see function 'refute_tac').
   569 
   570 For variables n of type nat, a constraint 0 <= n is added.
   571 *)
   572 fun split_items(items) =
   573   let fun elim_neq front _ [] = [rev front]
   574         | elim_neq front n (NONE::ineqs) = elim_neq front (n+1) ineqs
   575         | elim_neq front n (SOME(ineq as (l,i,rel,r,j,d))::ineqs) =
   576           if rel = "~=" then elim_neq front n (ineqs @ [SOME(l,i,"<",r,j,d)]) @
   577                              elim_neq front n (ineqs @ [SOME(r,j,"<",l,i,d)])
   578           else elim_neq ((ineq,n) :: front) (n+1) ineqs
   579   in elim_neq [] 0 items end;
   580 
   581 fun add_atoms(ats,((lhs,_,_,rhs,_,_),_)) =
   582     (map fst lhs) union ((map fst rhs) union ats)
   583 
   584 fun add_datoms(dats,((lhs,_,_,rhs,_,d),_)) =
   585     (map (pair d o fst) lhs) union ((map (pair d o fst) rhs) union dats)
   586 
   587 fun discr initems = map fst (Library.foldl add_datoms ([],initems));
   588 
   589 fun refutes sg (pTs,params) ex =
   590 let
   591   fun refute (initems::initemss) js =
   592     let val atoms = Library.foldl add_atoms ([],initems)
   593         val n = length atoms
   594         val mkleq = mklineq n atoms
   595         val ixs = 0 upto (n-1)
   596         val iatoms = atoms ~~ ixs
   597         val natlineqs = List.mapPartial (mknat pTs ixs) iatoms
   598         val ineqs = map mkleq initems @ natlineqs
   599     in case elim(ineqs,[]) of
   600          Success(j) =>
   601            (trace_msg "Contradiction!"; refute initemss (js@[j]))
   602        | Failure(hist) =>
   603            (if not ex then ()
   604             else trace_ex(sg,params,atoms,discr initems,n,hist);
   605             NONE)
   606     end
   607     | refute [] js = SOME js
   608 in refute end;
   609 
   610 fun refute sg ps ex items = refutes sg ps ex (split_items items) [];
   611 
   612 fun refute_tac ss (i,justs) =
   613   fn state =>
   614     let val sg = #sign(rep_thm state)
   615         val {neqE, ...} = Data.get sg;
   616         fun just1 j = REPEAT_DETERM(eresolve_tac neqE i) THEN
   617           METAHYPS (fn asms => rtac (mkthm (sg, ss) asms j) 1) i
   618     in DETERM(resolve_tac [LA_Logic.notI,LA_Logic.ccontr] i) THEN
   619        EVERY(map just1 justs)
   620     end
   621     state;
   622 
   623 fun count P xs = length(List.filter P xs);
   624 
   625 (* The limit on the number of ~= allowed.
   626    Because each ~= is split into two cases, this can lead to an explosion.
   627 *)
   628 val fast_arith_neq_limit = ref 9;
   629 
   630 fun prove sg ps ex Hs concl =
   631 let val Hitems = map (LA_Data.decomp sg) Hs
   632 in if count (fn NONE => false | SOME(_,_,r,_,_,_) => r = "~=") Hitems
   633       > !fast_arith_neq_limit then NONE
   634    else
   635    case LA_Data.decomp sg concl of
   636      NONE => refute sg ps ex (Hitems@[NONE])
   637    | SOME(citem as (r,i,rel,l,j,d)) =>
   638        let val neg::rel0 = explode rel
   639            val nrel = if neg = "~" then implode rel0 else "~"^rel
   640        in refute sg ps ex (Hitems @ [SOME(r,i,nrel,l,j,d)]) end
   641 end;
   642 
   643 (*
   644 Fast but very incomplete decider. Only premises and conclusions
   645 that are already (negated) (in)equations are taken into account.
   646 *)
   647 fun simpset_lin_arith_tac ss ex i st = SUBGOAL (fn (A,_) =>
   648   let val params = rev(Logic.strip_params A)
   649       val pTs = map snd params
   650       val Hs = Logic.strip_assums_hyp A
   651       val concl = Logic.strip_assums_concl A
   652   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
   653      case prove (Thm.sign_of_thm st) (pTs,params) ex Hs concl of
   654        NONE => (trace_msg "Refutation failed."; no_tac)
   655      | SOME js => (trace_msg "Refutation succeeded."; refute_tac ss (i,js))
   656   end) i st;
   657 
   658 fun lin_arith_tac ex i st =
   659   simpset_lin_arith_tac (Simplifier.theory_context (Thm.theory_of_thm st) Simplifier.empty_ss)
   660     ex i st;
   661 
   662 fun cut_lin_arith_tac ss i =
   663   cut_facts_tac (Simplifier.prems_of_ss ss) i THEN
   664   simpset_lin_arith_tac ss false i;
   665 
   666 (** Forward proof from theorems **)
   667 
   668 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   669 to splits of ~= premises) such that it coincides with the order of the cases
   670 generated by function split_items. *)
   671 
   672 datatype splittree = Tip of thm list
   673                    | Spl of thm * cterm * splittree * cterm * splittree
   674 
   675 fun extract imp =
   676 let val (Il,r) = Thm.dest_comb imp
   677     val (_,imp1) = Thm.dest_comb Il
   678     val (Ict1,_) = Thm.dest_comb imp1
   679     val (_,ct1) = Thm.dest_comb Ict1
   680     val (Ir,_) = Thm.dest_comb r
   681     val (_,Ict2r) = Thm.dest_comb Ir
   682     val (Ict2,_) = Thm.dest_comb Ict2r
   683     val (_,ct2) = Thm.dest_comb Ict2
   684 in (ct1,ct2) end;
   685 
   686 fun splitasms sg asms =
   687 let val {neqE, ...}  = Data.get sg;
   688     fun split(asms',[]) = Tip(rev asms')
   689       | split(asms',asm::asms) =
   690       (case get_first (fn th => SOME(asm COMP th) handle THM _ => NONE) neqE
   691        of SOME spl =>
   692           let val (ct1,ct2) = extract(cprop_of spl)
   693               val thm1 = assume ct1 and thm2 = assume ct2
   694           in Spl(spl,ct1,split(asms',asms@[thm1]),ct2,split(asms',asms@[thm2]))
   695           end
   696        | NONE => split(asm::asms', asms))
   697 in split([],asms) end;
   698 
   699 fun fwdproof ctxt (Tip asms) (j::js) = (mkthm ctxt asms j, js)
   700   | fwdproof ctxt (Spl(thm,ct1,tree1,ct2,tree2)) js =
   701     let val (thm1,js1) = fwdproof ctxt tree1 js
   702         val (thm2,js2) = fwdproof ctxt tree2 js1
   703         val thm1' = implies_intr ct1 thm1
   704         val thm2' = implies_intr ct2 thm2
   705     in (thm2' COMP (thm1' COMP thm), js2) end;
   706 (* needs handle THM _ => NONE ? *)
   707 
   708 fun prover (ctxt as (sg, _)) thms Tconcl js pos =
   709 let val nTconcl = LA_Logic.neg_prop Tconcl
   710     val cnTconcl = cterm_of sg nTconcl
   711     val nTconclthm = assume cnTconcl
   712     val tree = splitasms sg (thms @ [nTconclthm])
   713     val (thm,_) = fwdproof ctxt tree js
   714     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   715 in SOME(LA_Logic.mk_Eq((implies_intr cnTconcl thm) COMP contr)) end
   716 (* in case concl contains ?-var, which makes assume fail: *)
   717 handle THM _ => NONE;
   718 
   719 (* PRE: concl is not negated!
   720    This assumption is OK because
   721    1. lin_arith_prover tries both to prove and disprove concl and
   722    2. lin_arith_prover is applied by the simplifier which
   723       dives into terms and will thus try the non-negated concl anyway.
   724 *)
   725 fun lin_arith_prover sg ss concl =
   726 let
   727     val thms = List.concat(map LA_Logic.atomize (prems_of_ss ss));
   728     val Hs = map (#prop o rep_thm) thms
   729     val Tconcl = LA_Logic.mk_Trueprop concl
   730 in case prove sg ([],[]) false Hs Tconcl of (* concl provable? *)
   731      SOME js => prover (sg, ss) thms Tconcl js true
   732    | NONE => let val nTconcl = LA_Logic.neg_prop Tconcl
   733           in case prove sg ([],[]) false Hs nTconcl of (* ~concl provable? *)
   734                SOME js => prover (sg, ss) thms nTconcl js false
   735              | NONE => NONE
   736           end
   737 end;
   738 
   739 end;