src/Provers/Arith/fast_lin_arith.ML
 author wenzelm Thu Jan 19 21:22:08 2006 +0100 (2006-01-19 ago) changeset 18708 4b3dadb4fe33 parent 18572 dab1dd61e59d child 19049 2103a8e14eaa permissions -rw-r--r--
setup: theory -> theory;
```     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
```
```     2     ID:         \$Id\$
```
```     3     Author:     Tobias Nipkow
```
```     4     Copyright   1998  TU Munich
```
```     5
```
```     6 A generic linear arithmetic package.
```
```     7 It provides two tactics
```
```     8
```
```     9     lin_arith_tac:         int -> tactic
```
```    10 cut_lin_arith_tac: thms -> int -> tactic
```
```    11
```
```    12 and a simplification procedure
```
```    13
```
```    14     lin_arith_prover: theory -> simpset -> term -> thm option
```
```    15
```
```    16 Only take premises and conclusions into account that are already (negated)
```
```    17 (in)equations. lin_arith_prover tries to prove or disprove the term.
```
```    18 *)
```
```    19
```
```    20 (* Debugging: set Fast_Arith.trace *)
```
```    21
```
```    22 (*** Data needed for setting up the linear arithmetic package ***)
```
```    23
```
```    24 signature LIN_ARITH_LOGIC =
```
```    25 sig
```
```    26   val conjI:		thm
```
```    27   val ccontr:           thm (* (~ P ==> False) ==> P *)
```
```    28   val notI:             thm (* (P ==> False) ==> ~ P *)
```
```    29   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
```
```    30   val not_leD:          thm (* ~(m <= n) ==> n < m *)
```
```    31   val sym:		thm (* x = y ==> y = x *)
```
```    32   val mk_Eq: thm -> thm
```
```    33   val atomize: thm -> thm list
```
```    34   val mk_Trueprop: term -> term
```
```    35   val neg_prop: term -> term
```
```    36   val is_False: thm -> bool
```
```    37   val is_nat: typ list * term -> bool
```
```    38   val mk_nat_thm: theory -> term -> thm
```
```    39 end;
```
```    40 (*
```
```    41 mk_Eq(~in) = `in == False'
```
```    42 mk_Eq(in) = `in == True'
```
```    43 where `in' is an (in)equality.
```
```    44
```
```    45 neg_prop(t) = neg if t is wrapped up in Trueprop and
```
```    46   nt is the (logically) negated version of t, where the negation
```
```    47   of a negative term is the term itself (no double negation!);
```
```    48
```
```    49 is_nat(parameter-types,t) =  t:nat
```
```    50 mk_nat_thm(t) = "0 <= t"
```
```    51 *)
```
```    52
```
```    53 signature LIN_ARITH_DATA =
```
```    54 sig
```
```    55   val decomp:
```
```    56     theory -> term -> ((term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool) option
```
```    57   val number_of: IntInf.int * typ -> term
```
```    58 end;
```
```    59 (*
```
```    60 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
```
```    61    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
```
```    62          p/q is the decomposition of the sum terms x/y into a list
```
```    63          of summand * multiplicity pairs and a constant summand and
```
```    64          d indicates if the domain is discrete.
```
```    65
```
```    66 ss must reduce contradictory <= to False.
```
```    67    It should also cancel common summands to keep <= reduced;
```
```    68    otherwise <= can grow to massive proportions.
```
```    69 *)
```
```    70
```
```    71 signature FAST_LIN_ARITH =
```
```    72 sig
```
```    73   val setup: theory -> theory
```
```    74   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
```
```    75                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}
```
```    76                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
```
```    77                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset})
```
```    78                 -> theory -> theory
```
```    79   val trace           : bool ref
```
```    80   val fast_arith_neq_limit: int ref
```
```    81   val lin_arith_prover: theory -> simpset -> term -> thm option
```
```    82   val     lin_arith_tac:    bool -> int -> tactic
```
```    83   val cut_lin_arith_tac: simpset -> int -> tactic
```
```    84 end;
```
```    85
```
```    86 functor Fast_Lin_Arith(structure LA_Logic:LIN_ARITH_LOGIC
```
```    87                        and       LA_Data:LIN_ARITH_DATA) : FAST_LIN_ARITH =
```
```    88 struct
```
```    89
```
```    90
```
```    91 (** theory data **)
```
```    92
```
```    93 (* data kind 'Provers/fast_lin_arith' *)
```
```    94
```
```    95 structure Data = TheoryDataFun
```
```    96 (struct
```
```    97   val name = "Provers/fast_lin_arith";
```
```    98   type T = {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
```
```    99             lessD: thm list, neqE: thm list, simpset: Simplifier.simpset};
```
```   100
```
```   101   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
```
```   102                lessD = [], neqE = [], simpset = Simplifier.empty_ss};
```
```   103   val copy = I;
```
```   104   val extend = I;
```
```   105
```
```   106   fun merge _
```
```   107     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
```
```   108       lessD = lessD1, neqE=neqE1, simpset = simpset1},
```
```   109      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
```
```   110       lessD = lessD2, neqE=neqE2, simpset = simpset2}) =
```
```   111     {add_mono_thms = Drule.merge_rules (add_mono_thms1, add_mono_thms2),
```
```   112      mult_mono_thms = Drule.merge_rules (mult_mono_thms1, mult_mono_thms2),
```
```   113      inj_thms = Drule.merge_rules (inj_thms1, inj_thms2),
```
```   114      lessD = Drule.merge_rules (lessD1, lessD2),
```
```   115      neqE = Drule.merge_rules (neqE1, neqE2),
```
```   116      simpset = Simplifier.merge_ss (simpset1, simpset2)};
```
```   117
```
```   118   fun print _ _ = ();
```
```   119 end);
```
```   120
```
```   121 val map_data = Data.map;
```
```   122 val setup = Data.init;
```
```   123
```
```   124
```
```   125
```
```   126 (*** A fast decision procedure ***)
```
```   127 (*** Code ported from HOL Light ***)
```
```   128 (* possible optimizations:
```
```   129    use (var,coeff) rep or vector rep  tp save space;
```
```   130    treat non-negative atoms separately rather than adding 0 <= atom
```
```   131 *)
```
```   132
```
```   133 val trace = ref false;
```
```   134
```
```   135 datatype lineq_type = Eq | Le | Lt;
```
```   136
```
```   137 datatype injust = Asm of int
```
```   138                 | Nat of int (* index of atom *)
```
```   139                 | LessD of injust
```
```   140                 | NotLessD of injust
```
```   141                 | NotLeD of injust
```
```   142                 | NotLeDD of injust
```
```   143                 | Multiplied of IntInf.int * injust
```
```   144                 | Multiplied2 of IntInf.int * injust
```
```   145                 | Added of injust * injust;
```
```   146
```
```   147 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust;
```
```   148
```
```   149 fun el 0 (h::_) = h
```
```   150   | el n (_::t) = el (n - 1) t
```
```   151   | el _ _  = sys_error "el";
```
```   152
```
```   153 (* ------------------------------------------------------------------------- *)
```
```   154 (* Finding a (counter) example from the trace of a failed elimination        *)
```
```   155 (* ------------------------------------------------------------------------- *)
```
```   156 (* Examples are represented as rational numbers,                             *)
```
```   157 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
```
```   158
```
```   159 exception NoEx;
```
```   160
```
```   161 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
```
```   162    In general, true means the bound is included, false means it is excluded.
```
```   163    Need to know if it is a lower or upper bound for unambiguous interpretation!
```
```   164 *)
```
```   165
```
```   166 fun elim_eqns(ineqs,Lineq(i,Le,cs,_)) = (i,true,cs)::ineqs
```
```   167   | elim_eqns(ineqs,Lineq(i,Eq,cs,_)) = (i,true,cs)::(~i,true,map ~ cs)::ineqs
```
```   168   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
```
```   169
```
```   170 (* PRE: ex[v] must be 0! *)
```
```   171 fun eval (ex:Rat.rat list) v (a:IntInf.int,le,cs:IntInf.int list) =
```
```   172   let val rs = map Rat.rat_of_intinf cs
```
```   173       val rsum = Library.foldl Rat.add (Rat.zero, map Rat.mult (rs ~~ ex))
```
```   174   in (Rat.mult (Rat.add(Rat.rat_of_intinf a,Rat.neg rsum), Rat.inv(el v rs)), le) end;
```
```   175 (* If el v rs < 0, le should be negated.
```
```   176    Instead this swap is taken into account in ratrelmin2.
```
```   177 *)
```
```   178
```
```   179 fun ratrelmin2(x as (r,ler),y as (s,les)) =
```
```   180   if r=s then (r, (not ler) andalso (not les)) else if Rat.le(r,s) then x else y;
```
```   181 fun ratrelmax2(x as (r,ler),y as (s,les)) =
```
```   182   if r=s then (r,ler andalso les) else if Rat.le(r,s) then y else x;
```
```   183
```
```   184 val ratrelmin = foldr1 ratrelmin2;
```
```   185 val ratrelmax = foldr1 ratrelmax2;
```
```   186
```
```   187 fun ratexact up (r,exact) =
```
```   188   if exact then r else
```
```   189   let val (p,q) = Rat.quotient_of_rat r
```
```   190       val nth = Rat.inv(Rat.rat_of_intinf q)
```
```   191   in Rat.add(r,if up then nth else Rat.neg nth) end;
```
```   192
```
```   193 fun ratmiddle(r,s) = Rat.mult(Rat.add(r,s),Rat.inv(Rat.rat_of_int 2));
```
```   194
```
```   195 fun choose2 d ((lb,exactl),(ub,exactu)) =
```
```   196   if Rat.le(lb,Rat.zero) andalso (lb <> Rat.zero orelse exactl) andalso
```
```   197      Rat.le(Rat.zero,ub) andalso (ub <> Rat.zero orelse exactu)
```
```   198   then Rat.zero else
```
```   199   if not d
```
```   200   then (if Rat.ge0 lb
```
```   201         then if exactl then lb else ratmiddle(lb,ub)
```
```   202         else if exactu then ub else ratmiddle(lb,ub))
```
```   203   else (* discrete domain, both bounds must be exact *)
```
```   204   if Rat.ge0 lb then let val lb' = Rat.roundup lb
```
```   205                     in if Rat.le(lb',ub) then lb' else raise NoEx end
```
```   206                else let val ub' = Rat.rounddown ub
```
```   207                     in if Rat.le(lb,ub') then ub' else raise NoEx end;
```
```   208
```
```   209 fun findex1 discr (ex,(v,lineqs)) =
```
```   210   let val nz = List.filter (fn (Lineq(_,_,cs,_)) => el v cs <> 0) lineqs;
```
```   211       val ineqs = Library.foldl elim_eqns ([],nz)
```
```   212       val (ge,le) = List.partition (fn (_,_,cs) => el v cs > 0) ineqs
```
```   213       val lb = ratrelmax(map (eval ex v) ge)
```
```   214       val ub = ratrelmin(map (eval ex v) le)
```
```   215   in nth_update (v, choose2 (nth discr v) (lb, ub)) ex end;
```
```   216
```
```   217 fun findex discr = Library.foldl (findex1 discr);
```
```   218
```
```   219 fun elim1 v x =
```
```   220   map (fn (a,le,bs) => (Rat.add(a,Rat.neg(Rat.mult(el v bs,x))), le,
```
```   221                         nth_update (v, Rat.zero) bs));
```
```   222
```
```   223 fun single_var v (_,_,cs) = (filter_out (equal Rat.zero) cs = [el v cs]);
```
```   224
```
```   225 (* The base case:
```
```   226    all variables occur only with positive or only with negative coefficients *)
```
```   227 fun pick_vars discr (ineqs,ex) =
```
```   228   let val nz = filter_out (fn (_,_,cs) => forall (equal Rat.zero) cs) ineqs
```
```   229   in case nz of [] => ex
```
```   230      | (_,_,cs) :: _ =>
```
```   231        let val v = find_index (not o equal Rat.zero) cs
```
```   232            val d = nth discr v
```
```   233            val pos = Rat.ge0(el v cs)
```
```   234            val sv = List.filter (single_var v) nz
```
```   235            val minmax =
```
```   236              if pos then if d then Rat.roundup o fst o ratrelmax
```
```   237                          else ratexact true o ratrelmax
```
```   238                     else if d then Rat.rounddown o fst o ratrelmin
```
```   239                          else ratexact false o ratrelmin
```
```   240            val bnds = map (fn (a,le,bs) => (Rat.mult(a,Rat.inv(el v bs)),le)) sv
```
```   241            val x = minmax((Rat.zero,if pos then true else false)::bnds)
```
```   242            val ineqs' = elim1 v x nz
```
```   243            val ex' = nth_update (v, x) ex
```
```   244        in pick_vars discr (ineqs',ex') end
```
```   245   end;
```
```   246
```
```   247 fun findex0 discr n lineqs =
```
```   248   let val ineqs = Library.foldl elim_eqns ([],lineqs)
```
```   249       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_intinf a, le, map Rat.rat_of_intinf cs))
```
```   250                        ineqs
```
```   251   in pick_vars discr (rineqs,replicate n Rat.zero) end;
```
```   252
```
```   253 (* ------------------------------------------------------------------------- *)
```
```   254 (* End of counter example finder. The actual decision procedure starts here. *)
```
```   255 (* ------------------------------------------------------------------------- *)
```
```   256
```
```   257 (* ------------------------------------------------------------------------- *)
```
```   258 (* Calculate new (in)equality type after addition.                           *)
```
```   259 (* ------------------------------------------------------------------------- *)
```
```   260
```
```   261 fun find_add_type(Eq,x) = x
```
```   262   | find_add_type(x,Eq) = x
```
```   263   | find_add_type(_,Lt) = Lt
```
```   264   | find_add_type(Lt,_) = Lt
```
```   265   | find_add_type(Le,Le) = Le;
```
```   266
```
```   267 (* ------------------------------------------------------------------------- *)
```
```   268 (* Multiply out an (in)equation.                                             *)
```
```   269 (* ------------------------------------------------------------------------- *)
```
```   270
```
```   271 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
```
```   272   if n = 1 then i
```
```   273   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
```
```   274   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
```
```   275   else Lineq (n * k, ty, map (curry op* n) l, Multiplied (n, just));
```
```   276
```
```   277 (* ------------------------------------------------------------------------- *)
```
```   278 (* Add together (in)equations.                                               *)
```
```   279 (* ------------------------------------------------------------------------- *)
```
```   280
```
```   281 fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
```
```   282   let val l = map2 (curry (op +)) l1 l2
```
```   283   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
```
```   284
```
```   285 (* ------------------------------------------------------------------------- *)
```
```   286 (* Elimination of variable between a single pair of (in)equations.           *)
```
```   287 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
```
```   288 (* ------------------------------------------------------------------------- *)
```
```   289
```
```   290 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
```
```   291   let val c1 = el v l1 and c2 = el v l2
```
```   292       val m = lcm(abs c1, abs c2)
```
```   293       val m1 = m div (abs c1) and m2 = m div (abs c2)
```
```   294       val (n1,n2) =
```
```   295         if (c1 >= 0) = (c2 >= 0)
```
```   296         then if ty1 = Eq then (~m1,m2)
```
```   297              else if ty2 = Eq then (m1,~m2)
```
```   298                   else sys_error "elim_var"
```
```   299         else (m1,m2)
```
```   300       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
```
```   301                     then (~n1,~n2) else (n1,n2)
```
```   302   in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end;
```
```   303
```
```   304 (* ------------------------------------------------------------------------- *)
```
```   305 (* The main refutation-finding code.                                         *)
```
```   306 (* ------------------------------------------------------------------------- *)
```
```   307
```
```   308 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
```
```   309
```
```   310 fun is_answer (ans as Lineq(k,ty,l,_)) =
```
```   311   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
```
```   312
```
```   313 fun calc_blowup (l:IntInf.int list) =
```
```   314   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
```
```   315   in (length p) * (length n) end;
```
```   316
```
```   317 (* ------------------------------------------------------------------------- *)
```
```   318 (* Main elimination code:                                                    *)
```
```   319 (*                                                                           *)
```
```   320 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
```
```   321 (*                                                                           *)
```
```   322 (* (2) If there are any equations, picks a variable with the lowest absolute *)
```
```   323 (* coefficient in any of them, and uses it to eliminate.                     *)
```
```   324 (*                                                                           *)
```
```   325 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
```
```   326 (* blowup (number of consequences generated) and eliminates it.              *)
```
```   327 (* ------------------------------------------------------------------------- *)
```
```   328
```
```   329 fun allpairs f xs ys =
```
```   330   List.concat(map (fn x => map (fn y => f x y) ys) xs);
```
```   331
```
```   332 fun extract_first p =
```
```   333   let fun extract xs (y::ys) = if p y then (SOME y,xs@ys)
```
```   334                                else extract (y::xs) ys
```
```   335         | extract xs []      = (NONE,xs)
```
```   336   in extract [] end;
```
```   337
```
```   338 fun print_ineqs ineqs =
```
```   339   if !trace then
```
```   340      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
```
```   341        IntInf.toString c ^
```
```   342        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
```
```   343        commas(map IntInf.toString l)) ineqs))
```
```   344   else ();
```
```   345
```
```   346 type history = (int * lineq list) list;
```
```   347 datatype result = Success of injust | Failure of history;
```
```   348
```
```   349 fun elim(ineqs,hist) =
```
```   350   let val dummy = print_ineqs ineqs;
```
```   351       val (triv,nontriv) = List.partition is_trivial ineqs in
```
```   352   if not(null triv)
```
```   353   then case Library.find_first is_answer triv of
```
```   354          NONE => elim(nontriv,hist)
```
```   355        | SOME(Lineq(_,_,_,j)) => Success j
```
```   356   else
```
```   357   if null nontriv then Failure(hist)
```
```   358   else
```
```   359   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
```
```   360   if not(null eqs) then
```
```   361      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
```
```   362          val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y)))
```
```   363                            (List.filter (fn i => i<>0) clist)
```
```   364          val c = hd sclist
```
```   365          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
```
```   366                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
```
```   367          val v = find_index_eq c ceq
```
```   368          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0)
```
```   369                                      (othereqs @ noneqs)
```
```   370          val others = map (elim_var v eq) roth @ ioth
```
```   371      in elim(others,(v,nontriv)::hist) end
```
```   372   else
```
```   373   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
```
```   374       val numlist = 0 upto (length(hd lists) - 1)
```
```   375       val coeffs = map (fn i => map (el i) lists) numlist
```
```   376       val blows = map calc_blowup coeffs
```
```   377       val iblows = blows ~~ numlist
```
```   378       val nziblows = List.filter (fn (i,_) => i<>0) iblows
```
```   379   in if null nziblows then Failure((~1,nontriv)::hist)
```
```   380      else
```
```   381      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
```
```   382          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs
```
```   383          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes
```
```   384      in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
```
```   385   end
```
```   386   end
```
```   387   end;
```
```   388
```
```   389 (* ------------------------------------------------------------------------- *)
```
```   390 (* Translate back a proof.                                                   *)
```
```   391 (* ------------------------------------------------------------------------- *)
```
```   392
```
```   393 fun trace_thm msg th =
```
```   394     if !trace then (tracing msg; tracing (Display.string_of_thm th); th) else th;
```
```   395
```
```   396 fun trace_msg msg =
```
```   397     if !trace then tracing msg else ();
```
```   398
```
```   399 (* FIXME OPTIMIZE!!!! (partly done already)
```
```   400    Addition/Multiplication need i*t representation rather than t+t+...
```
```   401    Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
```
```   402    because Numerals are not known early enough.
```
```   403
```
```   404 Simplification may detect a contradiction 'prematurely' due to type
```
```   405 information: n+1 <= 0 is simplified to False and does not need to be crossed
```
```   406 with 0 <= n.
```
```   407 *)
```
```   408 local
```
```   409  exception FalseE of thm
```
```   410 in
```
```   411 fun mkthm (sg, ss) asms just =
```
```   412   let val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} =
```
```   413           Data.get sg;
```
```   414       val simpset' = Simplifier.inherit_context ss simpset;
```
```   415       val atoms = Library.foldl (fn (ats,(lhs,_,_,rhs,_,_)) =>
```
```   416                             map fst lhs  union  (map fst rhs  union  ats))
```
```   417                         ([], List.mapPartial (fn thm => if Thm.no_prems thm
```
```   418                                         then LA_Data.decomp sg (concl_of thm)
```
```   419                                         else NONE) asms)
```
```   420
```
```   421       fun add2 thm1 thm2 =
```
```   422         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
```
```   423         in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms
```
```   424         end;
```
```   425
```
```   426       fun try_add [] _ = NONE
```
```   427         | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of
```
```   428              NONE => try_add thm1s thm2 | some => some;
```
```   429
```
```   430       fun addthms thm1 thm2 =
```
```   431         case add2 thm1 thm2 of
```
```   432           NONE => (case try_add ([thm1] RL inj_thms) thm2 of
```
```   433                      NONE => ( valOf(try_add ([thm2] RL inj_thms) thm1)
```
```   434                                handle Option =>
```
```   435                                (trace_thm "" thm1; trace_thm "" thm2;
```
```   436                                 sys_error "Lin.arith. failed to add thms")
```
```   437                              )
```
```   438                    | SOME thm => thm)
```
```   439         | SOME thm => thm;
```
```   440
```
```   441       fun multn(n,thm) =
```
```   442         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
```
```   443         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
```
```   444 (*
```
```   445       fun multn2(n,thm) =
```
```   446         let val SOME(mth,cv) =
```
```   447               get_first (fn (th,cv) => SOME(thm RS th,cv) handle THM _ => NONE) mult_mono_thms
```
```   448             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
```
```   449         in instantiate ([],[(cv,ct)]) mth end
```
```   450 *)
```
```   451       fun multn2(n,thm) =
```
```   452         let val SOME(mth) =
```
```   453               get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms
```
```   454             fun cvar(th,_ \$ (_ \$ _ \$ var)) = cterm_of (#sign(rep_thm th)) var;
```
```   455             val cv = cvar(mth, hd(prems_of mth));
```
```   456             val ct = cterm_of sg (LA_Data.number_of(n,#T(rep_cterm cv)))
```
```   457         in instantiate ([],[(cv,ct)]) mth end
```
```   458
```
```   459       fun simp thm =
```
```   460         let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm)
```
```   461         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
```
```   462
```
```   463       fun mk(Asm i) = trace_thm "Asm" (nth asms i)
```
```   464         | mk(Nat i) = (trace_msg "Nat"; LA_Logic.mk_nat_thm sg (nth atoms i))
```
```   465         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
```
```   466         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
```
```   467         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
```
```   468         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
```
```   469         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
```
```   470         | mk(Multiplied(n,j)) = (trace_msg("*"^IntInf.toString n); trace_thm "*" (multn(n,mk j)))
```
```   471         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^IntInf.toString n); trace_thm "**" (multn2(n,mk j)))
```
```   472
```
```   473   in trace_msg "mkthm";
```
```   474      let val thm = trace_thm "Final thm:" (mk just)
```
```   475      in let val fls = simplify simpset' thm
```
```   476         in trace_thm "After simplification:" fls;
```
```   477            if LA_Logic.is_False fls then fls
```
```   478            else
```
```   479             (tracing "Assumptions:"; List.app print_thm asms;
```
```   480              tracing "Proved:"; print_thm fls;
```
```   481              warning "Linear arithmetic should have refuted the assumptions.\n\
```
```   482                      \Please inform Tobias Nipkow (nipkow@in.tum.de).";
```
```   483              fls)
```
```   484         end
```
```   485      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
```
```   486   end
```
```   487 end;
```
```   488
```
```   489 fun coeff poly atom : IntInf.int =
```
```   490   AList.lookup (op =) poly atom |> the_default 0;
```
```   491
```
```   492 fun lcms is = Library.foldl lcm (1, is);
```
```   493
```
```   494 fun integ(rlhs,r,rel,rrhs,s,d) =
```
```   495 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
```
```   496     val m = lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
```
```   497     fun mult(t,r) =
```
```   498         let val (i,j) = Rat.quotient_of_rat r
```
```   499         in (t,i * (m div j)) end
```
```   500 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
```
```   501
```
```   502 fun mklineq n atoms =
```
```   503   fn (item,k) =>
```
```   504   let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
```
```   505       val lhsa = map (coeff lhs) atoms
```
```   506       and rhsa = map (coeff rhs) atoms
```
```   507       val diff = map2 (curry (op -)) rhsa lhsa
```
```   508       val c = i-j
```
```   509       val just = Asm k
```
```   510       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
```
```   511   in case rel of
```
```   512       "<="   => lineq(c,Le,diff,just)
```
```   513      | "~<=" => if discrete
```
```   514                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
```
```   515                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
```
```   516      | "<"   => if discrete
```
```   517                 then lineq(c+1,Le,diff,LessD(just))
```
```   518                 else lineq(c,Lt,diff,just)
```
```   519      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
```
```   520      | "="   => lineq(c,Eq,diff,just)
```
```   521      | _     => sys_error("mklineq" ^ rel)
```
```   522   end;
```
```   523
```
```   524 (* ------------------------------------------------------------------------- *)
```
```   525 (* Print (counter) example                                                   *)
```
```   526 (* ------------------------------------------------------------------------- *)
```
```   527
```
```   528 fun print_atom((a,d),r) =
```
```   529   let val (p,q) = Rat.quotient_of_rat r
```
```   530       val s = if d then IntInf.toString p else
```
```   531               if p = 0 then "0"
```
```   532               else IntInf.toString p ^ "/" ^ IntInf.toString q
```
```   533   in a ^ " = " ^ s end;
```
```   534
```
```   535 fun print_ex sds =
```
```   536   curry (op ~~) sds
```
```   537   #> map print_atom
```
```   538   #> commas
```
```   539   #> curry (op ^) "Counter example:\n"
```
```   540   #> tracing;
```
```   541
```
```   542 fun trace_ex(sg,params,atoms,discr,n,hist:history) =
```
```   543   if null hist then ()
```
```   544   else let val frees = map Free params;
```
```   545            fun s_of_t t = Sign.string_of_term sg (subst_bounds(frees,t));
```
```   546            val (v,lineqs) :: hist' = hist
```
```   547            val start = if v = ~1 then (findex0 discr n lineqs,hist')
```
```   548                        else (replicate n Rat.zero,hist)
```
```   549        in warning "arith failed - see trace for a counter example";
```
```   550           print_ex ((map s_of_t atoms)~~discr) (findex discr start)
```
```   551           handle NoEx => (tracing "Sorry, no counter example.")
```
```   552        end;
```
```   553
```
```   554 fun mknat pTs ixs (atom,i) =
```
```   555   if LA_Logic.is_nat(pTs,atom)
```
```   556   then let val l = map (fn j => if j=i then 1 else 0) ixs
```
```   557        in SOME(Lineq(0,Le,l,Nat(i))) end
```
```   558   else NONE
```
```   559
```
```   560 (* This code is tricky. It takes a list of premises in the order they occur
```
```   561 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
```
```   562 ones as NONE. Going through the premises, each numeric one is converted into
```
```   563 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
```
```   564 >. Thus split_items returns a list of equation systems. This may blow up if
```
```   565 there are many ~=, but in practice it does not seem to happen. The really
```
```   566 tricky bit is to arrange the order of the cases such that they coincide with
```
```   567 the order in which the cases are in the end generated by the tactic that
```
```   568 applies the generated refutation thms (see function 'refute_tac').
```
```   569
```
```   570 For variables n of type nat, a constraint 0 <= n is added.
```
```   571 *)
```
```   572 fun split_items(items) =
```
```   573   let fun elim_neq front _ [] = [rev front]
```
```   574         | elim_neq front n (NONE::ineqs) = elim_neq front (n+1) ineqs
```
```   575         | elim_neq front n (SOME(ineq as (l,i,rel,r,j,d))::ineqs) =
```
```   576           if rel = "~=" then elim_neq front n (ineqs @ [SOME(l,i,"<",r,j,d)]) @
```
```   577                              elim_neq front n (ineqs @ [SOME(r,j,"<",l,i,d)])
```
```   578           else elim_neq ((ineq,n) :: front) (n+1) ineqs
```
```   579   in elim_neq [] 0 items end;
```
```   580
```
```   581 fun add_atoms(ats,((lhs,_,_,rhs,_,_),_)) =
```
```   582     (map fst lhs) union ((map fst rhs) union ats)
```
```   583
```
```   584 fun add_datoms(dats,((lhs,_,_,rhs,_,d),_)) =
```
```   585     (map (pair d o fst) lhs) union ((map (pair d o fst) rhs) union dats)
```
```   586
```
```   587 fun discr initems = map fst (Library.foldl add_datoms ([],initems));
```
```   588
```
```   589 fun refutes sg (pTs,params) ex =
```
```   590 let
```
```   591   fun refute (initems::initemss) js =
```
```   592     let val atoms = Library.foldl add_atoms ([],initems)
```
```   593         val n = length atoms
```
```   594         val mkleq = mklineq n atoms
```
```   595         val ixs = 0 upto (n-1)
```
```   596         val iatoms = atoms ~~ ixs
```
```   597         val natlineqs = List.mapPartial (mknat pTs ixs) iatoms
```
```   598         val ineqs = map mkleq initems @ natlineqs
```
```   599     in case elim(ineqs,[]) of
```
```   600          Success(j) =>
```
```   601            (trace_msg "Contradiction!"; refute initemss (js@[j]))
```
```   602        | Failure(hist) =>
```
```   603            (if not ex then ()
```
```   604             else trace_ex(sg,params,atoms,discr initems,n,hist);
```
```   605             NONE)
```
```   606     end
```
```   607     | refute [] js = SOME js
```
```   608 in refute end;
```
```   609
```
```   610 fun refute sg ps ex items = refutes sg ps ex (split_items items) [];
```
```   611
```
```   612 fun refute_tac ss (i,justs) =
```
```   613   fn state =>
```
```   614     let val sg = #sign(rep_thm state)
```
```   615         val {neqE, ...} = Data.get sg;
```
```   616         fun just1 j = REPEAT_DETERM(eresolve_tac neqE i) THEN
```
```   617           METAHYPS (fn asms => rtac (mkthm (sg, ss) asms j) 1) i
```
```   618     in DETERM(resolve_tac [LA_Logic.notI,LA_Logic.ccontr] i) THEN
```
```   619        EVERY(map just1 justs)
```
```   620     end
```
```   621     state;
```
```   622
```
```   623 fun count P xs = length(List.filter P xs);
```
```   624
```
```   625 (* The limit on the number of ~= allowed.
```
```   626    Because each ~= is split into two cases, this can lead to an explosion.
```
```   627 *)
```
```   628 val fast_arith_neq_limit = ref 9;
```
```   629
```
```   630 fun prove sg ps ex Hs concl =
```
```   631 let val Hitems = map (LA_Data.decomp sg) Hs
```
```   632 in if count (fn NONE => false | SOME(_,_,r,_,_,_) => r = "~=") Hitems
```
```   633       > !fast_arith_neq_limit then NONE
```
```   634    else
```
```   635    case LA_Data.decomp sg concl of
```
```   636      NONE => refute sg ps ex (Hitems@[NONE])
```
```   637    | SOME(citem as (r,i,rel,l,j,d)) =>
```
```   638        let val neg::rel0 = explode rel
```
```   639            val nrel = if neg = "~" then implode rel0 else "~"^rel
```
```   640        in refute sg ps ex (Hitems @ [SOME(r,i,nrel,l,j,d)]) end
```
```   641 end;
```
```   642
```
```   643 (*
```
```   644 Fast but very incomplete decider. Only premises and conclusions
```
```   645 that are already (negated) (in)equations are taken into account.
```
```   646 *)
```
```   647 fun simpset_lin_arith_tac ss ex i st = SUBGOAL (fn (A,_) =>
```
```   648   let val params = rev(Logic.strip_params A)
```
```   649       val pTs = map snd params
```
```   650       val Hs = Logic.strip_assums_hyp A
```
```   651       val concl = Logic.strip_assums_concl A
```
```   652   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
```
```   653      case prove (Thm.sign_of_thm st) (pTs,params) ex Hs concl of
```
```   654        NONE => (trace_msg "Refutation failed."; no_tac)
```
```   655      | SOME js => (trace_msg "Refutation succeeded."; refute_tac ss (i,js))
```
```   656   end) i st;
```
```   657
```
```   658 fun lin_arith_tac ex i st =
```
```   659   simpset_lin_arith_tac (Simplifier.theory_context (Thm.theory_of_thm st) Simplifier.empty_ss)
```
```   660     ex i st;
```
```   661
```
```   662 fun cut_lin_arith_tac ss i =
```
```   663   cut_facts_tac (Simplifier.prems_of_ss ss) i THEN
```
```   664   simpset_lin_arith_tac ss false i;
```
```   665
```
```   666 (** Forward proof from theorems **)
```
```   667
```
```   668 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
```
```   669 to splits of ~= premises) such that it coincides with the order of the cases
```
```   670 generated by function split_items. *)
```
```   671
```
```   672 datatype splittree = Tip of thm list
```
```   673                    | Spl of thm * cterm * splittree * cterm * splittree
```
```   674
```
```   675 fun extract imp =
```
```   676 let val (Il,r) = Thm.dest_comb imp
```
```   677     val (_,imp1) = Thm.dest_comb Il
```
```   678     val (Ict1,_) = Thm.dest_comb imp1
```
```   679     val (_,ct1) = Thm.dest_comb Ict1
```
```   680     val (Ir,_) = Thm.dest_comb r
```
```   681     val (_,Ict2r) = Thm.dest_comb Ir
```
```   682     val (Ict2,_) = Thm.dest_comb Ict2r
```
```   683     val (_,ct2) = Thm.dest_comb Ict2
```
```   684 in (ct1,ct2) end;
```
```   685
```
```   686 fun splitasms sg asms =
```
```   687 let val {neqE, ...}  = Data.get sg;
```
```   688     fun split(asms',[]) = Tip(rev asms')
```
```   689       | split(asms',asm::asms) =
```
```   690       (case get_first (fn th => SOME(asm COMP th) handle THM _ => NONE) neqE
```
```   691        of SOME spl =>
```
```   692           let val (ct1,ct2) = extract(cprop_of spl)
```
```   693               val thm1 = assume ct1 and thm2 = assume ct2
```
```   694           in Spl(spl,ct1,split(asms',asms@[thm1]),ct2,split(asms',asms@[thm2]))
```
```   695           end
```
```   696        | NONE => split(asm::asms', asms))
```
```   697 in split([],asms) end;
```
```   698
```
```   699 fun fwdproof ctxt (Tip asms) (j::js) = (mkthm ctxt asms j, js)
```
```   700   | fwdproof ctxt (Spl(thm,ct1,tree1,ct2,tree2)) js =
```
```   701     let val (thm1,js1) = fwdproof ctxt tree1 js
```
```   702         val (thm2,js2) = fwdproof ctxt tree2 js1
```
```   703         val thm1' = implies_intr ct1 thm1
```
```   704         val thm2' = implies_intr ct2 thm2
```
```   705     in (thm2' COMP (thm1' COMP thm), js2) end;
```
```   706 (* needs handle THM _ => NONE ? *)
```
```   707
```
```   708 fun prover (ctxt as (sg, _)) thms Tconcl js pos =
```
```   709 let val nTconcl = LA_Logic.neg_prop Tconcl
```
```   710     val cnTconcl = cterm_of sg nTconcl
```
```   711     val nTconclthm = assume cnTconcl
```
```   712     val tree = splitasms sg (thms @ [nTconclthm])
```
```   713     val (thm,_) = fwdproof ctxt tree js
```
```   714     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
```
```   715 in SOME(LA_Logic.mk_Eq((implies_intr cnTconcl thm) COMP contr)) end
```
```   716 (* in case concl contains ?-var, which makes assume fail: *)
```
```   717 handle THM _ => NONE;
```
```   718
```
```   719 (* PRE: concl is not negated!
```
```   720    This assumption is OK because
```
```   721    1. lin_arith_prover tries both to prove and disprove concl and
```
```   722    2. lin_arith_prover is applied by the simplifier which
```
```   723       dives into terms and will thus try the non-negated concl anyway.
```
```   724 *)
```
```   725 fun lin_arith_prover sg ss concl =
```
```   726 let
```
```   727     val thms = List.concat(map LA_Logic.atomize (prems_of_ss ss));
```
```   728     val Hs = map (#prop o rep_thm) thms
```
```   729     val Tconcl = LA_Logic.mk_Trueprop concl
```
```   730 in case prove sg ([],[]) false Hs Tconcl of (* concl provable? *)
```
```   731      SOME js => prover (sg, ss) thms Tconcl js true
```
```   732    | NONE => let val nTconcl = LA_Logic.neg_prop Tconcl
```
```   733           in case prove sg ([],[]) false Hs nTconcl of (* ~concl provable? *)
```
```   734                SOME js => prover (sg, ss) thms nTconcl js false
```
```   735              | NONE => NONE
```
```   736           end
```
```   737 end;
```
```   738
```
```   739 end;
```