src/Provers/Arith/fast_lin_arith.ML
author wenzelm
Sat Nov 04 15:24:40 2017 +0100 (20 months ago)
changeset 67003 49850a679c2c
parent 66035 de6cd60b1226
child 67649 1e1782c1aedf
permissions -rw-r--r--
more robust sorted_entries;
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
     3 
     4 A generic linear arithmetic package.
     5 
     6 Only take premises and conclusions into account that are already
     7 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
     8 the term.
     9 *)
    10 
    11 (*** Data needed for setting up the linear arithmetic package ***)
    12 
    13 signature LIN_ARITH_LOGIC =
    14 sig
    15   val conjI       : thm  (* P ==> Q ==> P & Q *)
    16   val ccontr      : thm  (* (~ P ==> False) ==> P *)
    17   val notI        : thm  (* (P ==> False) ==> ~ P *)
    18   val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
    19   val not_leD     : thm  (* ~(m <= n) ==> n < m *)
    20   val sym         : thm  (* x = y ==> y = x *)
    21   val trueI       : thm  (* True *)
    22   val mk_Eq       : thm -> thm
    23   val atomize     : thm -> thm list
    24   val mk_Trueprop : term -> term
    25   val neg_prop    : term -> term
    26   val is_False    : thm -> bool
    27   val is_nat      : typ list * term -> bool
    28   val mk_nat_thm  : theory -> term -> thm
    29 end;
    30 (*
    31 mk_Eq(~in) = `in == False'
    32 mk_Eq(in) = `in == True'
    33 where `in' is an (in)equality.
    34 
    35 neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
    36   (logically) negated version of t (again wrapped up in Trueprop),
    37   where the negation of a negative term is the term itself (no
    38   double negation!); raises TERM ("neg_prop", [t]) if t is not of
    39   the form 'Trueprop $ _'
    40 
    41 is_nat(parameter-types,t) =  t:nat
    42 mk_nat_thm(t) = "0 <= t"
    43 *)
    44 
    45 signature LIN_ARITH_DATA =
    46 sig
    47   (*internal representation of linear (in-)equations:*)
    48   type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
    49   val decomp: Proof.context -> term -> decomp option
    50   val domain_is_nat: term -> bool
    51 
    52   (*abstraction for proof replay*)
    53   val abstract_arith: term -> (term * term) list * Proof.context ->
    54     term * ((term * term) list * Proof.context)
    55   val abstract: term -> (term * term) list * Proof.context ->
    56     term * ((term * term) list * Proof.context)
    57 
    58   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    59   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    60 
    61   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    62   val pre_tac: Proof.context -> int -> tactic
    63 
    64   (*the limit on the number of ~= allowed; because each ~= is split
    65     into two cases, this can lead to an explosion*)
    66   val neq_limit: int Config.T
    67 
    68   val trace: bool Config.T
    69 end;
    70 (*
    71 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    72    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    73          p (q, respectively) is the decomposition of the sum term x
    74          (y, respectively) into a list of summand * multiplicity
    75          pairs and a constant summand and d indicates if the domain
    76          is discrete.
    77 
    78 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
    79 
    80 The relationship between pre_decomp and pre_tac is somewhat tricky.  The
    81 internal representation of a subgoal and the corresponding theorem must
    82 be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
    83 the comment for split_items below.  (This is even necessary for eta- and
    84 beta-equivalent modifications, as some of the lin. arith. code is not
    85 insensitive to them.)
    86 
    87 Simpset must reduce contradictory <= to False.
    88    It should also cancel common summands to keep <= reduced;
    89    otherwise <= can grow to massive proportions.
    90 *)
    91 
    92 signature FAST_LIN_ARITH =
    93 sig
    94   val prems_lin_arith_tac: Proof.context -> int -> tactic
    95   val lin_arith_tac: Proof.context -> int -> tactic
    96   val lin_arith_simproc: Proof.context -> cterm -> thm option
    97   val map_data:
    98     ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    99       lessD: thm list, neqE: thm list, simpset: simpset,
   100       number_of: (Proof.context -> typ -> int -> cterm) option} ->
   101      {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
   102       lessD: thm list, neqE: thm list, simpset: simpset,
   103       number_of: (Proof.context -> typ -> int -> cterm) option}) ->
   104       Context.generic -> Context.generic
   105   val add_inj_thms: thm list -> Context.generic -> Context.generic
   106   val add_lessD: thm -> Context.generic -> Context.generic
   107   val add_simps: thm list -> Context.generic -> Context.generic
   108   val add_simprocs: simproc list -> Context.generic -> Context.generic
   109   val set_number_of: (Proof.context -> typ -> int -> cterm) -> Context.generic -> Context.generic
   110 end;
   111 
   112 functor Fast_Lin_Arith
   113   (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
   114 struct
   115 
   116 
   117 (** theory data **)
   118 
   119 structure Data = Generic_Data
   120 (
   121   type T =
   122    {add_mono_thms: thm list,
   123     mult_mono_thms: thm list,
   124     inj_thms: thm list,
   125     lessD: thm list,
   126     neqE: thm list,
   127     simpset: simpset,
   128     number_of: (Proof.context -> typ -> int -> cterm) option};
   129 
   130   val empty : T =
   131    {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   132     lessD = [], neqE = [], simpset = empty_ss,
   133     number_of = NONE};
   134   val extend = I;
   135   fun merge
   136     ({add_mono_thms = add_mono_thms1, mult_mono_thms = mult_mono_thms1, inj_thms = inj_thms1,
   137       lessD = lessD1, neqE = neqE1, simpset = simpset1, number_of = number_of1},
   138      {add_mono_thms = add_mono_thms2, mult_mono_thms = mult_mono_thms2, inj_thms = inj_thms2,
   139       lessD = lessD2, neqE = neqE2, simpset = simpset2, number_of = number_of2}) : T =
   140     {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
   141      mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
   142      inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
   143      lessD = Thm.merge_thms (lessD1, lessD2),
   144      neqE = Thm.merge_thms (neqE1, neqE2),
   145      simpset = merge_ss (simpset1, simpset2),
   146      number_of = merge_options (number_of1, number_of2)};
   147 );
   148 
   149 val map_data = Data.map;
   150 val get_data = Data.get o Context.Proof;
   151 
   152 fun get_neqE ctxt = map (Thm.transfer (Proof_Context.theory_of ctxt)) (#neqE (get_data ctxt));
   153 
   154 fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   155   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
   156     lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   157 
   158 fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   159   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   160     lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   161 
   162 fun map_simpset f context =
   163   map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =>
   164     {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   165       lessD = lessD, neqE = neqE, simpset = simpset_map (Context.proof_of context) f simpset,
   166       number_of = number_of}) context;
   167 
   168 fun add_inj_thms thms = map_data (map_inj_thms (append (map Thm.trim_context thms)));
   169 fun add_lessD thm = map_data (map_lessD (fn thms => thms @ [Thm.trim_context thm]));
   170 fun add_simps thms = map_simpset (fn ctxt => ctxt addsimps thms);
   171 fun add_simprocs procs = map_simpset (fn ctxt => ctxt addsimprocs procs);
   172 
   173 fun set_number_of f =
   174   map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, ...} =>
   175    {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   176     lessD = lessD, neqE = neqE, simpset = simpset, number_of = SOME f});
   177 
   178 fun number_of ctxt =
   179   (case get_data ctxt of
   180     {number_of = SOME f, ...} => f ctxt
   181   | _ => fn _ => fn _ => raise CTERM ("number_of", []));
   182 
   183 
   184 
   185 (*** A fast decision procedure ***)
   186 (*** Code ported from HOL Light ***)
   187 (* possible optimizations:
   188    use (var,coeff) rep or vector rep  tp save space;
   189    treat non-negative atoms separately rather than adding 0 <= atom
   190 *)
   191 
   192 datatype lineq_type = Eq | Le | Lt;
   193 
   194 datatype injust = Asm of int
   195                 | Nat of int (* index of atom *)
   196                 | LessD of injust
   197                 | NotLessD of injust
   198                 | NotLeD of injust
   199                 | NotLeDD of injust
   200                 | Multiplied of int * injust
   201                 | Added of injust * injust;
   202 
   203 datatype lineq = Lineq of int * lineq_type * int list * injust;
   204 
   205 (* ------------------------------------------------------------------------- *)
   206 (* Finding a (counter) example from the trace of a failed elimination        *)
   207 (* ------------------------------------------------------------------------- *)
   208 (* Examples are represented as rational numbers,                             *)
   209 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   210 
   211 exception NoEx;
   212 
   213 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   214    In general, true means the bound is included, false means it is excluded.
   215    Need to know if it is a lower or upper bound for unambiguous interpretation!
   216 *)
   217 
   218 (* ------------------------------------------------------------------------- *)
   219 (* End of counterexample finder. The actual decision procedure starts here.  *)
   220 (* ------------------------------------------------------------------------- *)
   221 
   222 (* ------------------------------------------------------------------------- *)
   223 (* Calculate new (in)equality type after addition.                           *)
   224 (* ------------------------------------------------------------------------- *)
   225 
   226 fun find_add_type(Eq,x) = x
   227   | find_add_type(x,Eq) = x
   228   | find_add_type(_,Lt) = Lt
   229   | find_add_type(Lt,_) = Lt
   230   | find_add_type(Le,Le) = Le;
   231 
   232 (* ------------------------------------------------------------------------- *)
   233 (* Multiply out an (in)equation.                                             *)
   234 (* ------------------------------------------------------------------------- *)
   235 
   236 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   237   if n = 1 then i
   238   else if n = 0 andalso ty = Lt then raise Fail "multiply_ineq"
   239   else if n < 0 andalso (ty=Le orelse ty=Lt) then raise Fail "multiply_ineq"
   240   else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
   241 
   242 (* ------------------------------------------------------------------------- *)
   243 (* Add together (in)equations.                                               *)
   244 (* ------------------------------------------------------------------------- *)
   245 
   246 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
   247   let val l = map2 Integer.add l1 l2
   248   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   249 
   250 (* ------------------------------------------------------------------------- *)
   251 (* Elimination of variable between a single pair of (in)equations.           *)
   252 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   253 (* ------------------------------------------------------------------------- *)
   254 
   255 fun elim_var v (i1 as Lineq(_,ty1,l1,_)) (i2 as Lineq(_,ty2,l2,_)) =
   256   let val c1 = nth l1 v and c2 = nth l2 v
   257       val m = Integer.lcm c1 c2
   258       val m1 = m div (abs c1) and m2 = m div (abs c2)
   259       val (n1,n2) =
   260         if (c1 >= 0) = (c2 >= 0)
   261         then if ty1 = Eq then (~m1,m2)
   262              else if ty2 = Eq then (m1,~m2)
   263                   else raise Fail "elim_var"
   264         else (m1,m2)
   265       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   266                     then (~n1,~n2) else (n1,n2)
   267   in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
   268 
   269 (* ------------------------------------------------------------------------- *)
   270 (* The main refutation-finding code.                                         *)
   271 (* ------------------------------------------------------------------------- *)
   272 
   273 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   274 
   275 fun is_contradictory (Lineq(k,ty,_,_)) =
   276   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   277 
   278 fun calc_blowup l =
   279   let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
   280   in length p * length n end;
   281 
   282 (* ------------------------------------------------------------------------- *)
   283 (* Main elimination code:                                                    *)
   284 (*                                                                           *)
   285 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   286 (*                                                                           *)
   287 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   288 (* coefficient in any of them, and uses it to eliminate.                     *)
   289 (*                                                                           *)
   290 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   291 (* blowup (number of consequences generated) and eliminates it.              *)
   292 (* ------------------------------------------------------------------------- *)
   293 
   294 fun extract_first p =
   295   let
   296     fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
   297       | extract _ [] = raise List.Empty
   298   in extract [] end;
   299 
   300 fun print_ineqs ctxt ineqs =
   301   if Config.get ctxt LA_Data.trace then
   302      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   303        string_of_int c ^
   304        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   305        commas(map string_of_int l)) ineqs))
   306   else ();
   307 
   308 type history = (int * lineq list) list;
   309 datatype result = Success of injust | Failure of history;
   310 
   311 fun elim ctxt (ineqs, hist) =
   312   let val _ = print_ineqs ctxt ineqs
   313       val (triv, nontriv) = List.partition is_trivial ineqs in
   314   if not (null triv)
   315   then case find_first is_contradictory triv of
   316          NONE => elim ctxt (nontriv, hist)
   317        | SOME(Lineq(_,_,_,j)) => Success j
   318   else
   319   if null nontriv then Failure hist
   320   else
   321   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   322   if not (null eqs) then
   323      let val c =
   324            fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
   325            |> filter (fn i => i <> 0)
   326            |> sort (int_ord o apply2 abs)
   327            |> hd
   328          val (eq as Lineq(_,_,ceq,_),othereqs) =
   329                extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
   330          val v = find_index (fn v => v = c) ceq
   331          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
   332                                      (othereqs @ noneqs)
   333          val others = map (elim_var v eq) roth @ ioth
   334      in elim ctxt (others,(v,nontriv)::hist) end
   335   else
   336   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   337       val numlist = 0 upto (length (hd lists) - 1)
   338       val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
   339       val blows = map calc_blowup coeffs
   340       val iblows = blows ~~ numlist
   341       val nziblows = filter_out (fn (i, _) => i = 0) iblows
   342   in if null nziblows then Failure((~1,nontriv)::hist)
   343      else
   344      let val (_,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   345          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
   346          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
   347      in elim ctxt (no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
   348   end
   349   end
   350   end;
   351 
   352 (* ------------------------------------------------------------------------- *)
   353 (* Translate back a proof.                                                   *)
   354 (* ------------------------------------------------------------------------- *)
   355 
   356 fun trace_thm ctxt msgs th =
   357  (if Config.get ctxt LA_Data.trace
   358   then tracing (cat_lines (msgs @ [Thm.string_of_thm ctxt th]))
   359   else (); th);
   360 
   361 fun trace_term ctxt msgs t =
   362  (if Config.get ctxt LA_Data.trace
   363   then tracing (cat_lines (msgs @ [Syntax.string_of_term ctxt t]))
   364   else (); t);
   365 
   366 fun trace_msg ctxt msg =
   367   if Config.get ctxt LA_Data.trace then tracing msg else ();
   368 
   369 val union_term = union Envir.aeconv;
   370 
   371 fun add_atoms (lhs, _, _, rhs, _, _) =
   372   union_term (map fst lhs) o union_term (map fst rhs);
   373 
   374 fun atoms_of ds = fold add_atoms ds [];
   375 
   376 (*
   377 Simplification may detect a contradiction 'prematurely' due to type
   378 information: n+1 <= 0 is simplified to False and does not need to be crossed
   379 with 0 <= n.
   380 *)
   381 local
   382   exception FalseE of thm * (int * cterm) list * Proof.context
   383 in
   384 
   385 fun mkthm ctxt asms (just: injust) =
   386   let
   387     val thy = Proof_Context.theory_of ctxt;
   388     val {add_mono_thms = add_mono_thms0, mult_mono_thms = mult_mono_thms0,
   389       inj_thms = inj_thms0, lessD = lessD0, simpset, ...} = get_data ctxt;
   390     val add_mono_thms = map (Thm.transfer thy) add_mono_thms0;
   391     val mult_mono_thms = map (Thm.transfer thy) mult_mono_thms0;
   392     val inj_thms = map (Thm.transfer thy) inj_thms0;
   393     val lessD = map (Thm.transfer thy) lessD0;
   394 
   395     val number_of = number_of ctxt;
   396     val simpset_ctxt = put_simpset simpset ctxt;
   397     fun only_concl f thm =
   398       if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
   399     val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
   400 
   401     fun use_first rules thm =
   402       get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
   403 
   404     fun add2 thm1 thm2 =
   405       use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
   406     fun try_add thms thm = get_first (fn th => add2 th thm) thms;
   407 
   408     fun add_thms thm1 thm2 =
   409       (case add2 thm1 thm2 of
   410         NONE =>
   411           (case try_add ([thm1] RL inj_thms) thm2 of
   412             NONE =>
   413               (the (try_add ([thm2] RL inj_thms) thm1)
   414                 handle Option.Option =>
   415                   (trace_thm ctxt [] thm1; trace_thm ctxt [] thm2;
   416                    raise Fail "Linear arithmetic: failed to add thms"))
   417           | SOME thm => thm)
   418       | SOME thm => thm);
   419 
   420     fun mult_by_add n thm =
   421       let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
   422       in mul n thm end;
   423 
   424     val rewr = Simplifier.rewrite simpset_ctxt;
   425     val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
   426       (Conv.binop_conv rewr)));
   427     fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
   428       let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
   429       in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
   430 
   431     fun mult n thm =
   432       (case use_first mult_mono_thms thm of
   433         NONE => mult_by_add n thm
   434       | SOME mth =>
   435           let
   436             val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
   437               |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
   438             val T = Thm.typ_of_cterm cv
   439           in
   440             mth
   441             |> Thm.instantiate ([], [(dest_Var (Thm.term_of cv), number_of T n)])
   442             |> rewrite_concl
   443             |> discharge_prem
   444             handle CTERM _ => mult_by_add n thm
   445                  | THM _ => mult_by_add n thm
   446           end);
   447 
   448     fun mult_thm n thm =
   449       if n = ~1 then thm RS LA_Logic.sym
   450       else if n < 0 then mult (~n) thm RS LA_Logic.sym
   451       else mult n thm;
   452 
   453     fun simp thm (cx as (_, hyps, ctxt')) =
   454       let val thm' = trace_thm ctxt ["Simplified:"] (full_simplify simpset_ctxt thm)
   455       in if LA_Logic.is_False thm' then raise FalseE (thm', hyps, ctxt') else (thm', cx) end;
   456 
   457     fun abs_thm i (cx as (terms, hyps, ctxt)) =
   458       (case AList.lookup (op =) hyps i of
   459         SOME ct => (Thm.assume ct, cx)
   460       | NONE =>
   461           let
   462             val thm = nth asms i
   463             val (t, (terms', ctxt')) = LA_Data.abstract (Thm.prop_of thm) (terms, ctxt)
   464             val ct = Thm.cterm_of ctxt' t
   465           in (Thm.assume ct, (terms', (i, ct) :: hyps, ctxt')) end);
   466 
   467     fun nat_thm t (terms, hyps, ctxt) =
   468       let val (t', (terms', ctxt')) = LA_Data.abstract_arith t (terms, ctxt)
   469       in (LA_Logic.mk_nat_thm thy t', (terms', hyps, ctxt')) end;
   470 
   471     fun step0 msg (thm, cx) = (trace_thm ctxt [msg] thm, cx)
   472     fun step1 msg j f cx = mk j cx |>> f |>> trace_thm ctxt [msg]
   473     and step2 msg j1 j2 f cx = mk j1 cx ||>> mk j2 |>> f |>> trace_thm ctxt [msg]
   474 
   475     and mk (Asm i) cx = step0 ("Asm " ^ string_of_int i) (abs_thm i cx)
   476       | mk (Nat i) cx = step0 ("Nat " ^ string_of_int i) (nat_thm (nth atoms i) cx)
   477       | mk (LessD j) cx = step1 "L" j (fn thm => hd ([thm] RL lessD)) cx
   478       | mk (NotLeD j) cx = step1 "NLe" j (fn thm => thm RS LA_Logic.not_leD) cx
   479       | mk (NotLeDD j) cx = step1 "NLeD" j (fn thm => hd ([thm RS LA_Logic.not_leD] RL lessD)) cx
   480       | mk (NotLessD j) cx = step1 "NL" j (fn thm => thm RS LA_Logic.not_lessD) cx
   481       | mk (Added (j1, j2)) cx = step2 "+" j1 j2 (uncurry add_thms) cx |-> simp
   482       | mk (Multiplied (n, j)) cx =
   483           (trace_msg ctxt ("*" ^ string_of_int n); step1 "*" j (mult_thm n) cx)
   484 
   485     fun finish ctxt' hyps thm =
   486       thm
   487       |> fold_rev (Thm.implies_intr o snd) hyps
   488       |> singleton (Variable.export ctxt' ctxt)
   489       |> fold (fn (i, _) => fn thm => nth asms i RS thm) hyps
   490   in
   491     let
   492       val _ = trace_msg ctxt "mkthm";
   493       val (thm, (_, hyps, ctxt')) = mk just ([], [], ctxt);
   494       val _ = trace_thm ctxt ["Final thm:"] thm;
   495       val fls = simplify simpset_ctxt thm;
   496       val _ = trace_thm ctxt ["After simplification:"] fls;
   497       val _ =
   498         if LA_Logic.is_False fls then ()
   499         else
   500          (tracing (cat_lines
   501            (["Assumptions:"] @ map (Thm.string_of_thm ctxt) asms @ [""] @
   502             ["Proved:", Thm.string_of_thm ctxt fls, ""]));
   503           warning "Linear arithmetic should have refuted the assumptions.\n\
   504             \Please inform Tobias Nipkow.")
   505     in finish ctxt' hyps fls end
   506     handle FalseE (thm, hyps, ctxt') =>
   507       trace_thm ctxt ["False reached early:"] (finish ctxt' hyps thm)
   508   end;
   509 
   510 end;
   511 
   512 fun coeff poly atom =
   513   AList.lookup Envir.aeconv poly atom |> the_default 0;
   514 
   515 fun integ(rlhs,r,rel,rrhs,s,d) =
   516 let val (rn,rd) = Rat.dest r and (sn,sd) = Rat.dest s
   517     val m = Integer.lcms(map (snd o Rat.dest) (r :: s :: map snd rlhs @ map snd rrhs))
   518     fun mult(t,r) =
   519         let val (i,j) = Rat.dest r
   520         in (t,i * (m div j)) end
   521 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   522 
   523 fun mklineq atoms =
   524   fn (item, k) =>
   525   let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
   526       val lhsa = map (coeff lhs) atoms
   527       and rhsa = map (coeff rhs) atoms
   528       val diff = map2 (curry (op -)) rhsa lhsa
   529       val c = i-j
   530       val just = Asm k
   531       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
   532   in case rel of
   533       "<="   => lineq(c,Le,diff,just)
   534      | "~<=" => if discrete
   535                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   536                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   537      | "<"   => if discrete
   538                 then lineq(c+1,Le,diff,LessD(just))
   539                 else lineq(c,Lt,diff,just)
   540      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   541      | "="   => lineq(c,Eq,diff,just)
   542      | _     => raise Fail ("mklineq" ^ rel)
   543   end;
   544 
   545 (* ------------------------------------------------------------------------- *)
   546 (* Print (counter) example                                                   *)
   547 (* ------------------------------------------------------------------------- *)
   548 
   549 (* ------------------------------------------------------------------------- *)
   550 
   551 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
   552   if LA_Logic.is_nat (pTs, atom)
   553   then let val l = map (fn j => if j=i then 1 else 0) ixs
   554        in SOME (Lineq (0, Le, l, Nat i)) end
   555   else NONE;
   556 
   557 (* This code is tricky. It takes a list of premises in the order they occur
   558 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   559 ones as NONE. Going through the premises, each numeric one is converted into
   560 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   561 >. Thus split_items returns a list of equation systems. This may blow up if
   562 there are many ~=, but in practice it does not seem to happen. The really
   563 tricky bit is to arrange the order of the cases such that they coincide with
   564 the order in which the cases are in the end generated by the tactic that
   565 applies the generated refutation thms (see function 'refute_tac').
   566 
   567 For variables n of type nat, a constraint 0 <= n is added.
   568 *)
   569 
   570 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
   571 (*        could be intertwined: separate the first (fully split) case,       *)
   572 (*        refute it, continue with splitting and refuting.  Terminate with   *)
   573 (*        failure as soon as a case could not be refuted; i.e. delay further *)
   574 (*        splitting until after a refutation for other cases has been found. *)
   575 
   576 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
   577 let
   578   (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
   579   (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
   580   (* level                                                          *)
   581   (* FIXME: this is currently sensitive to the order of theorems in *)
   582   (*        neqE:  The theorem for type "nat" must come first.  A   *)
   583   (*        better (i.e. less likely to break when neqE changes)    *)
   584   (*        implementation should *test* which theorem from neqE    *)
   585   (*        can be applied, and split the premise accordingly.      *)
   586   fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   587                (LA_Data.decomp option * bool) list list =
   588   let
   589     fun elim_neq' _ ([] : (LA_Data.decomp option * bool) list) :
   590                   (LA_Data.decomp option * bool) list list =
   591           [[]]
   592       | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
   593           map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
   594       | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
   595           if rel = "~=" andalso (not nat_only orelse is_nat) then
   596             (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
   597             elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
   598             elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
   599           else
   600             map (cons ineq) (elim_neq' nat_only ineqs)
   601   in
   602     ineqs |> elim_neq' true
   603           |> maps (elim_neq' false)
   604   end
   605 
   606   fun ignore_neq (NONE, bool) = (NONE, bool)
   607     | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
   608       if rel = "~=" then (NONE, bool) else (ineq, bool)
   609 
   610   fun number_hyps _ []             = []
   611     | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
   612     | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
   613 
   614   val result = (Ts, terms)
   615     |> (* user-defined preprocessing of the subgoal *)
   616        (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
   617     |> tap (fn subgoals => trace_msg ctxt ("Preprocessing yields " ^
   618          string_of_int (length subgoals) ^ " subgoal(s) total."))
   619     |> (* produce the internal encoding of (in-)equalities *)
   620        map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
   621     |> (* splitting of inequalities *)
   622        map (apsnd (if split_neq then elim_neq else
   623                      Library.single o map ignore_neq))
   624     |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
   625     |> (* numbering of hypotheses, ignoring irrelevant ones *)
   626        map (apsnd (number_hyps 0))
   627 in
   628   trace_msg ctxt ("Splitting of inequalities yields " ^
   629     string_of_int (length result) ^ " subgoal(s) total.");
   630   result
   631 end;
   632 
   633 fun refutes ctxt :
   634     (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
   635   let
   636     fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
   637           let
   638             val atoms = atoms_of (map fst initems)
   639             val n = length atoms
   640             val mkleq = mklineq atoms
   641             val ixs = 0 upto (n - 1)
   642             val iatoms = atoms ~~ ixs
   643             val natlineqs = map_filter (mknat Ts ixs) iatoms
   644             val ineqs = map mkleq initems @ natlineqs
   645           in
   646             (case elim ctxt (ineqs, []) of
   647                Success j =>
   648                  (trace_msg ctxt ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
   649                   refute initemss (js @ [j]))
   650              | Failure _ => NONE)
   651           end
   652       | refute [] js = SOME js
   653   in refute end;
   654 
   655 fun refute ctxt params do_pre split_neq terms : injust list option =
   656   refutes ctxt (split_items ctxt do_pre split_neq (map snd params, terms)) [];
   657 
   658 fun count P xs = length (filter P xs);
   659 
   660 fun prove ctxt params do_pre Hs concl : bool * injust list option =
   661   let
   662     val _ = trace_msg ctxt "prove:"
   663     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
   664     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
   665     (* theorem/tactic level                                             *)
   666     val Hs' = Hs @ [LA_Logic.neg_prop concl]
   667     fun is_neq NONE                 = false
   668       | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
   669     val neq_limit = Config.get ctxt LA_Data.neq_limit
   670     val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
   671   in
   672     if split_neq then ()
   673     else
   674       trace_msg ctxt ("neq_limit exceeded (current value is " ^
   675         string_of_int neq_limit ^ "), ignoring all inequalities");
   676     (split_neq, refute ctxt params do_pre split_neq Hs')
   677   end handle TERM ("neg_prop", _) =>
   678     (* since no meta-logic negation is available, we can only fail if   *)
   679     (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
   680     (* dropping the conclusion doesn't work either, because even        *)
   681     (* 'False' does not imply arbitrary 'concl::prop')                  *)
   682     (trace_msg ctxt "prove failed (cannot negate conclusion).";
   683       (false, NONE));
   684 
   685 fun refute_tac ctxt (i, split_neq, justs) =
   686   fn state =>
   687     let
   688       val _ = trace_thm ctxt
   689         ["refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
   690           string_of_int (length justs) ^ " justification(s)):"] state
   691       val neqE = get_neqE ctxt;
   692       fun just1 j =
   693         (* eliminate inequalities *)
   694         (if split_neq then
   695           REPEAT_DETERM (eresolve_tac ctxt neqE i)
   696         else
   697           all_tac) THEN
   698           PRIMITIVE (trace_thm ctxt ["State after neqE:"]) THEN
   699           (* use theorems generated from the actual justifications *)
   700           Subgoal.FOCUS (fn {prems, ...} => resolve_tac ctxt [mkthm ctxt prems j] 1) ctxt i
   701     in
   702       (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
   703       DETERM (resolve_tac ctxt [LA_Logic.notI, LA_Logic.ccontr] i) THEN
   704       (* user-defined preprocessing of the subgoal *)
   705       DETERM (LA_Data.pre_tac ctxt i) THEN
   706       PRIMITIVE (trace_thm ctxt ["State after pre_tac:"]) THEN
   707       (* prove every resulting subgoal, using its justification *)
   708       EVERY (map just1 justs)
   709     end  state;
   710 
   711 (*
   712 Fast but very incomplete decider. Only premises and conclusions
   713 that are already (negated) (in)equations are taken into account.
   714 *)
   715 fun simpset_lin_arith_tac ctxt = SUBGOAL (fn (A, i) =>
   716   let
   717     val params = rev (Logic.strip_params A)
   718     val Hs = Logic.strip_assums_hyp A
   719     val concl = Logic.strip_assums_concl A
   720     val _ = trace_term ctxt ["Trying to refute subgoal " ^ string_of_int i] A
   721   in
   722     case prove ctxt params true Hs concl of
   723       (_, NONE) => (trace_msg ctxt "Refutation failed."; no_tac)
   724     | (split_neq, SOME js) => (trace_msg ctxt "Refutation succeeded.";
   725                                refute_tac ctxt (i, split_neq, js))
   726   end);
   727 
   728 fun prems_lin_arith_tac ctxt =
   729   Method.insert_tac ctxt (Simplifier.prems_of ctxt) THEN'
   730   simpset_lin_arith_tac ctxt;
   731 
   732 fun lin_arith_tac ctxt =
   733   simpset_lin_arith_tac (empty_simpset ctxt);
   734 
   735 
   736 
   737 (** Forward proof from theorems **)
   738 
   739 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   740 to splits of ~= premises) such that it coincides with the order of the cases
   741 generated by function split_items. *)
   742 
   743 datatype splittree = Tip of thm list
   744                    | Spl of thm * cterm * splittree * cterm * splittree;
   745 
   746 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
   747 
   748 fun extract (imp : cterm) : cterm * cterm =
   749 let val (Il, r)    = Thm.dest_comb imp
   750     val (_, imp1)  = Thm.dest_comb Il
   751     val (Ict1, _)  = Thm.dest_comb imp1
   752     val (_, ct1)   = Thm.dest_comb Ict1
   753     val (Ir, _)    = Thm.dest_comb r
   754     val (_, Ict2r) = Thm.dest_comb Ir
   755     val (Ict2, _)  = Thm.dest_comb Ict2r
   756     val (_, ct2)   = Thm.dest_comb Ict2
   757 in (ct1, ct2) end;
   758 
   759 fun splitasms ctxt (asms : thm list) : splittree =
   760 let val neqE = get_neqE ctxt
   761     fun elim_neq [] (asms', []) = Tip (rev asms')
   762       | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
   763       | elim_neq (_ :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
   764       | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
   765       (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
   766         SOME spl =>
   767           let val (ct1, ct2) = extract (Thm.cprop_of spl)
   768               val thm1 = Thm.assume ct1
   769               val thm2 = Thm.assume ct2
   770           in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
   771             ct2, elim_neq neqs (asms', asms@[thm2]))
   772           end
   773       | NONE => elim_neq neqs (asm::asms', asms))
   774 in elim_neq neqE ([], asms) end;
   775 
   776 fun fwdproof ctxt (Tip asms : splittree) (j::js : injust list) = (mkthm ctxt asms j, js)
   777   | fwdproof ctxt (Spl (thm, ct1, tree1, ct2, tree2)) js =
   778       let
   779         val (thm1, js1) = fwdproof ctxt tree1 js
   780         val (thm2, js2) = fwdproof ctxt tree2 js1
   781         val thm1' = Thm.implies_intr ct1 thm1
   782         val thm2' = Thm.implies_intr ct2 thm2
   783       in (thm2' COMP (thm1' COMP thm), js2) end;
   784       (* FIXME needs handle THM _ => NONE ? *)
   785 
   786 fun prover ctxt thms Tconcl (js : injust list) split_neq pos : thm option =
   787   let
   788     val nTconcl = LA_Logic.neg_prop Tconcl
   789     val cnTconcl = Thm.cterm_of ctxt nTconcl
   790     val nTconclthm = Thm.assume cnTconcl
   791     val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
   792     val (Falsethm, _) = fwdproof ctxt tree js
   793     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   794     val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
   795   in SOME (trace_thm ctxt ["Proved by lin. arith. prover:"] (LA_Logic.mk_Eq concl)) end
   796   (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
   797   handle THM _ => NONE;
   798 
   799 (* PRE: concl is not negated!
   800    This assumption is OK because
   801    1. lin_arith_simproc tries both to prove and disprove concl and
   802    2. lin_arith_simproc is applied by the Simplifier which
   803       dives into terms and will thus try the non-negated concl anyway.
   804 *)
   805 fun lin_arith_simproc ctxt concl =
   806   let
   807     val thms = maps LA_Logic.atomize (Simplifier.prems_of ctxt)
   808     val Hs = map Thm.prop_of thms
   809     val Tconcl = LA_Logic.mk_Trueprop (Thm.term_of concl)
   810   in
   811     case prove ctxt [] false Hs Tconcl of (* concl provable? *)
   812       (split_neq, SOME js) => prover ctxt thms Tconcl js split_neq true
   813     | (_, NONE) =>
   814         let val nTconcl = LA_Logic.neg_prop Tconcl in
   815           case prove ctxt [] false Hs nTconcl of (* ~concl provable? *)
   816             (split_neq, SOME js) => prover ctxt thms nTconcl js split_neq false
   817           | (_, NONE) => NONE
   818         end
   819   end;
   820 
   821 end;