src/Provers/Arith/fast_lin_arith.ML
author haftmann
Tue Oct 20 16:13:01 2009 +0200 (2009-10-20)
changeset 33037 b22e44496dc2
parent 33002 f3f02f36a3e2
child 33038 8f9594c31de4
permissions -rw-r--r--
replaced old_style infixes eq_set, subset, union, inter and variants by generic versions
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
     4 
     5 A generic linear arithmetic package.  It provides two tactics
     6 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
     7 (lin_arith_simproc).
     8 
     9 Only take premises and conclusions into account that are already
    10 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
    11 the term.
    12 *)
    13 
    14 (*** Data needed for setting up the linear arithmetic package ***)
    15 
    16 signature LIN_ARITH_LOGIC =
    17 sig
    18   val conjI       : thm  (* P ==> Q ==> P & Q *)
    19   val ccontr      : thm  (* (~ P ==> False) ==> P *)
    20   val notI        : thm  (* (P ==> False) ==> ~ P *)
    21   val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
    22   val not_leD     : thm  (* ~(m <= n) ==> n < m *)
    23   val sym         : thm  (* x = y ==> y = x *)
    24   val trueI       : thm  (* True *)
    25   val mk_Eq       : thm -> thm
    26   val atomize     : thm -> thm list
    27   val mk_Trueprop : term -> term
    28   val neg_prop    : term -> term
    29   val is_False    : thm -> bool
    30   val is_nat      : typ list * term -> bool
    31   val mk_nat_thm  : theory -> term -> thm
    32 end;
    33 (*
    34 mk_Eq(~in) = `in == False'
    35 mk_Eq(in) = `in == True'
    36 where `in' is an (in)equality.
    37 
    38 neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
    39   (logically) negated version of t (again wrapped up in Trueprop),
    40   where the negation of a negative term is the term itself (no
    41   double negation!); raises TERM ("neg_prop", [t]) if t is not of
    42   the form 'Trueprop $ _'
    43 
    44 is_nat(parameter-types,t) =  t:nat
    45 mk_nat_thm(t) = "0 <= t"
    46 *)
    47 
    48 signature LIN_ARITH_DATA =
    49 sig
    50   (*internal representation of linear (in-)equations:*)
    51   type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
    52   val decomp: Proof.context -> term -> decomp option
    53   val domain_is_nat: term -> bool
    54 
    55   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    56   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    57 
    58   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    59   val pre_tac: Proof.context -> int -> tactic
    60 
    61   (*the limit on the number of ~= allowed; because each ~= is split
    62     into two cases, this can lead to an explosion*)
    63   val fast_arith_neq_limit: int Config.T
    64 end;
    65 (*
    66 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    67    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    68          p (q, respectively) is the decomposition of the sum term x
    69          (y, respectively) into a list of summand * multiplicity
    70          pairs and a constant summand and d indicates if the domain
    71          is discrete.
    72 
    73 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
    74 
    75 The relationship between pre_decomp and pre_tac is somewhat tricky.  The
    76 internal representation of a subgoal and the corresponding theorem must
    77 be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
    78 the comment for split_items below.  (This is even necessary for eta- and
    79 beta-equivalent modifications, as some of the lin. arith. code is not
    80 insensitive to them.)
    81 
    82 ss must reduce contradictory <= to False.
    83    It should also cancel common summands to keep <= reduced;
    84    otherwise <= can grow to massive proportions.
    85 *)
    86 
    87 signature FAST_LIN_ARITH =
    88 sig
    89   val cut_lin_arith_tac: simpset -> int -> tactic
    90   val lin_arith_tac: Proof.context -> bool -> int -> tactic
    91   val lin_arith_simproc: simpset -> term -> thm option
    92   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    93                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    94                  number_of : serial * (theory -> typ -> int -> cterm)}
    95                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    96                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    97                      number_of : serial * (theory -> typ -> int -> cterm)})
    98                 -> Context.generic -> Context.generic
    99   val trace: bool Unsynchronized.ref
   100   val warning_count: int Unsynchronized.ref;
   101 end;
   102 
   103 functor Fast_Lin_Arith
   104   (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
   105 struct
   106 
   107 
   108 (** theory data **)
   109 
   110 fun no_number_of _ _ _ = raise CTERM ("number_of", [])
   111 
   112 structure Data = GenericDataFun
   113 (
   114   type T =
   115    {add_mono_thms: thm list,
   116     mult_mono_thms: thm list,
   117     inj_thms: thm list,
   118     lessD: thm list,
   119     neqE: thm list,
   120     simpset: Simplifier.simpset,
   121     number_of : serial * (theory -> typ -> int -> cterm)};
   122 
   123   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   124                lessD = [], neqE = [], simpset = Simplifier.empty_ss,
   125                number_of = (serial (), no_number_of) } : T;
   126   val extend = I;
   127   fun merge _
   128     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
   129       lessD = lessD1, neqE=neqE1, simpset = simpset1,
   130       number_of = (number_of1 as (s1, _))},
   131      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
   132       lessD = lessD2, neqE=neqE2, simpset = simpset2,
   133       number_of = (number_of2 as (s2, _))}) =
   134     {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
   135      mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
   136      inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
   137      lessD = Thm.merge_thms (lessD1, lessD2),
   138      neqE = Thm.merge_thms (neqE1, neqE2),
   139      simpset = Simplifier.merge_ss (simpset1, simpset2),
   140      number_of = if s1 > s2 then number_of1 else number_of2};
   141 );
   142 
   143 val map_data = Data.map;
   144 val get_data = Data.get o Context.Proof;
   145 
   146 
   147 
   148 (*** A fast decision procedure ***)
   149 (*** Code ported from HOL Light ***)
   150 (* possible optimizations:
   151    use (var,coeff) rep or vector rep  tp save space;
   152    treat non-negative atoms separately rather than adding 0 <= atom
   153 *)
   154 
   155 val trace = Unsynchronized.ref false;
   156 
   157 datatype lineq_type = Eq | Le | Lt;
   158 
   159 datatype injust = Asm of int
   160                 | Nat of int (* index of atom *)
   161                 | LessD of injust
   162                 | NotLessD of injust
   163                 | NotLeD of injust
   164                 | NotLeDD of injust
   165                 | Multiplied of int * injust
   166                 | Added of injust * injust;
   167 
   168 datatype lineq = Lineq of int * lineq_type * int list * injust;
   169 
   170 (* ------------------------------------------------------------------------- *)
   171 (* Finding a (counter) example from the trace of a failed elimination        *)
   172 (* ------------------------------------------------------------------------- *)
   173 (* Examples are represented as rational numbers,                             *)
   174 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   175 
   176 exception NoEx;
   177 
   178 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   179    In general, true means the bound is included, false means it is excluded.
   180    Need to know if it is a lower or upper bound for unambiguous interpretation!
   181 *)
   182 
   183 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   184   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   185   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   186 
   187 (* PRE: ex[v] must be 0! *)
   188 fun eval ex v (a, le, cs) =
   189   let
   190     val rs = map Rat.rat_of_int cs;
   191     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   192   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   193 (* If nth rs v < 0, le should be negated.
   194    Instead this swap is taken into account in ratrelmin2.
   195 *)
   196 
   197 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
   198   case Rat.ord (r, s)
   199    of EQUAL => (r, (not ler) andalso (not les))
   200     | LESS => x
   201     | GREATER => y;
   202 
   203 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
   204   case Rat.ord (r, s)
   205    of EQUAL => (r, ler andalso les)
   206     | LESS => y
   207     | GREATER => x;
   208 
   209 val ratrelmin = foldr1 ratrelmin2;
   210 val ratrelmax = foldr1 ratrelmax2;
   211 
   212 fun ratexact up (r, exact) =
   213   if exact then r else
   214   let
   215     val (p, q) = Rat.quotient_of_rat r;
   216     val nth = Rat.inv (Rat.rat_of_int q);
   217   in Rat.add r (if up then nth else Rat.neg nth) end;
   218 
   219 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
   220 
   221 fun choose2 d ((lb, exactl), (ub, exactu)) =
   222   let val ord = Rat.sign lb in
   223   if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
   224     then Rat.zero
   225     else if not d then
   226       if ord = GREATER
   227         then if exactl then lb else ratmiddle (lb, ub)
   228         else if exactu then ub else ratmiddle (lb, ub)
   229       else (*discrete domain, both bounds must be exact*)
   230       if ord = GREATER
   231         then let val lb' = Rat.roundup lb in
   232           if Rat.le lb' ub then lb' else raise NoEx end
   233         else let val ub' = Rat.rounddown ub in
   234           if Rat.le lb ub' then ub' else raise NoEx end
   235   end;
   236 
   237 fun findex1 discr (v, lineqs) ex =
   238   let
   239     val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
   240     val ineqs = maps elim_eqns nz;
   241     val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
   242     val lb = ratrelmax (map (eval ex v) ge)
   243     val ub = ratrelmin (map (eval ex v) le)
   244   in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
   245 
   246 fun elim1 v x =
   247   map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
   248                         nth_map v (K Rat.zero) bs));
   249 
   250 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
   251  of [x] => x =/ nth cs v
   252   | _ => false;
   253 
   254 (* The base case:
   255    all variables occur only with positive or only with negative coefficients *)
   256 fun pick_vars discr (ineqs,ex) =
   257   let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
   258   in case nz of [] => ex
   259      | (_,_,cs) :: _ =>
   260        let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
   261            val d = nth discr v;
   262            val pos = not (Rat.sign (nth cs v) = LESS);
   263            val sv = filter (single_var v) nz;
   264            val minmax =
   265              if pos then if d then Rat.roundup o fst o ratrelmax
   266                          else ratexact true o ratrelmax
   267                     else if d then Rat.rounddown o fst o ratrelmin
   268                          else ratexact false o ratrelmin
   269            val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
   270            val x = minmax((Rat.zero,if pos then true else false)::bnds)
   271            val ineqs' = elim1 v x nz
   272            val ex' = nth_map v (K x) ex
   273        in pick_vars discr (ineqs',ex') end
   274   end;
   275 
   276 fun findex0 discr n lineqs =
   277   let val ineqs = maps elim_eqns lineqs
   278       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
   279                        ineqs
   280   in pick_vars discr (rineqs,replicate n Rat.zero) end;
   281 
   282 (* ------------------------------------------------------------------------- *)
   283 (* End of counterexample finder. The actual decision procedure starts here.  *)
   284 (* ------------------------------------------------------------------------- *)
   285 
   286 (* ------------------------------------------------------------------------- *)
   287 (* Calculate new (in)equality type after addition.                           *)
   288 (* ------------------------------------------------------------------------- *)
   289 
   290 fun find_add_type(Eq,x) = x
   291   | find_add_type(x,Eq) = x
   292   | find_add_type(_,Lt) = Lt
   293   | find_add_type(Lt,_) = Lt
   294   | find_add_type(Le,Le) = Le;
   295 
   296 (* ------------------------------------------------------------------------- *)
   297 (* Multiply out an (in)equation.                                             *)
   298 (* ------------------------------------------------------------------------- *)
   299 
   300 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   301   if n = 1 then i
   302   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
   303   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
   304   else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
   305 
   306 (* ------------------------------------------------------------------------- *)
   307 (* Add together (in)equations.                                               *)
   308 (* ------------------------------------------------------------------------- *)
   309 
   310 fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   311   let val l = map2 Integer.add l1 l2
   312   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   313 
   314 (* ------------------------------------------------------------------------- *)
   315 (* Elimination of variable between a single pair of (in)equations.           *)
   316 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   317 (* ------------------------------------------------------------------------- *)
   318 
   319 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   320   let val c1 = nth l1 v and c2 = nth l2 v
   321       val m = Integer.lcm (abs c1) (abs c2)
   322       val m1 = m div (abs c1) and m2 = m div (abs c2)
   323       val (n1,n2) =
   324         if (c1 >= 0) = (c2 >= 0)
   325         then if ty1 = Eq then (~m1,m2)
   326              else if ty2 = Eq then (m1,~m2)
   327                   else sys_error "elim_var"
   328         else (m1,m2)
   329       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   330                     then (~n1,~n2) else (n1,n2)
   331   in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
   332 
   333 (* ------------------------------------------------------------------------- *)
   334 (* The main refutation-finding code.                                         *)
   335 (* ------------------------------------------------------------------------- *)
   336 
   337 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   338 
   339 fun is_contradictory (ans as Lineq(k,ty,l,_)) =
   340   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   341 
   342 fun calc_blowup l =
   343   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
   344   in length p * length n end;
   345 
   346 (* ------------------------------------------------------------------------- *)
   347 (* Main elimination code:                                                    *)
   348 (*                                                                           *)
   349 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   350 (*                                                                           *)
   351 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   352 (* coefficient in any of them, and uses it to eliminate.                     *)
   353 (*                                                                           *)
   354 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   355 (* blowup (number of consequences generated) and eliminates it.              *)
   356 (* ------------------------------------------------------------------------- *)
   357 
   358 fun extract_first p =
   359   let
   360     fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
   361       | extract xs [] = raise Empty
   362   in extract [] end;
   363 
   364 fun print_ineqs ineqs =
   365   if !trace then
   366      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   367        string_of_int c ^
   368        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   369        commas(map string_of_int l)) ineqs))
   370   else ();
   371 
   372 type history = (int * lineq list) list;
   373 datatype result = Success of injust | Failure of history;
   374 
   375 fun elim (ineqs, hist) =
   376   let val _ = print_ineqs ineqs
   377       val (triv, nontriv) = List.partition is_trivial ineqs in
   378   if not (null triv)
   379   then case Library.find_first is_contradictory triv of
   380          NONE => elim (nontriv, hist)
   381        | SOME(Lineq(_,_,_,j)) => Success j
   382   else
   383   if null nontriv then Failure hist
   384   else
   385   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   386   if not (null eqs) then
   387      let val c =
   388            fold (fn Lineq(_,_,l,_) => fn cs => gen_union (op =) (l, cs)) eqs []
   389            |> filter (fn i => i <> 0)
   390            |> sort (int_ord o pairself abs)
   391            |> hd
   392          val (eq as Lineq(_,_,ceq,_),othereqs) =
   393                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   394          val v = find_index (fn v => v = c) ceq
   395          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
   396                                      (othereqs @ noneqs)
   397          val others = map (elim_var v eq) roth @ ioth
   398      in elim(others,(v,nontriv)::hist) end
   399   else
   400   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   401       val numlist = 0 upto (length (hd lists) - 1)
   402       val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
   403       val blows = map calc_blowup coeffs
   404       val iblows = blows ~~ numlist
   405       val nziblows = filter_out (fn (i, _) => i = 0) iblows
   406   in if null nziblows then Failure((~1,nontriv)::hist)
   407      else
   408      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   409          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
   410          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
   411      in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
   412   end
   413   end
   414   end;
   415 
   416 (* ------------------------------------------------------------------------- *)
   417 (* Translate back a proof.                                                   *)
   418 (* ------------------------------------------------------------------------- *)
   419 
   420 fun trace_thm ctxt msg th =
   421   (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
   422 
   423 fun trace_term ctxt msg t =
   424   (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
   425 
   426 fun trace_msg msg =
   427   if !trace then tracing msg else ();
   428 
   429 val warning_count = Unsynchronized.ref 0;
   430 val warning_count_max = 10;
   431 
   432 val union_term = curry (gen_union Pattern.aeconv);
   433 val union_bterm = curry (gen_union
   434   (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t')));
   435 
   436 fun add_atoms (lhs, _, _, rhs, _, _) =
   437   union_term (map fst lhs) o union_term (map fst rhs);
   438 
   439 fun atoms_of ds = fold add_atoms ds [];
   440 
   441 (*
   442 Simplification may detect a contradiction 'prematurely' due to type
   443 information: n+1 <= 0 is simplified to False and does not need to be crossed
   444 with 0 <= n.
   445 *)
   446 local
   447   exception FalseE of thm
   448 in
   449 
   450 fun mkthm ss asms (just: injust) =
   451   let
   452     val ctxt = Simplifier.the_context ss;
   453     val thy = ProofContext.theory_of ctxt;
   454     val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
   455       number_of = (_, num_of), ...} = get_data ctxt;
   456     val simpset' = Simplifier.inherit_context ss simpset;
   457     fun only_concl f thm =
   458       if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
   459     val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
   460 
   461     fun use_first rules thm =
   462       get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
   463 
   464     fun add2 thm1 thm2 =
   465       use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
   466     fun try_add thms thm = get_first (fn th => add2 th thm) thms;
   467 
   468     fun add_thms thm1 thm2 =
   469       (case add2 thm1 thm2 of
   470         NONE =>
   471           (case try_add ([thm1] RL inj_thms) thm2 of
   472             NONE =>
   473               (the (try_add ([thm2] RL inj_thms) thm1)
   474                 handle Option =>
   475                   (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
   476                    sys_error "Linear arithmetic: failed to add thms"))
   477           | SOME thm => thm)
   478       | SOME thm => thm);
   479 
   480     fun mult_by_add n thm =
   481       let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
   482       in mul n thm end;
   483 
   484     val rewr = Simplifier.rewrite simpset';
   485     val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
   486       (Conv.binop_conv rewr)));
   487     fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
   488       let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
   489       in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
   490 
   491     fun mult n thm =
   492       (case use_first mult_mono_thms thm of
   493         NONE => mult_by_add n thm
   494       | SOME mth =>
   495           let
   496             val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
   497               |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
   498             val T = #T (Thm.rep_cterm cv)
   499           in
   500             mth
   501             |> Thm.instantiate ([], [(cv, num_of thy T n)])
   502             |> rewrite_concl
   503             |> discharge_prem
   504             handle CTERM _ => mult_by_add n thm
   505                  | THM _ => mult_by_add n thm
   506           end);
   507 
   508     fun mult_thm (n, thm) =
   509       if n = ~1 then thm RS LA_Logic.sym
   510       else if n < 0 then mult (~n) thm RS LA_Logic.sym
   511       else mult n thm;
   512 
   513     fun simp thm =
   514       let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
   515       in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
   516 
   517     fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
   518       | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
   519       | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
   520       | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
   521       | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   522       | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
   523       | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
   524       | mk (Multiplied (n, j)) =
   525           (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
   526 
   527   in
   528     let
   529       val _ = trace_msg "mkthm";
   530       val thm = trace_thm ctxt "Final thm:" (mk just);
   531       val fls = simplify simpset' thm;
   532       val _ = trace_thm ctxt "After simplification:" fls;
   533       val _ =
   534         if LA_Logic.is_False fls then ()
   535         else
   536           let val count = CRITICAL (fn () => Unsynchronized.inc warning_count) in
   537             if count > warning_count_max then ()
   538             else
   539               (tracing (cat_lines
   540                 (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
   541                  ["Proved:", Display.string_of_thm ctxt fls, ""] @
   542                  (if count <> warning_count_max then []
   543                   else ["\n(Reached maximal message count -- disabling future warnings)"])));
   544                 warning "Linear arithmetic should have refuted the assumptions.\n\
   545                   \Please inform Tobias Nipkow (nipkow@in.tum.de).")
   546           end;
   547     in fls end
   548     handle FalseE thm => trace_thm ctxt "False reached early:" thm
   549   end;
   550 
   551 end;
   552 
   553 fun coeff poly atom =
   554   AList.lookup Pattern.aeconv poly atom |> the_default 0;
   555 
   556 fun integ(rlhs,r,rel,rrhs,s,d) =
   557 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   558     val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   559     fun mult(t,r) =
   560         let val (i,j) = Rat.quotient_of_rat r
   561         in (t,i * (m div j)) end
   562 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   563 
   564 fun mklineq n atoms =
   565   fn (item, k) =>
   566   let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
   567       val lhsa = map (coeff lhs) atoms
   568       and rhsa = map (coeff rhs) atoms
   569       val diff = map2 (curry (op -)) rhsa lhsa
   570       val c = i-j
   571       val just = Asm k
   572       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
   573   in case rel of
   574       "<="   => lineq(c,Le,diff,just)
   575      | "~<=" => if discrete
   576                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   577                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   578      | "<"   => if discrete
   579                 then lineq(c+1,Le,diff,LessD(just))
   580                 else lineq(c,Lt,diff,just)
   581      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   582      | "="   => lineq(c,Eq,diff,just)
   583      | _     => sys_error("mklineq" ^ rel)
   584   end;
   585 
   586 (* ------------------------------------------------------------------------- *)
   587 (* Print (counter) example                                                   *)
   588 (* ------------------------------------------------------------------------- *)
   589 
   590 fun print_atom((a,d),r) =
   591   let val (p,q) = Rat.quotient_of_rat r
   592       val s = if d then string_of_int p else
   593               if p = 0 then "0"
   594               else string_of_int p ^ "/" ^ string_of_int q
   595   in a ^ " = " ^ s end;
   596 
   597 fun produce_ex sds =
   598   curry (op ~~) sds
   599   #> map print_atom
   600   #> commas
   601   #> curry (op ^) "Counterexample (possibly spurious):\n";
   602 
   603 fun trace_ex ctxt params atoms discr n (hist: history) =
   604   case hist of
   605     [] => ()
   606   | (v, lineqs) :: hist' =>
   607       let
   608         val frees = map Free params
   609         fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
   610         val start =
   611           if v = ~1 then (hist', findex0 discr n lineqs)
   612           else (hist, replicate n Rat.zero)
   613         val ex = SOME (produce_ex (map show_term atoms ~~ discr)
   614             (uncurry (fold (findex1 discr)) start))
   615           handle NoEx => NONE
   616       in
   617         case ex of
   618           SOME s => (warning "Linear arithmetic failed - see trace for a counterexample."; tracing s)
   619         | NONE => warning "Linear arithmetic failed"
   620       end;
   621 
   622 (* ------------------------------------------------------------------------- *)
   623 
   624 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
   625   if LA_Logic.is_nat (pTs, atom)
   626   then let val l = map (fn j => if j=i then 1 else 0) ixs
   627        in SOME (Lineq (0, Le, l, Nat i)) end
   628   else NONE;
   629 
   630 (* This code is tricky. It takes a list of premises in the order they occur
   631 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   632 ones as NONE. Going through the premises, each numeric one is converted into
   633 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   634 >. Thus split_items returns a list of equation systems. This may blow up if
   635 there are many ~=, but in practice it does not seem to happen. The really
   636 tricky bit is to arrange the order of the cases such that they coincide with
   637 the order in which the cases are in the end generated by the tactic that
   638 applies the generated refutation thms (see function 'refute_tac').
   639 
   640 For variables n of type nat, a constraint 0 <= n is added.
   641 *)
   642 
   643 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
   644 (*        could be intertwined: separate the first (fully split) case,       *)
   645 (*        refute it, continue with splitting and refuting.  Terminate with   *)
   646 (*        failure as soon as a case could not be refuted; i.e. delay further *)
   647 (*        splitting until after a refutation for other cases has been found. *)
   648 
   649 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
   650 let
   651   (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
   652   (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
   653   (* level                                                          *)
   654   (* FIXME: this is currently sensitive to the order of theorems in *)
   655   (*        neqE:  The theorem for type "nat" must come first.  A   *)
   656   (*        better (i.e. less likely to break when neqE changes)    *)
   657   (*        implementation should *test* which theorem from neqE    *)
   658   (*        can be applied, and split the premise accordingly.      *)
   659   fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   660                (LA_Data.decomp option * bool) list list =
   661   let
   662     fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
   663                   (LA_Data.decomp option * bool) list list =
   664           [[]]
   665       | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
   666           map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
   667       | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
   668           if rel = "~=" andalso (not nat_only orelse is_nat) then
   669             (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
   670             elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
   671             elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
   672           else
   673             map (cons ineq) (elim_neq' nat_only ineqs)
   674   in
   675     ineqs |> elim_neq' true
   676           |> maps (elim_neq' false)
   677   end
   678 
   679   fun ignore_neq (NONE, bool) = (NONE, bool)
   680     | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
   681       if rel = "~=" then (NONE, bool) else (ineq, bool)
   682 
   683   fun number_hyps _ []             = []
   684     | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
   685     | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
   686 
   687   val result = (Ts, terms)
   688     |> (* user-defined preprocessing of the subgoal *)
   689        (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
   690     |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
   691          string_of_int (length subgoals) ^ " subgoal(s) total."))
   692     |> (* produce the internal encoding of (in-)equalities *)
   693        map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
   694     |> (* splitting of inequalities *)
   695        map (apsnd (if split_neq then elim_neq else
   696                      Library.single o map ignore_neq))
   697     |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
   698     |> (* numbering of hypotheses, ignoring irrelevant ones *)
   699        map (apsnd (number_hyps 0))
   700 in
   701   trace_msg ("Splitting of inequalities yields " ^
   702     string_of_int (length result) ^ " subgoal(s) total.");
   703   result
   704 end;
   705 
   706 fun add_datoms (dats : (bool * term) list, ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _)) :
   707   (bool * term) list =
   708   union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
   709 
   710 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
   711   map fst (Library.foldl add_datoms ([],initems));
   712 
   713 fun refutes ctxt params show_ex :
   714     (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
   715   let
   716     fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
   717           let
   718             val atoms = atoms_of (map fst initems)
   719             val n = length atoms
   720             val mkleq = mklineq n atoms
   721             val ixs = 0 upto (n - 1)
   722             val iatoms = atoms ~~ ixs
   723             val natlineqs = map_filter (mknat Ts ixs) iatoms
   724             val ineqs = map mkleq initems @ natlineqs
   725           in case elim (ineqs, []) of
   726                Success j =>
   727                  (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
   728                   refute initemss (js @ [j]))
   729              | Failure hist =>
   730                  (if not show_ex then ()
   731                   else
   732                     let
   733                       val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
   734                       val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
   735                         (Name.invents (Variable.names_of ctxt') Name.uu (length Ts - length params))
   736                       val params' = (more_names @ param_names) ~~ Ts
   737                     in
   738                       trace_ex ctxt'' params' atoms (discr initems) n hist
   739                     end; NONE)
   740           end
   741       | refute [] js = SOME js
   742   in refute end;
   743 
   744 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
   745   refutes ctxt params show_ex (split_items ctxt do_pre split_neq
   746     (map snd params, terms)) [];
   747 
   748 fun count P xs = length (filter P xs);
   749 
   750 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
   751   let
   752     val _ = trace_msg "prove:"
   753     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
   754     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
   755     (* theorem/tactic level                                             *)
   756     val Hs' = Hs @ [LA_Logic.neg_prop concl]
   757     fun is_neq NONE                 = false
   758       | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
   759     val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
   760     val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
   761   in
   762     if split_neq then ()
   763     else
   764       trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
   765         string_of_int neq_limit ^ "), ignoring all inequalities");
   766     (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
   767   end handle TERM ("neg_prop", _) =>
   768     (* since no meta-logic negation is available, we can only fail if   *)
   769     (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
   770     (* dropping the conclusion doesn't work either, because even        *)
   771     (* 'False' does not imply arbitrary 'concl::prop')                  *)
   772     (trace_msg "prove failed (cannot negate conclusion).";
   773       (false, NONE));
   774 
   775 fun refute_tac ss (i, split_neq, justs) =
   776   fn state =>
   777     let
   778       val ctxt = Simplifier.the_context ss;
   779       val _ = trace_thm ctxt
   780         ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
   781           string_of_int (length justs) ^ " justification(s)):") state
   782       val {neqE, ...} = get_data ctxt;
   783       fun just1 j =
   784         (* eliminate inequalities *)
   785         (if split_neq then
   786           REPEAT_DETERM (eresolve_tac neqE i)
   787         else
   788           all_tac) THEN
   789           PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
   790           (* use theorems generated from the actual justifications *)
   791           Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
   792     in
   793       (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
   794       DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
   795       (* user-defined preprocessing of the subgoal *)
   796       DETERM (LA_Data.pre_tac ctxt i) THEN
   797       PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
   798       (* prove every resulting subgoal, using its justification *)
   799       EVERY (map just1 justs)
   800     end  state;
   801 
   802 (*
   803 Fast but very incomplete decider. Only premises and conclusions
   804 that are already (negated) (in)equations are taken into account.
   805 *)
   806 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
   807   let
   808     val ctxt = Simplifier.the_context ss
   809     val params = rev (Logic.strip_params A)
   810     val Hs = Logic.strip_assums_hyp A
   811     val concl = Logic.strip_assums_concl A
   812     val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
   813   in
   814     case prove ctxt params show_ex true Hs concl of
   815       (_, NONE) => (trace_msg "Refutation failed."; no_tac)
   816     | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
   817                                refute_tac ss (i, split_neq, js))
   818   end);
   819 
   820 fun cut_lin_arith_tac ss =
   821   cut_facts_tac (Simplifier.prems_of_ss ss) THEN'
   822   simpset_lin_arith_tac ss false;
   823 
   824 fun lin_arith_tac ctxt =
   825   simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
   826 
   827 
   828 
   829 (** Forward proof from theorems **)
   830 
   831 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   832 to splits of ~= premises) such that it coincides with the order of the cases
   833 generated by function split_items. *)
   834 
   835 datatype splittree = Tip of thm list
   836                    | Spl of thm * cterm * splittree * cterm * splittree;
   837 
   838 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
   839 
   840 fun extract (imp : cterm) : cterm * cterm =
   841 let val (Il, r)    = Thm.dest_comb imp
   842     val (_, imp1)  = Thm.dest_comb Il
   843     val (Ict1, _)  = Thm.dest_comb imp1
   844     val (_, ct1)   = Thm.dest_comb Ict1
   845     val (Ir, _)    = Thm.dest_comb r
   846     val (_, Ict2r) = Thm.dest_comb Ir
   847     val (Ict2, _)  = Thm.dest_comb Ict2r
   848     val (_, ct2)   = Thm.dest_comb Ict2
   849 in (ct1, ct2) end;
   850 
   851 fun splitasms ctxt (asms : thm list) : splittree =
   852 let val {neqE, ...} = get_data ctxt
   853     fun elim_neq (asms', []) = Tip (rev asms')
   854       | elim_neq (asms', asm::asms) =
   855       (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) neqE of
   856         SOME spl =>
   857           let val (ct1, ct2) = extract (cprop_of spl)
   858               val thm1 = assume ct1
   859               val thm2 = assume ct2
   860           in Spl (spl, ct1, elim_neq (asms', asms@[thm1]), ct2, elim_neq (asms', asms@[thm2]))
   861           end
   862       | NONE => elim_neq (asm::asms', asms))
   863 in elim_neq ([], asms) end;
   864 
   865 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
   866   | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
   867       let
   868         val (thm1, js1) = fwdproof ss tree1 js
   869         val (thm2, js2) = fwdproof ss tree2 js1
   870         val thm1' = implies_intr ct1 thm1
   871         val thm2' = implies_intr ct2 thm2
   872       in (thm2' COMP (thm1' COMP thm), js2) end;
   873       (* FIXME needs handle THM _ => NONE ? *)
   874 
   875 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
   876   let
   877     val ctxt = Simplifier.the_context ss
   878     val thy = ProofContext.theory_of ctxt
   879     val nTconcl = LA_Logic.neg_prop Tconcl
   880     val cnTconcl = cterm_of thy nTconcl
   881     val nTconclthm = assume cnTconcl
   882     val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
   883     val (Falsethm, _) = fwdproof ss tree js
   884     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   885     val concl = implies_intr cnTconcl Falsethm COMP contr
   886   in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
   887   (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
   888   handle THM _ => NONE;
   889 
   890 (* PRE: concl is not negated!
   891    This assumption is OK because
   892    1. lin_arith_simproc tries both to prove and disprove concl and
   893    2. lin_arith_simproc is applied by the Simplifier which
   894       dives into terms and will thus try the non-negated concl anyway.
   895 *)
   896 fun lin_arith_simproc ss concl =
   897   let
   898     val ctxt = Simplifier.the_context ss
   899     val thms = maps LA_Logic.atomize (Simplifier.prems_of_ss ss)
   900     val Hs = map Thm.prop_of thms
   901     val Tconcl = LA_Logic.mk_Trueprop concl
   902   in
   903     case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
   904       (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
   905     | (_, NONE) =>
   906         let val nTconcl = LA_Logic.neg_prop Tconcl in
   907           case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
   908             (split_neq, SOME js) => prover ss thms nTconcl js split_neq false
   909           | (_, NONE) => NONE
   910         end
   911   end;
   912 
   913 end;