src/Provers/Arith/fast_lin_arith.ML
author wenzelm
Thu Apr 07 09:24:35 2005 +0200 (2005-04-07 ago)
changeset 15660 255055554c67
parent 15570 8d8c70b41bab
child 15922 7ef183f3fc98
permissions -rw-r--r--
handle Option instead of OPTION;
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow
     4     Copyright   1998  TU Munich
     5 
     6 A generic linear arithmetic package.
     7 It provides two tactics
     8 
     9     lin_arith_tac:         int -> tactic
    10 cut_lin_arith_tac: thms -> int -> tactic
    11 
    12 and a simplification procedure
    13 
    14     lin_arith_prover: Sign.sg -> simpset -> term -> thm option
    15 
    16 Only take premises and conclusions into account that are already (negated)
    17 (in)equations. lin_arith_prover tries to prove or disprove the term.
    18 *)
    19 
    20 (* Debugging: set Fast_Arith.trace *)
    21 
    22 (*** Data needed for setting up the linear arithmetic package ***)
    23 
    24 signature LIN_ARITH_LOGIC =
    25 sig
    26   val conjI:		thm
    27   val ccontr:           thm (* (~ P ==> False) ==> P *)
    28   val neqE:             thm (* [| m ~= n; m < n ==> P; n < m ==> P |] ==> P *)
    29   val notI:             thm (* (P ==> False) ==> ~ P *)
    30   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
    31   val not_leD:          thm (* ~(m <= n) ==> n < m *)
    32   val sym:		thm (* x = y ==> y = x *)
    33   val mk_Eq: thm -> thm
    34   val mk_Trueprop: term -> term
    35   val neg_prop: term -> term
    36   val is_False: thm -> bool
    37   val is_nat: typ list * term -> bool
    38   val mk_nat_thm: Sign.sg -> term -> thm
    39 end;
    40 (*
    41 mk_Eq(~in) = `in == False'
    42 mk_Eq(in) = `in == True'
    43 where `in' is an (in)equality.
    44 
    45 neg_prop(t) = neg if t is wrapped up in Trueprop and
    46   nt is the (logically) negated version of t, where the negation
    47   of a negative term is the term itself (no double negation!);
    48 
    49 is_nat(parameter-types,t) =  t:nat
    50 mk_nat_thm(t) = "0 <= t"
    51 *)
    52 
    53 signature LIN_ARITH_DATA =
    54 sig
    55   val decomp:
    56     Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option
    57   val number_of: int * typ -> term
    58 end;
    59 (*
    60 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    61    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    62          p/q is the decomposition of the sum terms x/y into a list
    63          of summand * multiplicity pairs and a constant summand and
    64          d indicates if the domain is discrete.
    65 
    66 ss must reduce contradictory <= to False.
    67    It should also cancel common summands to keep <= reduced;
    68    otherwise <= can grow to massive proportions.
    69 *)
    70 
    71 signature FAST_LIN_ARITH =
    72 sig
    73   val setup: (theory -> theory) list
    74   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    75                  lessD: thm list, simpset: Simplifier.simpset}
    76                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    77                      lessD: thm list, simpset: Simplifier.simpset})
    78                 -> theory -> theory
    79   val trace           : bool ref
    80   val fast_arith_neq_limit: int ref
    81   val lin_arith_prover: Sign.sg -> simpset -> term -> thm option
    82   val     lin_arith_tac:     bool -> int -> tactic
    83   val cut_lin_arith_tac: thm list -> int -> tactic
    84 end;
    85 
    86 functor Fast_Lin_Arith(structure LA_Logic:LIN_ARITH_LOGIC 
    87                        and       LA_Data:LIN_ARITH_DATA) : FAST_LIN_ARITH =
    88 struct
    89 
    90 
    91 (** theory data **)
    92 
    93 (* data kind 'Provers/fast_lin_arith' *)
    94 
    95 structure DataArgs =
    96 struct
    97   val name = "Provers/fast_lin_arith";
    98   type T = {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    99             lessD: thm list, simpset: Simplifier.simpset};
   100 
   101   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   102                lessD = [], simpset = Simplifier.empty_ss};
   103   val copy = I;
   104   val prep_ext = I;
   105 
   106   fun merge ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
   107               lessD = lessD1, simpset = simpset1},
   108              {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
   109               lessD = lessD2, simpset = simpset2}) =
   110     {add_mono_thms = Drule.merge_rules (add_mono_thms1, add_mono_thms2),
   111      mult_mono_thms = Drule.merge_rules (mult_mono_thms1, mult_mono_thms2),
   112      inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
   113      lessD = Drule.merge_rules (lessD1, lessD2),
   114      simpset = Simplifier.merge_ss (simpset1, simpset2)};
   115 
   116   fun print _ _ = ();
   117 end;
   118 
   119 structure Data = TheoryDataFun(DataArgs);
   120 val map_data = Data.map;
   121 val setup = [Data.init];
   122 
   123 
   124 
   125 (*** A fast decision procedure ***)
   126 (*** Code ported from HOL Light ***)
   127 (* possible optimizations:
   128    use (var,coeff) rep or vector rep  tp save space;
   129    treat non-negative atoms separately rather than adding 0 <= atom
   130 *)
   131 
   132 val trace = ref false;
   133 
   134 datatype lineq_type = Eq | Le | Lt;
   135 
   136 datatype injust = Asm of int
   137                 | Nat of int (* index of atom *)
   138                 | LessD of injust
   139                 | NotLessD of injust
   140                 | NotLeD of injust
   141                 | NotLeDD of injust
   142                 | Multiplied of int * injust
   143                 | Multiplied2 of int * injust
   144                 | Added of injust * injust;
   145 
   146 datatype lineq = Lineq of int * lineq_type * int list * injust;
   147 
   148 fun el 0 (h::_) = h
   149   | el n (_::t) = el (n - 1) t
   150   | el _ _  = sys_error "el";
   151 
   152 (* ------------------------------------------------------------------------- *)
   153 (* Finding a (counter) example from the trace of a failed elimination        *)
   154 (* ------------------------------------------------------------------------- *)
   155 (* Examples are represented as rational numbers,                             *)
   156 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   157 
   158 exception NoEx;
   159 
   160 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   161    In general, true means the bound is included, false means it is excluded.
   162    Need to know if it is a lower or upper bound for unambiguous interpretation!
   163 *)
   164 
   165 fun elim_eqns(ineqs,Lineq(i,Le,cs,_)) = (i,true,cs)::ineqs
   166   | elim_eqns(ineqs,Lineq(i,Eq,cs,_)) = (i,true,cs)::(~i,true,map ~ cs)::ineqs
   167   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
   168 
   169 val rat0 = rat_of_int 0;
   170 
   171 (* PRE: ex[v] must be 0! *)
   172 fun eval (ex:rat list) v (a:int,le,cs:int list) =
   173   let val rs = map rat_of_int cs
   174       val rsum = Library.foldl ratadd (rat0,map ratmul (rs ~~ ex))
   175   in (ratmul(ratadd(rat_of_int a,ratneg rsum), ratinv(el v rs)), le) end;
   176 (* If el v rs < 0, le should be negated.
   177    Instead this swap is taken into account in ratrelmin2.
   178 *)
   179 
   180 fun ratge0 r = fst(rep_rat r) >= 0;
   181 fun ratle(r,s) = ratge0(ratadd(s,ratneg r));
   182 
   183 fun ratrelmin2(x as (r,ler),y as (s,les)) =
   184   if r=s then (r, (not ler) andalso (not les)) else if ratle(r,s) then x else y;
   185 fun ratrelmax2(x as (r,ler),y as (s,les)) =
   186   if r=s then (r,ler andalso les) else if ratle(r,s) then y else x;
   187 
   188 val ratrelmin = foldr1 ratrelmin2;
   189 val ratrelmax = foldr1 ratrelmax2;
   190 
   191 fun ratroundup r = let val (p,q) = rep_rat r
   192                    in if q=1 then r else rat_of_int((p div q) + 1) end
   193 
   194 fun ratrounddown r = let val (p,q) = rep_rat r
   195                      in if q=1 then r else rat_of_int((p div q) - 1) end
   196 
   197 fun ratexact up (r,exact) =
   198   if exact then r else
   199   let val (p,q) = rep_rat r
   200       val nth = ratinv(rat_of_int q)
   201   in ratadd(r,if up then nth else ratneg nth) end;
   202 
   203 fun ratmiddle(r,s) = ratmul(ratadd(r,s),ratinv(rat_of_int 2));
   204 
   205 fun choose2 d ((lb,exactl),(ub,exactu)) =
   206   if ratle(lb,rat0) andalso (lb <> rat0 orelse exactl) andalso
   207      ratle(rat0,ub) andalso (ub <> rat0 orelse exactu)
   208   then rat0 else
   209   if not d
   210   then (if ratge0 lb
   211         then if exactl then lb else ratmiddle(lb,ub)
   212         else if exactu then ub else ratmiddle(lb,ub))
   213   else (* discrete domain, both bounds must be exact *)
   214   if ratge0 lb then let val lb' = ratroundup lb
   215                     in if ratle(lb',ub) then lb' else raise NoEx end
   216                else let val ub' = ratrounddown ub
   217                     in if ratle(lb,ub') then ub' else raise NoEx end;
   218 
   219 fun findex1 discr (ex,(v,lineqs)) =
   220   let val nz = List.filter (fn (Lineq(_,_,cs,_)) => el v cs <> 0) lineqs;
   221       val ineqs = Library.foldl elim_eqns ([],nz)
   222       val (ge,le) = List.partition (fn (_,_,cs) => el v cs > 0) ineqs
   223       val lb = ratrelmax(map (eval ex v) ge)
   224       val ub = ratrelmin(map (eval ex v) le)
   225   in nth_update (choose2 (List.nth(discr,v)) (lb,ub)) (v,ex) end;
   226 
   227 fun findex discr = Library.foldl (findex1 discr);
   228 
   229 fun elim1 v x =
   230   map (fn (a,le,bs) => (ratadd(a,ratneg(ratmul(el v bs,x))), le,
   231                         nth_update rat0 (v,bs)));
   232 
   233 fun single_var v (_,_,cs) = (filter_out (equal rat0) cs = [el v cs]);
   234 
   235 (* The base case:
   236    all variables occur only with positive or only with negative coefficients *)
   237 fun pick_vars discr (ineqs,ex) =
   238   let val nz = filter_out (fn (_,_,cs) => forall (equal rat0) cs) ineqs
   239   in case nz of [] => ex
   240      | (_,_,cs) :: _ =>
   241        let val v = find_index (not o equal rat0) cs
   242            val d = List.nth(discr,v)
   243            val pos = ratge0(el v cs)
   244            val sv = List.filter (single_var v) nz
   245            val minmax =
   246              if pos then if d then ratroundup o fst o ratrelmax
   247                          else ratexact true o ratrelmax
   248                     else if d then ratrounddown o fst o ratrelmin
   249                          else ratexact false o ratrelmin
   250            val bnds = map (fn (a,le,bs) => (ratmul(a,ratinv(el v bs)),le)) sv
   251            val x = minmax((rat0,if pos then true else false)::bnds)
   252            val ineqs' = elim1 v x nz
   253            val ex' = nth_update x (v,ex)
   254        in pick_vars discr (ineqs',ex') end
   255   end;
   256 
   257 fun findex0 discr n lineqs =
   258   let val ineqs = Library.foldl elim_eqns ([],lineqs)
   259       val rineqs = map (fn (a,le,cs) => (rat_of_int a, le, map rat_of_int cs))
   260                        ineqs
   261   in pick_vars discr (rineqs,replicate n rat0) end;
   262 
   263 (* ------------------------------------------------------------------------- *)
   264 (* End of counter example finder. The actual decision procedure starts here. *)
   265 (* ------------------------------------------------------------------------- *)
   266 
   267 (* ------------------------------------------------------------------------- *)
   268 (* Calculate new (in)equality type after addition.                           *)
   269 (* ------------------------------------------------------------------------- *)
   270 
   271 fun find_add_type(Eq,x) = x
   272   | find_add_type(x,Eq) = x
   273   | find_add_type(_,Lt) = Lt
   274   | find_add_type(Lt,_) = Lt
   275   | find_add_type(Le,Le) = Le;
   276 
   277 (* ------------------------------------------------------------------------- *)
   278 (* Multiply out an (in)equation.                                             *)
   279 (* ------------------------------------------------------------------------- *)
   280 
   281 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   282   if n = 1 then i
   283   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
   284   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
   285   else Lineq(n * k,ty,map (apl(n,op * )) l,Multiplied(n,just));
   286 
   287 (* ------------------------------------------------------------------------- *)
   288 (* Add together (in)equations.                                               *)
   289 (* ------------------------------------------------------------------------- *)
   290 
   291 fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   292   let val l = map2 (op +) (l1,l2)
   293   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   294 
   295 (* ------------------------------------------------------------------------- *)
   296 (* Elimination of variable between a single pair of (in)equations.           *)
   297 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   298 (* ------------------------------------------------------------------------- *)
   299 
   300 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   301   let val c1 = el v l1 and c2 = el v l2
   302       val m = lcm(abs c1,abs c2)
   303       val m1 = m div (abs c1) and m2 = m div (abs c2)
   304       val (n1,n2) =
   305         if (c1 >= 0) = (c2 >= 0)
   306         then if ty1 = Eq then (~m1,m2)
   307              else if ty2 = Eq then (m1,~m2)
   308                   else sys_error "elim_var"
   309         else (m1,m2)
   310       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   311                     then (~n1,~n2) else (n1,n2)
   312   in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end;
   313 
   314 (* ------------------------------------------------------------------------- *)
   315 (* The main refutation-finding code.                                         *)
   316 (* ------------------------------------------------------------------------- *)
   317 
   318 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   319 
   320 fun is_answer (ans as Lineq(k,ty,l,_)) =
   321   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   322 
   323 fun calc_blowup l =
   324   let val (p,n) = List.partition (apl(0,op<)) (List.filter (apl(0,op<>)) l)
   325   in (length p) * (length n) end;
   326 
   327 (* ------------------------------------------------------------------------- *)
   328 (* Main elimination code:                                                    *)
   329 (*                                                                           *)
   330 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   331 (*                                                                           *)
   332 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   333 (* coefficient in any of them, and uses it to eliminate.                     *)
   334 (*                                                                           *)
   335 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   336 (* blowup (number of consequences generated) and eliminates it.              *)
   337 (* ------------------------------------------------------------------------- *)
   338 
   339 fun allpairs f xs ys =
   340   List.concat(map (fn x => map (fn y => f x y) ys) xs);
   341 
   342 fun extract_first p =
   343   let fun extract xs (y::ys) = if p y then (SOME y,xs@ys)
   344                                else extract (y::xs) ys
   345         | extract xs []      = (NONE,xs)
   346   in extract [] end;
   347 
   348 fun print_ineqs ineqs =
   349   if !trace then
   350      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   351        string_of_int c ^
   352        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   353        commas(map string_of_int l)) ineqs))
   354   else ();
   355 
   356 type history = (int * lineq list) list;
   357 datatype result = Success of injust | Failure of history;
   358 
   359 fun elim(ineqs,hist) =
   360   let val dummy = print_ineqs ineqs;
   361       val (triv,nontriv) = List.partition is_trivial ineqs in
   362   if not(null triv)
   363   then case Library.find_first is_answer triv of
   364          NONE => elim(nontriv,hist)
   365        | SOME(Lineq(_,_,_,j)) => Success j
   366   else
   367   if null nontriv then Failure(hist)
   368   else
   369   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   370   if not(null eqs) then
   371      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   372          val sclist = sort (fn (x,y) => int_ord(abs(x),abs(y)))
   373                            (List.filter (fn i => i<>0) clist)
   374          val c = hd sclist
   375          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   376                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   377          val v = find_index_eq c ceq
   378          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0)
   379                                      (othereqs @ noneqs)
   380          val others = map (elim_var v eq) roth @ ioth
   381      in elim(others,(v,nontriv)::hist) end
   382   else
   383   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   384       val numlist = 0 upto (length(hd lists) - 1)
   385       val coeffs = map (fn i => map (el i) lists) numlist
   386       val blows = map calc_blowup coeffs
   387       val iblows = blows ~~ numlist
   388       val nziblows = List.filter (fn (i,_) => i<>0) iblows
   389   in if null nziblows then Failure((~1,nontriv)::hist)
   390      else
   391      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   392          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs
   393          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes
   394      in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
   395   end
   396   end
   397   end;
   398 
   399 (* ------------------------------------------------------------------------- *)
   400 (* Translate back a proof.                                                   *)
   401 (* ------------------------------------------------------------------------- *)
   402 
   403 fun trace_thm msg th = 
   404     if !trace then (tracing msg; tracing (Display.string_of_thm th); th) else th;
   405 
   406 fun trace_msg msg = 
   407     if !trace then tracing msg else ();
   408 
   409 (* FIXME OPTIMIZE!!!! (partly done already)
   410    Addition/Multiplication need i*t representation rather than t+t+...
   411    Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
   412    because Numerals are not known early enough.
   413 
   414 Simplification may detect a contradiction 'prematurely' due to type
   415 information: n+1 <= 0 is simplified to False and does not need to be crossed
   416 with 0 <= n.
   417 *)
   418 local
   419  exception FalseE of thm
   420 in
   421 fun mkthm sg asms just =
   422   let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset} = Data.get_sg sg;
   423       val atoms = Library.foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
   424                             map fst lhs  union  (map fst rhs  union  ats))
   425                         ([], List.mapPartial (fn thm => if Thm.no_prems thm
   426                                         then LA_Data.decomp sg (concl_of thm)
   427                                         else NONE) asms)
   428 
   429       fun add2 thm1 thm2 =
   430         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
   431         in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms
   432         end;
   433 
   434       fun try_add [] _ = NONE
   435         | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of
   436              NONE => try_add thm1s thm2 | some => some;
   437 
   438       fun addthms thm1 thm2 =
   439         case add2 thm1 thm2 of
   440           NONE => (case try_add ([thm1] RL inj_thms) thm2 of
   441                      NONE => ( valOf(try_add ([thm2] RL inj_thms) thm1)
   442                                handle Option =>
   443                                (trace_thm "" thm1; trace_thm "" thm2;
   444                                 sys_error "Lin.arith. failed to add thms")
   445                              )
   446                    | SOME thm => thm)
   447         | SOME thm => thm;
   448 
   449       fun multn(n,thm) =
   450         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
   451         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
   452 (*
   453       fun multn2(n,thm) =
   454         let val SOME(mth,cv) =
   455               get_first (fn (th,cv) => SOME(thm RS th,cv) handle THM _ => NONE) mult_mono_thms
   456             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
   457         in instantiate ([],[(cv,ct)]) mth end
   458 *)
   459       fun multn2(n,thm) =
   460         let val SOME(mth) =
   461               get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms
   462             fun cvar(th,_ $ (_ $ _ $ var)) = cterm_of (#sign(rep_thm th)) var;
   463             val cv = cvar(mth, hd(prems_of mth));
   464             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
   465         in instantiate ([],[(cv,ct)]) mth end
   466 
   467       fun simp thm =
   468         let val thm' = trace_thm "Simplified:" (full_simplify simpset thm)
   469         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
   470 
   471       fun mk(Asm i) = trace_thm "Asm" (List.nth(asms,i))
   472         | mk(Nat i) = (trace_msg "Nat"; LA_Logic.mk_nat_thm sg (List.nth(atoms,i)))
   473         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
   474         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   475         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
   476         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   477         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   478         | mk(Multiplied(n,j)) = (trace_msg("*"^string_of_int n); trace_thm "*" (multn(n,mk j)))
   479         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^string_of_int n); trace_thm "**" (multn2(n,mk j)))
   480 
   481   in trace_msg "mkthm";
   482      let val thm = trace_thm "Final thm:" (mk just)
   483      in let val fls = simplify simpset thm
   484         in trace_thm "After simplification:" fls;
   485            if LA_Logic.is_False fls then fls
   486            else
   487             (tracing "Assumptions:"; List.app print_thm asms;
   488              tracing "Proved:"; print_thm fls;
   489              warning "Linear arithmetic should have refuted the assumptions.\n\
   490                      \Please inform Tobias Nipkow (nipkow@in.tum.de).";
   491              fls)
   492         end
   493      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
   494   end
   495 end;
   496 
   497 fun coeff poly atom = case assoc(poly,atom) of NONE => 0 | SOME i => i;
   498 
   499 fun lcms is = Library.foldl lcm (1,is);
   500 
   501 fun integ(rlhs,r,rel,rrhs,s,d) =
   502 let val (rn,rd) = rep_rat r and (sn,sd) = rep_rat s
   503     val m = lcms(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   504     fun mult(t,r) = let val (i,j) = rep_rat r in (t,i * (m div j)) end
   505 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   506 
   507 fun mklineq n atoms =
   508   fn (item,k) =>
   509   let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
   510       val lhsa = map (coeff lhs) atoms
   511       and rhsa = map (coeff rhs) atoms
   512       val diff = map2 (op -) (rhsa,lhsa)
   513       val c = i-j
   514       val just = Asm k
   515       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
   516   in case rel of
   517       "<="   => lineq(c,Le,diff,just)
   518      | "~<=" => if discrete
   519                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   520                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   521      | "<"   => if discrete
   522                 then lineq(c+1,Le,diff,LessD(just))
   523                 else lineq(c,Lt,diff,just)
   524      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   525      | "="   => lineq(c,Eq,diff,just)
   526      | _     => sys_error("mklineq" ^ rel)   
   527   end;
   528 
   529 (* ------------------------------------------------------------------------- *)
   530 (* Print (counter) example                                                   *)
   531 (* ------------------------------------------------------------------------- *)
   532 
   533 fun print_atom((a,d),r) =
   534   let val (p,q) = rep_rat r
   535       val s = if d then string_of_int p else
   536               if p = 0 then "0"
   537               else string_of_int p ^ "/" ^ string_of_int q
   538   in a ^ " = " ^ s end;
   539 
   540 fun print_ex sds =
   541   tracing o
   542   apl("Counter example:\n",op ^) o
   543   commas o
   544   map print_atom o
   545   apl(sds, op ~~);
   546 
   547 fun trace_ex(sg,params,atoms,discr,n,hist:history) =
   548   if null hist then ()
   549   else let val frees = map Free params;
   550            fun s_of_t t = Sign.string_of_term sg (subst_bounds(frees,t));
   551            val (v,lineqs) :: hist' = hist
   552            val start = if v = ~1 then (findex0 discr n lineqs,hist')
   553                        else (replicate n rat0,hist)
   554        in warning "arith failed - see trace for a counter example";
   555           print_ex ((map s_of_t atoms)~~discr) (findex discr start)
   556           handle NoEx =>
   557   (tracing "The decision procedure failed to prove your proposition\n\
   558            \but could not construct a counter example either.\n\
   559            \Probably the proposition is true but cannot be proved\n\
   560            \by the incomplete decision procedure.")
   561        end;
   562 
   563 fun mknat pTs ixs (atom,i) =
   564   if LA_Logic.is_nat(pTs,atom)
   565   then let val l = map (fn j => if j=i then 1 else 0) ixs
   566        in SOME(Lineq(0,Le,l,Nat(i))) end
   567   else NONE
   568 
   569 (* This code is tricky. It takes a list of premises in the order they occur
   570 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   571 ones as NONE. Going through the premises, each numeric one is converted into
   572 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   573 >. Thus split_items returns a list of equation systems. This may blow up if
   574 there are many ~=, but in practice it does not seem to happen. The really
   575 tricky bit is to arrange the order of the cases such that they coincide with
   576 the order in which the cases are in the end generated by the tactic that
   577 applies the generated refutation thms (see function 'refute_tac').
   578 
   579 For variables n of type nat, a constraint 0 <= n is added.
   580 *)
   581 fun split_items(items) =
   582   let fun elim_neq front _ [] = [rev front]
   583         | elim_neq front n (NONE::ineqs) = elim_neq front (n+1) ineqs
   584         | elim_neq front n (SOME(ineq as (l,i,rel,r,j,d))::ineqs) =
   585           if rel = "~=" then elim_neq front n (ineqs @ [SOME(l,i,"<",r,j,d)]) @
   586                              elim_neq front n (ineqs @ [SOME(r,j,"<",l,i,d)])
   587           else elim_neq ((ineq,n) :: front) (n+1) ineqs
   588   in elim_neq [] 0 items end;
   589 
   590 fun add_atoms(ats,((lhs,_,_,rhs,_,_),_)) =
   591     (map fst lhs) union ((map fst rhs) union ats)
   592 
   593 fun add_datoms(dats,((lhs,_,_,rhs,_,d),_)) =
   594     (map (pair d o fst) lhs) union ((map (pair d o fst) rhs) union dats)
   595 
   596 fun discr initems = map fst (Library.foldl add_datoms ([],initems));
   597 
   598 fun refutes sg (pTs,params) ex =
   599 let
   600   fun refute (initems::initemss) js =
   601     let val atoms = Library.foldl add_atoms ([],initems)
   602         val n = length atoms
   603         val mkleq = mklineq n atoms
   604         val ixs = 0 upto (n-1)
   605         val iatoms = atoms ~~ ixs
   606         val natlineqs = List.mapPartial (mknat pTs ixs) iatoms
   607         val ineqs = map mkleq initems @ natlineqs
   608     in case elim(ineqs,[]) of
   609          Success(j) =>
   610            (trace_msg "Contradiction!"; refute initemss (js@[j]))
   611        | Failure(hist) =>
   612            (if not ex then ()
   613             else trace_ex(sg,params,atoms,discr initems,n,hist);
   614             NONE)
   615     end
   616     | refute [] js = SOME js
   617 in refute end;
   618 
   619 fun refute sg ps ex items = refutes sg ps ex (split_items items) [];
   620 
   621 fun refute_tac(i,justs) =
   622   fn state =>
   623     let val sg = #sign(rep_thm state)
   624         fun just1 j = REPEAT_DETERM(etac LA_Logic.neqE i) THEN
   625                       METAHYPS (fn asms => rtac (mkthm sg asms j) 1) i
   626     in DETERM(resolve_tac [LA_Logic.notI,LA_Logic.ccontr] i) THEN
   627        EVERY(map just1 justs)
   628     end
   629     state;
   630 
   631 fun count P xs = length(List.filter P xs);
   632 
   633 (* The limit on the number of ~= allowed.
   634    Because each ~= is split into two cases, this can lead to an explosion.
   635 *)
   636 val fast_arith_neq_limit = ref 9;
   637 
   638 fun prove sg ps ex Hs concl =
   639 let val Hitems = map (LA_Data.decomp sg) Hs
   640 in if count (fn NONE => false | SOME(_,_,r,_,_,_) => r = "~=") Hitems
   641       > !fast_arith_neq_limit then NONE
   642    else
   643    case LA_Data.decomp sg concl of
   644      NONE => refute sg ps ex (Hitems@[NONE])
   645    | SOME(citem as (r,i,rel,l,j,d)) =>
   646        let val neg::rel0 = explode rel
   647            val nrel = if neg = "~" then implode rel0 else "~"^rel
   648        in refute sg ps ex (Hitems @ [SOME(r,i,nrel,l,j,d)]) end
   649 end;
   650 
   651 (*
   652 Fast but very incomplete decider. Only premises and conclusions
   653 that are already (negated) (in)equations are taken into account.
   654 *)
   655 fun lin_arith_tac ex i st = SUBGOAL (fn (A,_) =>
   656   let val params = rev(Logic.strip_params A)
   657       val pTs = map snd params
   658       val Hs = Logic.strip_assums_hyp A
   659       val concl = Logic.strip_assums_concl A
   660   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
   661      case prove (Thm.sign_of_thm st) (pTs,params) ex Hs concl of
   662        NONE => (trace_msg "Refutation failed."; no_tac)
   663      | SOME js => (trace_msg "Refutation succeeded."; refute_tac(i,js))
   664   end) i st;
   665 
   666 fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac false i;
   667 
   668 (** Forward proof from theorems **)
   669 
   670 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   671 to splits of ~= premises) such that it coincides with the order of the cases
   672 generated by function split_items. *)
   673 
   674 datatype splittree = Tip of thm list
   675                    | Spl of thm * cterm * splittree * cterm * splittree
   676 
   677 fun extract imp =
   678 let val (Il,r) = Thm.dest_comb imp
   679     val (_,imp1) = Thm.dest_comb Il
   680     val (Ict1,_) = Thm.dest_comb imp1
   681     val (_,ct1) = Thm.dest_comb Ict1
   682     val (Ir,_) = Thm.dest_comb r
   683     val (_,Ict2r) = Thm.dest_comb Ir
   684     val (Ict2,_) = Thm.dest_comb Ict2r
   685     val (_,ct2) = Thm.dest_comb Ict2
   686 in (ct1,ct2) end;
   687 
   688 fun splitasms asms =
   689 let fun split(asms',[]) = Tip(rev asms')
   690       | split(asms',asm::asms) =
   691       let val spl = asm COMP LA_Logic.neqE
   692           val (ct1,ct2) = extract(cprop_of spl)
   693           val thm1 = assume ct1 and thm2 = assume ct2
   694       in Spl(spl,ct1,split(asms',asms@[thm1]),ct2,split(asms',asms@[thm2])) end
   695       handle THM _ => split(asm::asms', asms)
   696 in split([],asms) end;
   697 
   698 fun fwdproof sg (Tip asms) (j::js) = (mkthm sg asms j, js)
   699   | fwdproof sg (Spl(thm,ct1,tree1,ct2,tree2)) js =
   700     let val (thm1,js1) = fwdproof sg tree1 js
   701         val (thm2,js2) = fwdproof sg tree2 js1
   702         val thm1' = implies_intr ct1 thm1
   703         val thm2' = implies_intr ct2 thm2
   704     in (thm2' COMP (thm1' COMP thm), js2) end;
   705 (* needs handle THM _ => NONE ? *)
   706 
   707 fun prover sg thms Tconcl js pos =
   708 let val nTconcl = LA_Logic.neg_prop Tconcl
   709     val cnTconcl = cterm_of sg nTconcl
   710     val nTconclthm = assume cnTconcl
   711     val tree = splitasms (thms @ [nTconclthm])
   712     val (thm,_) = fwdproof sg tree js
   713     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   714 in SOME(LA_Logic.mk_Eq((implies_intr cnTconcl thm) COMP contr)) end
   715 (* in case concl contains ?-var, which makes assume fail: *)
   716 handle THM _ => NONE;
   717 
   718 (* PRE: concl is not negated!
   719    This assumption is OK because
   720    1. lin_arith_prover tries both to prove and disprove concl and
   721    2. lin_arith_prover is applied by the simplifier which
   722       dives into terms and will thus try the non-negated concl anyway.
   723 *)
   724 fun lin_arith_prover sg ss concl =
   725 let
   726     val thms = prems_of_ss ss;
   727     val Hs = map (#prop o rep_thm) thms
   728     val Tconcl = LA_Logic.mk_Trueprop concl
   729 in case prove sg ([],[]) false Hs Tconcl of (* concl provable? *)
   730      SOME js => prover sg thms Tconcl js true
   731    | NONE => let val nTconcl = LA_Logic.neg_prop Tconcl
   732           in case prove sg ([],[]) false Hs nTconcl of (* ~concl provable? *)
   733                SOME js => prover sg thms nTconcl js false
   734              | NONE => NONE
   735           end
   736 end;
   737 
   738 end;