src/Provers/Arith/fast_lin_arith.ML
author wenzelm
Wed Jun 29 21:34:16 2011 +0200 (2011-06-29)
changeset 43597 b4a093e755db
parent 43568 3e4889375781
child 43609 20760e3608fa
permissions -rw-r--r--
tuned signature;
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
     3 
     4 A generic linear arithmetic package.  It provides two tactics
     5 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
     6 (lin_arith_simproc).
     7 
     8 Only take premises and conclusions into account that are already
     9 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
    10 the term.
    11 *)
    12 
    13 (*** Data needed for setting up the linear arithmetic package ***)
    14 
    15 signature LIN_ARITH_LOGIC =
    16 sig
    17   val conjI       : thm  (* P ==> Q ==> P & Q *)
    18   val ccontr      : thm  (* (~ P ==> False) ==> P *)
    19   val notI        : thm  (* (P ==> False) ==> ~ P *)
    20   val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
    21   val not_leD     : thm  (* ~(m <= n) ==> n < m *)
    22   val sym         : thm  (* x = y ==> y = x *)
    23   val trueI       : thm  (* True *)
    24   val mk_Eq       : thm -> thm
    25   val atomize     : thm -> thm list
    26   val mk_Trueprop : term -> term
    27   val neg_prop    : term -> term
    28   val is_False    : thm -> bool
    29   val is_nat      : typ list * term -> bool
    30   val mk_nat_thm  : theory -> term -> thm
    31 end;
    32 (*
    33 mk_Eq(~in) = `in == False'
    34 mk_Eq(in) = `in == True'
    35 where `in' is an (in)equality.
    36 
    37 neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
    38   (logically) negated version of t (again wrapped up in Trueprop),
    39   where the negation of a negative term is the term itself (no
    40   double negation!); raises TERM ("neg_prop", [t]) if t is not of
    41   the form 'Trueprop $ _'
    42 
    43 is_nat(parameter-types,t) =  t:nat
    44 mk_nat_thm(t) = "0 <= t"
    45 *)
    46 
    47 signature LIN_ARITH_DATA =
    48 sig
    49   (*internal representation of linear (in-)equations:*)
    50   type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
    51   val decomp: Proof.context -> term -> decomp option
    52   val domain_is_nat: term -> bool
    53 
    54   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    55   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    56 
    57   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    58   val pre_tac: simpset -> int -> tactic
    59 
    60   (*the limit on the number of ~= allowed; because each ~= is split
    61     into two cases, this can lead to an explosion*)
    62   val fast_arith_neq_limit: int Config.T
    63 end;
    64 (*
    65 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    66    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    67          p (q, respectively) is the decomposition of the sum term x
    68          (y, respectively) into a list of summand * multiplicity
    69          pairs and a constant summand and d indicates if the domain
    70          is discrete.
    71 
    72 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
    73 
    74 The relationship between pre_decomp and pre_tac is somewhat tricky.  The
    75 internal representation of a subgoal and the corresponding theorem must
    76 be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
    77 the comment for split_items below.  (This is even necessary for eta- and
    78 beta-equivalent modifications, as some of the lin. arith. code is not
    79 insensitive to them.)
    80 
    81 ss must reduce contradictory <= to False.
    82    It should also cancel common summands to keep <= reduced;
    83    otherwise <= can grow to massive proportions.
    84 *)
    85 
    86 signature FAST_LIN_ARITH =
    87 sig
    88   val cut_lin_arith_tac: simpset -> int -> tactic
    89   val lin_arith_tac: Proof.context -> bool -> int -> tactic
    90   val lin_arith_simproc: simpset -> term -> thm option
    91   val map_data:
    92     ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    93       lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    94       number_of: (theory -> typ -> int -> cterm) option} ->
    95      {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    96       lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    97       number_of: (theory -> typ -> int -> cterm) option}) ->
    98       Context.generic -> Context.generic
    99   val add_inj_thms: thm list -> Context.generic -> Context.generic
   100   val add_lessD: thm -> Context.generic -> Context.generic
   101   val add_simps: thm list -> Context.generic -> Context.generic
   102   val add_simprocs: simproc list -> Context.generic -> Context.generic
   103   val set_number_of: (theory -> typ -> int -> cterm) -> Context.generic -> Context.generic
   104   val trace: bool Unsynchronized.ref
   105 end;
   106 
   107 functor Fast_Lin_Arith
   108   (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
   109 struct
   110 
   111 
   112 (** theory data **)
   113 
   114 structure Data = Generic_Data
   115 (
   116   type T =
   117    {add_mono_thms: thm list,
   118     mult_mono_thms: thm list,
   119     inj_thms: thm list,
   120     lessD: thm list,
   121     neqE: thm list,
   122     simpset: Simplifier.simpset,
   123     number_of: (theory -> typ -> int -> cterm) option};
   124 
   125   val empty : T =
   126    {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   127     lessD = [], neqE = [], simpset = Simplifier.empty_ss,
   128     number_of = NONE};
   129   val extend = I;
   130   fun merge
   131     ({add_mono_thms = add_mono_thms1, mult_mono_thms = mult_mono_thms1, inj_thms = inj_thms1,
   132       lessD = lessD1, neqE = neqE1, simpset = simpset1, number_of = number_of1},
   133      {add_mono_thms = add_mono_thms2, mult_mono_thms = mult_mono_thms2, inj_thms = inj_thms2,
   134       lessD = lessD2, neqE = neqE2, simpset = simpset2, number_of = number_of2}) : T =
   135     {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
   136      mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
   137      inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
   138      lessD = Thm.merge_thms (lessD1, lessD2),
   139      neqE = Thm.merge_thms (neqE1, neqE2),
   140      simpset = Simplifier.merge_ss (simpset1, simpset2),
   141      number_of = merge_options (number_of1, number_of2)};
   142 );
   143 
   144 val map_data = Data.map;
   145 val get_data = Data.get o Context.Proof;
   146 
   147 fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   148   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
   149     lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   150 
   151 fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   152   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   153     lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   154 
   155 fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   156   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   157     lessD = lessD, neqE = neqE, simpset = f simpset, number_of = number_of};
   158 
   159 fun add_inj_thms thms = map_data (map_inj_thms (append thms));
   160 fun add_lessD thm = map_data (map_lessD (fn thms => thms @ [thm]));
   161 fun add_simps thms = map_data (map_simpset (fn simpset => simpset addsimps thms));
   162 fun add_simprocs procs = map_data (map_simpset (fn simpset => simpset addsimprocs procs));
   163 
   164 fun set_number_of f =
   165   map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, ...} =>
   166    {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   167     lessD = lessD, neqE = neqE, simpset = simpset, number_of = SOME f});
   168 
   169 fun number_of ctxt =
   170   (case Data.get (Context.Proof ctxt) of
   171     {number_of = SOME f, ...} => f (Proof_Context.theory_of ctxt)
   172   | _ => fn _ => fn _ => raise CTERM ("number_of", []));
   173 
   174 
   175 
   176 (*** A fast decision procedure ***)
   177 (*** Code ported from HOL Light ***)
   178 (* possible optimizations:
   179    use (var,coeff) rep or vector rep  tp save space;
   180    treat non-negative atoms separately rather than adding 0 <= atom
   181 *)
   182 
   183 val trace = Unsynchronized.ref false;
   184 
   185 datatype lineq_type = Eq | Le | Lt;
   186 
   187 datatype injust = Asm of int
   188                 | Nat of int (* index of atom *)
   189                 | LessD of injust
   190                 | NotLessD of injust
   191                 | NotLeD of injust
   192                 | NotLeDD of injust
   193                 | Multiplied of int * injust
   194                 | Added of injust * injust;
   195 
   196 datatype lineq = Lineq of int * lineq_type * int list * injust;
   197 
   198 (* ------------------------------------------------------------------------- *)
   199 (* Finding a (counter) example from the trace of a failed elimination        *)
   200 (* ------------------------------------------------------------------------- *)
   201 (* Examples are represented as rational numbers,                             *)
   202 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   203 
   204 exception NoEx;
   205 
   206 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   207    In general, true means the bound is included, false means it is excluded.
   208    Need to know if it is a lower or upper bound for unambiguous interpretation!
   209 *)
   210 
   211 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   212   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   213   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   214 
   215 (* PRE: ex[v] must be 0! *)
   216 fun eval ex v (a, le, cs) =
   217   let
   218     val rs = map Rat.rat_of_int cs;
   219     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   220   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   221 (* If nth rs v < 0, le should be negated.
   222    Instead this swap is taken into account in ratrelmin2.
   223 *)
   224 
   225 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
   226   case Rat.ord (r, s)
   227    of EQUAL => (r, (not ler) andalso (not les))
   228     | LESS => x
   229     | GREATER => y;
   230 
   231 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
   232   case Rat.ord (r, s)
   233    of EQUAL => (r, ler andalso les)
   234     | LESS => y
   235     | GREATER => x;
   236 
   237 val ratrelmin = foldr1 ratrelmin2;
   238 val ratrelmax = foldr1 ratrelmax2;
   239 
   240 fun ratexact up (r, exact) =
   241   if exact then r else
   242   let
   243     val (_, q) = Rat.quotient_of_rat r;
   244     val nth = Rat.inv (Rat.rat_of_int q);
   245   in Rat.add r (if up then nth else Rat.neg nth) end;
   246 
   247 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
   248 
   249 fun choose2 d ((lb, exactl), (ub, exactu)) =
   250   let val ord = Rat.sign lb in
   251   if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
   252     then Rat.zero
   253     else if not d then
   254       if ord = GREATER
   255         then if exactl then lb else ratmiddle (lb, ub)
   256         else if exactu then ub else ratmiddle (lb, ub)
   257       else (*discrete domain, both bounds must be exact*)
   258       if ord = GREATER
   259         then let val lb' = Rat.roundup lb in
   260           if Rat.le lb' ub then lb' else raise NoEx end
   261         else let val ub' = Rat.rounddown ub in
   262           if Rat.le lb ub' then ub' else raise NoEx end
   263   end;
   264 
   265 fun findex1 discr (v, lineqs) ex =
   266   let
   267     val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
   268     val ineqs = maps elim_eqns nz;
   269     val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
   270     val lb = ratrelmax (map (eval ex v) ge)
   271     val ub = ratrelmin (map (eval ex v) le)
   272   in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
   273 
   274 fun elim1 v x =
   275   map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
   276                         nth_map v (K Rat.zero) bs));
   277 
   278 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
   279  of [x] => x =/ nth cs v
   280   | _ => false;
   281 
   282 (* The base case:
   283    all variables occur only with positive or only with negative coefficients *)
   284 fun pick_vars discr (ineqs,ex) =
   285   let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
   286   in case nz of [] => ex
   287      | (_,_,cs) :: _ =>
   288        let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
   289            val d = nth discr v;
   290            val pos = not (Rat.sign (nth cs v) = LESS);
   291            val sv = filter (single_var v) nz;
   292            val minmax =
   293              if pos then if d then Rat.roundup o fst o ratrelmax
   294                          else ratexact true o ratrelmax
   295                     else if d then Rat.rounddown o fst o ratrelmin
   296                          else ratexact false o ratrelmin
   297            val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
   298            val x = minmax((Rat.zero,if pos then true else false)::bnds)
   299            val ineqs' = elim1 v x nz
   300            val ex' = nth_map v (K x) ex
   301        in pick_vars discr (ineqs',ex') end
   302   end;
   303 
   304 fun findex0 discr n lineqs =
   305   let val ineqs = maps elim_eqns lineqs
   306       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
   307                        ineqs
   308   in pick_vars discr (rineqs,replicate n Rat.zero) end;
   309 
   310 (* ------------------------------------------------------------------------- *)
   311 (* End of counterexample finder. The actual decision procedure starts here.  *)
   312 (* ------------------------------------------------------------------------- *)
   313 
   314 (* ------------------------------------------------------------------------- *)
   315 (* Calculate new (in)equality type after addition.                           *)
   316 (* ------------------------------------------------------------------------- *)
   317 
   318 fun find_add_type(Eq,x) = x
   319   | find_add_type(x,Eq) = x
   320   | find_add_type(_,Lt) = Lt
   321   | find_add_type(Lt,_) = Lt
   322   | find_add_type(Le,Le) = Le;
   323 
   324 (* ------------------------------------------------------------------------- *)
   325 (* Multiply out an (in)equation.                                             *)
   326 (* ------------------------------------------------------------------------- *)
   327 
   328 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   329   if n = 1 then i
   330   else if n = 0 andalso ty = Lt then raise Fail "multiply_ineq"
   331   else if n < 0 andalso (ty=Le orelse ty=Lt) then raise Fail "multiply_ineq"
   332   else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
   333 
   334 (* ------------------------------------------------------------------------- *)
   335 (* Add together (in)equations.                                               *)
   336 (* ------------------------------------------------------------------------- *)
   337 
   338 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
   339   let val l = map2 Integer.add l1 l2
   340   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   341 
   342 (* ------------------------------------------------------------------------- *)
   343 (* Elimination of variable between a single pair of (in)equations.           *)
   344 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   345 (* ------------------------------------------------------------------------- *)
   346 
   347 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   348   let val c1 = nth l1 v and c2 = nth l2 v
   349       val m = Integer.lcm (abs c1) (abs c2)
   350       val m1 = m div (abs c1) and m2 = m div (abs c2)
   351       val (n1,n2) =
   352         if (c1 >= 0) = (c2 >= 0)
   353         then if ty1 = Eq then (~m1,m2)
   354              else if ty2 = Eq then (m1,~m2)
   355                   else raise Fail "elim_var"
   356         else (m1,m2)
   357       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   358                     then (~n1,~n2) else (n1,n2)
   359   in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
   360 
   361 (* ------------------------------------------------------------------------- *)
   362 (* The main refutation-finding code.                                         *)
   363 (* ------------------------------------------------------------------------- *)
   364 
   365 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   366 
   367 fun is_contradictory (Lineq(k,ty,_,_)) =
   368   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   369 
   370 fun calc_blowup l =
   371   let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
   372   in length p * length n end;
   373 
   374 (* ------------------------------------------------------------------------- *)
   375 (* Main elimination code:                                                    *)
   376 (*                                                                           *)
   377 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   378 (*                                                                           *)
   379 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   380 (* coefficient in any of them, and uses it to eliminate.                     *)
   381 (*                                                                           *)
   382 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   383 (* blowup (number of consequences generated) and eliminates it.              *)
   384 (* ------------------------------------------------------------------------- *)
   385 
   386 fun extract_first p =
   387   let
   388     fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
   389       | extract xs [] = raise Empty
   390   in extract [] end;
   391 
   392 fun print_ineqs ineqs =
   393   if !trace then
   394      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   395        string_of_int c ^
   396        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   397        commas(map string_of_int l)) ineqs))
   398   else ();
   399 
   400 type history = (int * lineq list) list;
   401 datatype result = Success of injust | Failure of history;
   402 
   403 fun elim (ineqs, hist) =
   404   let val _ = print_ineqs ineqs
   405       val (triv, nontriv) = List.partition is_trivial ineqs in
   406   if not (null triv)
   407   then case Library.find_first is_contradictory triv of
   408          NONE => elim (nontriv, hist)
   409        | SOME(Lineq(_,_,_,j)) => Success j
   410   else
   411   if null nontriv then Failure hist
   412   else
   413   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   414   if not (null eqs) then
   415      let val c =
   416            fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
   417            |> filter (fn i => i <> 0)
   418            |> sort (int_ord o pairself abs)
   419            |> hd
   420          val (eq as Lineq(_,_,ceq,_),othereqs) =
   421                extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
   422          val v = find_index (fn v => v = c) ceq
   423          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
   424                                      (othereqs @ noneqs)
   425          val others = map (elim_var v eq) roth @ ioth
   426      in elim(others,(v,nontriv)::hist) end
   427   else
   428   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   429       val numlist = 0 upto (length (hd lists) - 1)
   430       val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
   431       val blows = map calc_blowup coeffs
   432       val iblows = blows ~~ numlist
   433       val nziblows = filter_out (fn (i, _) => i = 0) iblows
   434   in if null nziblows then Failure((~1,nontriv)::hist)
   435      else
   436      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   437          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
   438          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
   439      in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
   440   end
   441   end
   442   end;
   443 
   444 (* ------------------------------------------------------------------------- *)
   445 (* Translate back a proof.                                                   *)
   446 (* ------------------------------------------------------------------------- *)
   447 
   448 fun trace_thm ctxt msg th =
   449   (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
   450 
   451 fun trace_term ctxt msg t =
   452   (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
   453 
   454 fun trace_msg msg =
   455   if !trace then tracing msg else ();
   456 
   457 val union_term = union Pattern.aeconv;
   458 val union_bterm = union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'));
   459 
   460 fun add_atoms (lhs, _, _, rhs, _, _) =
   461   union_term (map fst lhs) o union_term (map fst rhs);
   462 
   463 fun atoms_of ds = fold add_atoms ds [];
   464 
   465 (*
   466 Simplification may detect a contradiction 'prematurely' due to type
   467 information: n+1 <= 0 is simplified to False and does not need to be crossed
   468 with 0 <= n.
   469 *)
   470 local
   471   exception FalseE of thm
   472 in
   473 
   474 fun mkthm ss asms (just: injust) =
   475   let
   476     val ctxt = Simplifier.the_context ss;
   477     val thy = Proof_Context.theory_of ctxt;
   478     val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} = get_data ctxt;
   479     val number_of = number_of ctxt;
   480     val simpset' = Simplifier.inherit_context ss simpset;
   481     fun only_concl f thm =
   482       if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
   483     val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
   484 
   485     fun use_first rules thm =
   486       get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
   487 
   488     fun add2 thm1 thm2 =
   489       use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
   490     fun try_add thms thm = get_first (fn th => add2 th thm) thms;
   491 
   492     fun add_thms thm1 thm2 =
   493       (case add2 thm1 thm2 of
   494         NONE =>
   495           (case try_add ([thm1] RL inj_thms) thm2 of
   496             NONE =>
   497               (the (try_add ([thm2] RL inj_thms) thm1)
   498                 handle Option =>
   499                   (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
   500                    raise Fail "Linear arithmetic: failed to add thms"))
   501           | SOME thm => thm)
   502       | SOME thm => thm);
   503 
   504     fun mult_by_add n thm =
   505       let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
   506       in mul n thm end;
   507 
   508     val rewr = Simplifier.rewrite simpset';
   509     val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
   510       (Conv.binop_conv rewr)));
   511     fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
   512       let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
   513       in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
   514 
   515     fun mult n thm =
   516       (case use_first mult_mono_thms thm of
   517         NONE => mult_by_add n thm
   518       | SOME mth =>
   519           let
   520             val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
   521               |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
   522             val T = #T (Thm.rep_cterm cv)
   523           in
   524             mth
   525             |> Thm.instantiate ([], [(cv, number_of T n)])
   526             |> rewrite_concl
   527             |> discharge_prem
   528             handle CTERM _ => mult_by_add n thm
   529                  | THM _ => mult_by_add n thm
   530           end);
   531 
   532     fun mult_thm (n, thm) =
   533       if n = ~1 then thm RS LA_Logic.sym
   534       else if n < 0 then mult (~n) thm RS LA_Logic.sym
   535       else mult n thm;
   536 
   537     fun simp thm =
   538       let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
   539       in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
   540 
   541     fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
   542       | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
   543       | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
   544       | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
   545       | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   546       | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
   547       | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
   548       | mk (Multiplied (n, j)) =
   549           (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
   550 
   551   in
   552     let
   553       val _ = trace_msg "mkthm";
   554       val thm = trace_thm ctxt "Final thm:" (mk just);
   555       val fls = simplify simpset' thm;
   556       val _ = trace_thm ctxt "After simplification:" fls;
   557       val _ =
   558         if LA_Logic.is_False fls then ()
   559         else
   560          (tracing (cat_lines
   561            (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
   562             ["Proved:", Display.string_of_thm ctxt fls, ""]));
   563           warning "Linear arithmetic should have refuted the assumptions.\n\
   564             \Please inform Tobias Nipkow.")
   565     in fls end
   566     handle FalseE thm => trace_thm ctxt "False reached early:" thm
   567   end;
   568 
   569 end;
   570 
   571 fun coeff poly atom =
   572   AList.lookup Pattern.aeconv poly atom |> the_default 0;
   573 
   574 fun integ(rlhs,r,rel,rrhs,s,d) =
   575 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   576     val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   577     fun mult(t,r) =
   578         let val (i,j) = Rat.quotient_of_rat r
   579         in (t,i * (m div j)) end
   580 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   581 
   582 fun mklineq atoms =
   583   fn (item, k) =>
   584   let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
   585       val lhsa = map (coeff lhs) atoms
   586       and rhsa = map (coeff rhs) atoms
   587       val diff = map2 (curry (op -)) rhsa lhsa
   588       val c = i-j
   589       val just = Asm k
   590       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
   591   in case rel of
   592       "<="   => lineq(c,Le,diff,just)
   593      | "~<=" => if discrete
   594                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   595                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   596      | "<"   => if discrete
   597                 then lineq(c+1,Le,diff,LessD(just))
   598                 else lineq(c,Lt,diff,just)
   599      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   600      | "="   => lineq(c,Eq,diff,just)
   601      | _     => raise Fail ("mklineq" ^ rel)
   602   end;
   603 
   604 (* ------------------------------------------------------------------------- *)
   605 (* Print (counter) example                                                   *)
   606 (* ------------------------------------------------------------------------- *)
   607 
   608 fun print_atom((a,d),r) =
   609   let val (p,q) = Rat.quotient_of_rat r
   610       val s = if d then string_of_int p else
   611               if p = 0 then "0"
   612               else string_of_int p ^ "/" ^ string_of_int q
   613   in a ^ " = " ^ s end;
   614 
   615 fun produce_ex sds =
   616   curry (op ~~) sds
   617   #> map print_atom
   618   #> commas
   619   #> curry (op ^) "Counterexample (possibly spurious):\n";
   620 
   621 fun trace_ex ctxt params atoms discr n (hist: history) =
   622   case hist of
   623     [] => ()
   624   | (v, lineqs) :: hist' =>
   625       let
   626         val frees = map Free params
   627         fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
   628         val start =
   629           if v = ~1 then (hist', findex0 discr n lineqs)
   630           else (hist, replicate n Rat.zero)
   631         val ex = SOME (produce_ex (map show_term atoms ~~ discr)
   632             (uncurry (fold (findex1 discr)) start))
   633           handle NoEx => NONE
   634       in
   635         case ex of
   636           SOME s => (warning "Linear arithmetic failed - see trace for a (potentially spurious) counterexample."; tracing s)
   637         | NONE => warning "Linear arithmetic failed"
   638       end;
   639 
   640 (* ------------------------------------------------------------------------- *)
   641 
   642 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
   643   if LA_Logic.is_nat (pTs, atom)
   644   then let val l = map (fn j => if j=i then 1 else 0) ixs
   645        in SOME (Lineq (0, Le, l, Nat i)) end
   646   else NONE;
   647 
   648 (* This code is tricky. It takes a list of premises in the order they occur
   649 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   650 ones as NONE. Going through the premises, each numeric one is converted into
   651 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   652 >. Thus split_items returns a list of equation systems. This may blow up if
   653 there are many ~=, but in practice it does not seem to happen. The really
   654 tricky bit is to arrange the order of the cases such that they coincide with
   655 the order in which the cases are in the end generated by the tactic that
   656 applies the generated refutation thms (see function 'refute_tac').
   657 
   658 For variables n of type nat, a constraint 0 <= n is added.
   659 *)
   660 
   661 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
   662 (*        could be intertwined: separate the first (fully split) case,       *)
   663 (*        refute it, continue with splitting and refuting.  Terminate with   *)
   664 (*        failure as soon as a case could not be refuted; i.e. delay further *)
   665 (*        splitting until after a refutation for other cases has been found. *)
   666 
   667 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
   668 let
   669   (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
   670   (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
   671   (* level                                                          *)
   672   (* FIXME: this is currently sensitive to the order of theorems in *)
   673   (*        neqE:  The theorem for type "nat" must come first.  A   *)
   674   (*        better (i.e. less likely to break when neqE changes)    *)
   675   (*        implementation should *test* which theorem from neqE    *)
   676   (*        can be applied, and split the premise accordingly.      *)
   677   fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   678                (LA_Data.decomp option * bool) list list =
   679   let
   680     fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
   681                   (LA_Data.decomp option * bool) list list =
   682           [[]]
   683       | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
   684           map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
   685       | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
   686           if rel = "~=" andalso (not nat_only orelse is_nat) then
   687             (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
   688             elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
   689             elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
   690           else
   691             map (cons ineq) (elim_neq' nat_only ineqs)
   692   in
   693     ineqs |> elim_neq' true
   694           |> maps (elim_neq' false)
   695   end
   696 
   697   fun ignore_neq (NONE, bool) = (NONE, bool)
   698     | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
   699       if rel = "~=" then (NONE, bool) else (ineq, bool)
   700 
   701   fun number_hyps _ []             = []
   702     | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
   703     | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
   704 
   705   val result = (Ts, terms)
   706     |> (* user-defined preprocessing of the subgoal *)
   707        (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
   708     |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
   709          string_of_int (length subgoals) ^ " subgoal(s) total."))
   710     |> (* produce the internal encoding of (in-)equalities *)
   711        map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
   712     |> (* splitting of inequalities *)
   713        map (apsnd (if split_neq then elim_neq else
   714                      Library.single o map ignore_neq))
   715     |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
   716     |> (* numbering of hypotheses, ignoring irrelevant ones *)
   717        map (apsnd (number_hyps 0))
   718 in
   719   trace_msg ("Splitting of inequalities yields " ^
   720     string_of_int (length result) ^ " subgoal(s) total.");
   721   result
   722 end;
   723 
   724 fun add_datoms ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _) (dats : (bool * term) list) =
   725   union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
   726 
   727 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
   728   map fst (fold add_datoms initems []);
   729 
   730 fun refutes ctxt params show_ex :
   731     (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
   732   let
   733     fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
   734           let
   735             val atoms = atoms_of (map fst initems)
   736             val n = length atoms
   737             val mkleq = mklineq atoms
   738             val ixs = 0 upto (n - 1)
   739             val iatoms = atoms ~~ ixs
   740             val natlineqs = map_filter (mknat Ts ixs) iatoms
   741             val ineqs = map mkleq initems @ natlineqs
   742           in case elim (ineqs, []) of
   743                Success j =>
   744                  (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
   745                   refute initemss (js @ [j]))
   746              | Failure hist =>
   747                  (if not show_ex then ()
   748                   else
   749                     let
   750                       val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
   751                       val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
   752                         (Name.invent (Variable.names_of ctxt') Name.uu (length Ts - length params))
   753                       val params' = (more_names @ param_names) ~~ Ts
   754                     in
   755                       trace_ex ctxt'' params' atoms (discr initems) n hist
   756                     end; NONE)
   757           end
   758       | refute [] js = SOME js
   759   in refute end;
   760 
   761 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
   762   refutes ctxt params show_ex (split_items ctxt do_pre split_neq
   763     (map snd params, terms)) [];
   764 
   765 fun count P xs = length (filter P xs);
   766 
   767 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
   768   let
   769     val _ = trace_msg "prove:"
   770     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
   771     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
   772     (* theorem/tactic level                                             *)
   773     val Hs' = Hs @ [LA_Logic.neg_prop concl]
   774     fun is_neq NONE                 = false
   775       | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
   776     val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
   777     val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
   778   in
   779     if split_neq then ()
   780     else
   781       trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
   782         string_of_int neq_limit ^ "), ignoring all inequalities");
   783     (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
   784   end handle TERM ("neg_prop", _) =>
   785     (* since no meta-logic negation is available, we can only fail if   *)
   786     (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
   787     (* dropping the conclusion doesn't work either, because even        *)
   788     (* 'False' does not imply arbitrary 'concl::prop')                  *)
   789     (trace_msg "prove failed (cannot negate conclusion).";
   790       (false, NONE));
   791 
   792 fun refute_tac ss (i, split_neq, justs) =
   793   fn state =>
   794     let
   795       val ctxt = Simplifier.the_context ss;
   796       val _ = trace_thm ctxt
   797         ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
   798           string_of_int (length justs) ^ " justification(s)):") state
   799       val {neqE, ...} = get_data ctxt;
   800       fun just1 j =
   801         (* eliminate inequalities *)
   802         (if split_neq then
   803           REPEAT_DETERM (eresolve_tac neqE i)
   804         else
   805           all_tac) THEN
   806           PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
   807           (* use theorems generated from the actual justifications *)
   808           Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
   809     in
   810       (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
   811       DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
   812       (* user-defined preprocessing of the subgoal *)
   813       DETERM (LA_Data.pre_tac ss i) THEN
   814       PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
   815       (* prove every resulting subgoal, using its justification *)
   816       EVERY (map just1 justs)
   817     end  state;
   818 
   819 (*
   820 Fast but very incomplete decider. Only premises and conclusions
   821 that are already (negated) (in)equations are taken into account.
   822 *)
   823 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
   824   let
   825     val ctxt = Simplifier.the_context ss
   826     val params = rev (Logic.strip_params A)
   827     val Hs = Logic.strip_assums_hyp A
   828     val concl = Logic.strip_assums_concl A
   829     val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
   830   in
   831     case prove ctxt params show_ex true Hs concl of
   832       (_, NONE) => (trace_msg "Refutation failed."; no_tac)
   833     | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
   834                                refute_tac ss (i, split_neq, js))
   835   end);
   836 
   837 fun cut_lin_arith_tac ss =
   838   cut_facts_tac (Simplifier.prems_of ss) THEN'
   839   simpset_lin_arith_tac ss false;
   840 
   841 fun lin_arith_tac ctxt =
   842   simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
   843 
   844 
   845 
   846 (** Forward proof from theorems **)
   847 
   848 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   849 to splits of ~= premises) such that it coincides with the order of the cases
   850 generated by function split_items. *)
   851 
   852 datatype splittree = Tip of thm list
   853                    | Spl of thm * cterm * splittree * cterm * splittree;
   854 
   855 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
   856 
   857 fun extract (imp : cterm) : cterm * cterm =
   858 let val (Il, r)    = Thm.dest_comb imp
   859     val (_, imp1)  = Thm.dest_comb Il
   860     val (Ict1, _)  = Thm.dest_comb imp1
   861     val (_, ct1)   = Thm.dest_comb Ict1
   862     val (Ir, _)    = Thm.dest_comb r
   863     val (_, Ict2r) = Thm.dest_comb Ir
   864     val (Ict2, _)  = Thm.dest_comb Ict2r
   865     val (_, ct2)   = Thm.dest_comb Ict2
   866 in (ct1, ct2) end;
   867 
   868 fun splitasms ctxt (asms : thm list) : splittree =
   869 let val {neqE, ...} = get_data ctxt
   870     fun elim_neq [] (asms', []) = Tip (rev asms')
   871       | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
   872       | elim_neq (neq :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
   873       | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
   874       (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
   875         SOME spl =>
   876           let val (ct1, ct2) = extract (cprop_of spl)
   877               val thm1 = Thm.assume ct1
   878               val thm2 = Thm.assume ct2
   879           in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
   880             ct2, elim_neq neqs (asms', asms@[thm2]))
   881           end
   882       | NONE => elim_neq neqs (asm::asms', asms))
   883 in elim_neq neqE ([], asms) end;
   884 
   885 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
   886   | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
   887       let
   888         val (thm1, js1) = fwdproof ss tree1 js
   889         val (thm2, js2) = fwdproof ss tree2 js1
   890         val thm1' = Thm.implies_intr ct1 thm1
   891         val thm2' = Thm.implies_intr ct2 thm2
   892       in (thm2' COMP (thm1' COMP thm), js2) end;
   893       (* FIXME needs handle THM _ => NONE ? *)
   894 
   895 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
   896   let
   897     val ctxt = Simplifier.the_context ss
   898     val thy = Proof_Context.theory_of ctxt
   899     val nTconcl = LA_Logic.neg_prop Tconcl
   900     val cnTconcl = cterm_of thy nTconcl
   901     val nTconclthm = Thm.assume cnTconcl
   902     val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
   903     val (Falsethm, _) = fwdproof ss tree js
   904     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   905     val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
   906   in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
   907   (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
   908   handle THM _ => NONE;
   909 
   910 (* PRE: concl is not negated!
   911    This assumption is OK because
   912    1. lin_arith_simproc tries both to prove and disprove concl and
   913    2. lin_arith_simproc is applied by the Simplifier which
   914       dives into terms and will thus try the non-negated concl anyway.
   915 *)
   916 fun lin_arith_simproc ss concl =
   917   let
   918     val ctxt = Simplifier.the_context ss
   919     val thms = maps LA_Logic.atomize (Simplifier.prems_of ss)
   920     val Hs = map Thm.prop_of thms
   921     val Tconcl = LA_Logic.mk_Trueprop concl
   922   in
   923     case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
   924       (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
   925     | (_, NONE) =>
   926         let val nTconcl = LA_Logic.neg_prop Tconcl in
   927           case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
   928             (split_neq, SOME js) => prover ss thms nTconcl js split_neq false
   929           | (_, NONE) => NONE
   930         end
   931   end;
   932 
   933 end;