# HG changeset patch # User nipkow # Date 912182099 -3600 # Node ID aeb97860d35210f1ae470eb7319c84f5a58fe891 # Parent ec5c3d17969f0eb1da676762786c9e6aec478975 Replaced the puny nat_transitive.ML by the general fast_lin_arith.ML. diff -r ec5c3d17969f -r aeb97860d352 src/Provers/Arith/fast_lin_arith.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Provers/Arith/fast_lin_arith.ML Fri Nov 27 16:54:59 1998 +0100 @@ -0,0 +1,282 @@ +(* Title: Provers/Arith/fast_lin_arith.ML + ID: $Id$ + Author: Tobias Nipkow + Copyright 1998 TU Munich + +A generic linear arithmetic package. At the moment only used for nat. +The two tactics provided: + lin_arith_tac: int -> tactic +cut_lin_arith_tac: thms -> int -> tactic +Only take premises and conclusions into account +that are already (negated) (in)equations. +*) + +(*** Data needed for setting up the linear arithmetic package ***) + +signature LIN_ARITH_DATA = +sig + val add_mono_thms: thm list + (* [| i rel1 j; m rel2 n |] ==> i + m rel3 j + n *) + val conjI: thm + val ccontr: thm (* (~ P ==> False) ==> P *) + val lessD: thm (* m < n ==> Suc m <= n *) + val nat_neqE: thm (* [| m ~= n; m < n ==> P; n < m ==> P |] ==> P *) + val notI: thm (* (P ==> False) ==> ~ P *) + val not_leD: thm (* ~(m <= n) ==> Suc n <= m *) + val not_lessD: thm (* ~(m < n) ==> n < m *) + val sym: thm (* x = y ==> y = x *) + val decomp: term -> + ((term * int)list * int * string * (term * int)list * int)option + val simp: thm -> thm +end; +(* +decomp(`x Rel y') should yield (p,i,Rel,q,j) + where Rel is one of "<", "~<", "<=", "~<=" and "=" and + p/q is the decomposition of the sum terms x/y into a list + of summand * multiplicity pairs and a constant summand. + +simp must reduce contradictory <= to False. + It should also cancel common summands to keep <= reduced; + otherwise <= can grow to massive proportions. +*) + +functor Fast_Lin_Arith(LA_Data:LIN_ARITH_DATA) = +struct + +(*** A fast decision procedure ***) +(*** Code ported from HOL Light ***) +(* possible optimizations: eliminate eqns first; use (var,coeff) rep *) + +datatype lineq_type = Eq | Le | Lt; + +datatype injust = Given of int + | Fwd of injust * thm + | Multiplied of int * injust + | Added of injust * injust; + +datatype lineq = Lineq of int * lineq_type * int list * injust; + +(* ------------------------------------------------------------------------- *) +(* Calculate new (in)equality type after addition. *) +(* ------------------------------------------------------------------------- *) + +fun find_add_type(Eq,x) = x + | find_add_type(x,Eq) = x + | find_add_type(_,Lt) = Lt + | find_add_type(Lt,_) = Lt + | find_add_type(Le,Le) = Le; + +(* ------------------------------------------------------------------------- *) +(* Multiply out an (in)equation. *) +(* ------------------------------------------------------------------------- *) + +fun multiply_ineq n (i as Lineq(k,ty,l,just)) = + if n = 1 then i + else if n = 0 andalso ty = Lt then sys_error "multiply_ineq" + else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq" + else Lineq(n * k,ty,map (apl(n,op * )) l,Multiplied(n,just)); + +(* ------------------------------------------------------------------------- *) +(* Add together (in)equations. *) +(* ------------------------------------------------------------------------- *) + +fun add_ineq (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) = + let val l = map2 (op +) (l1,l2) + in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end; + +(* ------------------------------------------------------------------------- *) +(* Elimination of variable between a single pair of (in)equations. *) +(* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *) +(* ------------------------------------------------------------------------- *) + +fun gcd x y = + let fun gxd x y = + if y = 0 then x else gxd y (x mod y) + in if x < y then gxd y x else gxd x y end; + +fun lcm x y = (x * y) div gcd x y; + +fun el 0 (h::_) = h + | el n (_::t) = el (n - 1) t + | el _ _ = sys_error "el"; + +fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) = + let val c1 = el v l1 and c2 = el v l2 + val m = lcm (abs c1) (abs c2) + val m1 = m div (abs c1) and m2 = m div (abs c2) + val (n1,n2) = + if (c1 >= 0) = (c2 >= 0) + then if ty1 = Eq then (~m1,m2) + else if ty2 = Eq then (m1,~m2) + else sys_error "elim_var" + else (m1,m2) + val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1) + then (~n1,~n2) else (n1,n2) + in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end; + +(* ------------------------------------------------------------------------- *) +(* The main refutation-finding code. *) +(* ------------------------------------------------------------------------- *) + +fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l; + +fun is_answer (ans as Lineq(k,ty,l,_)) = + case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0; + +fun calc_blowup l = + let val (p,n) = partition (apl(0,op<)) (filter (apl(0,op<>)) l) + in (length p) * (length n) end; + +(* ------------------------------------------------------------------------- *) +(* Main elimination code: *) +(* *) +(* (1) Looks for immediate solutions (false assertions with no variables). *) +(* *) +(* (2) If there are any equations, picks a variable with the lowest absolute *) +(* coefficient in any of them, and uses it to eliminate. *) +(* *) +(* (3) Otherwise, chooses a variable in the inequality to minimize the *) +(* blowup (number of consequences generated) and eliminates it. *) +(* ------------------------------------------------------------------------- *) + +fun allpairs f xs ys = + flat(map (fn x => map (fn y => f x y) ys) xs); + +fun extract_first p = + let fun extract xs (y::ys) = if p y then (Some y,xs@ys) + else extract (y::xs) ys + | extract xs [] = (None,xs) + in extract [] end; + +fun elim ineqs = + let val (triv,nontriv) = partition is_trivial ineqs in + if not(null triv) + then case find_first is_answer triv of + None => elim nontriv | some => some + else + if null nontriv then None else + let val (eqs,noneqs) = partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in + if not(null eqs) then + let val clist = foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs) + val sclist = sort (fn (x,y) => int_ord(abs(x),abs(y))) + (filter (fn i => i<>0) clist) + val c = hd sclist + val (Some(eq as Lineq(_,_,ceq,_)),othereqs) = + extract_first (fn Lineq(_,_,l,_) => c mem l) eqs + val v = find_index (fn k => k=c) ceq + val (ioth,roth) = partition (fn (Lineq(_,_,l,_)) => el v l = 0) + (othereqs @ noneqs) + val others = map (elim_var v eq) roth @ ioth + in elim others end + else + let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs + val numlist = 0 upto (length(hd lists) - 1) + val coeffs = map (fn i => map (el i) lists) numlist + val blows = map calc_blowup coeffs + val iblows = blows ~~ numlist + val nziblows = filter (fn (i,_) => i<>0) iblows + in if null nziblows then None else + let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows) + val (no,yes) = partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs + val (pos,neg) = partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes + in elim (no @ allpairs (elim_var v) pos neg) end + end + end + end; + +(* ------------------------------------------------------------------------- *) +(* Translate back a proof. *) +(* ------------------------------------------------------------------------- *) + +(* FIXME OPTIMIZE!!!! *) +fun mkproof asms just = + let fun addthms thm1 thm2 = + let val conj = thm1 RS (thm2 RS LA_Data.conjI) + in the(get_first (fn th => Some(conj RS th) handle _ => None) + LA_Data.add_mono_thms) + end; + + fun multn(n,thm) = + let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th) + in if n < 0 then mul(~n,thm) RS LA_Data.sym else mul(n,thm) end; + + fun mk(Given i) = nth_elem(i,asms) + | mk(Fwd(j,thm)) = mk j RS thm + | mk(Added(j1,j2)) = LA_Data.simp(addthms (mk j1) (mk j2)) + | mk(Multiplied(n,j)) = multn(n,mk j) + + in LA_Data.simp(mk just) end; + + +fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i; + +fun mklineq atoms = + let val n = length atoms in + fn ((lhs,i,rel,rhs,j),k) => + let val lhsa = map (coeff lhs) atoms + and rhsa = map (coeff rhs) atoms + val diff = map2 (op -) (rhsa,lhsa) + val c = i-j + val just = Given k + in case rel of + "<=" => Some(Lineq(c,Le,diff,just)) + | "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,Fwd(just,LA_Data.not_leD))) + | "<" => Some(Lineq(c+1,Le,diff,Fwd(just,LA_Data.lessD))) + | "~<" => Some(Lineq(~c,Le,map (op~) diff,Fwd(just,LA_Data.not_lessD))) + | "=" => Some(Lineq(c,Eq,diff,just)) + | "~=" => None + | _ => sys_error("mklineq" ^ rel) + end + end; + +fun abstract items = + let val atoms = foldl (fn (ats,((lhs,_,_,rhs,_),_)) => + (map fst lhs) union ((map fst rhs) union ats)) + ([],items) + in mapfilter (mklineq atoms) items end; + +(* Ordinary refutation *) +fun refute1_tac items = + let val lineqs = abstract items + in case elim lineqs of + None => K no_tac + | Some(Lineq(_,_,_,j)) => + resolve_tac [LA_Data.notI,LA_Data.ccontr] THEN' + METAHYPS (fn asms => rtac (mkproof asms j) 1) + end; + +(* Double refutation caused by equality in conclusion *) +fun refute2_tac items (rhs,i,_,lhs,j) nHs = + (case elim (abstract(items@[((rhs,i,"<",lhs,j),nHs)])) of + None => K no_tac + | Some(Lineq(_,_,_,j1)) => + (case elim (abstract(items@[((lhs,j,"<",rhs,i),nHs)])) of + None => K no_tac + | Some(Lineq(_,_,_,j2)) => + rtac LA_Data.ccontr THEN' etac LA_Data.nat_neqE THEN' + METAHYPS (fn asms => rtac (mkproof asms j1) 1) THEN' + METAHYPS (fn asms => rtac (mkproof asms j2) 1) )); + +(* +Fast but very incomplete decider. Only premises and conclusions +that are already (negated) (in)equations are taken into account. +*) +val lin_arith_tac = SUBGOAL (fn (A,n) => + let val Hs = Logic.strip_assums_hyp A + val nHs = length Hs + val His = Hs ~~ (0 upto (nHs-1)) + val Hitems = mapfilter (fn (h,i) => case LA_Data.decomp h of + None => None | Some(it) => Some(it,i)) His + in case LA_Data.decomp(Logic.strip_assums_concl A) of + None => if null Hitems then no_tac else refute1_tac Hitems n + | Some(citem as (r,i,rel,l,j)) => + if rel = "=" + then refute2_tac Hitems citem nHs n + else let val neg::rel0 = explode rel + val nrel = if neg = "~" then implode rel0 else "~"^rel + in refute1_tac (Hitems@[((r,i,nrel,l,j),nHs)]) n end + end); + +fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac i; + +end; diff -r ec5c3d17969f -r aeb97860d352 src/Provers/Arith/nat_transitive.ML --- a/src/Provers/Arith/nat_transitive.ML Fri Nov 27 16:46:01 1998 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,253 +0,0 @@ -(* Title: Provers/nat_transitive.ML - ID: $Id$ - Author: Tobias Nipkow - Copyright 1996 TU Munich -*) - -(*** -A very simple package for inequalities over nat. -It uses all premises of the form - -t = u, t < u, t <= u, ~(t < u), ~(t <= u) - -where t and u must be of type nat, to -1. either derive a contradiction, - in which case the conclusion can be any term, -2. or prove the conclusion, which must be of the same form as the premises. - -The package -- does not deal with the relation ~= -- treats `pred', +, *, ... as atomic terms. Hence it can prove - [| x < y+z; y+z < u |] ==> Suc x < u - but not - [| x < y+z; z < u |] ==> Suc x < y+u -- takes only (in)equalities which are atomic premises into account. It does - not deal with logical operators like -->, & etc. Hence it cannot prove - [| x < y+z & y+z < u |] ==> Suc x < u - -In order not to fall foul of the above limitations, the following hints are -useful: - -1. You may need to run `by(safe_tac HOL_cs)' in order to bring out the atomic - premises. - -2. To get rid of ~= in the premises, it is advisable to use a rule like - nat_neqE = "[| (m::nat) ~= n; m < n ==> P; n < m ==> P |] ==> P" : thm - (the name nat_eqE is chosen in HOL), for example as follows: - by(safe_tac (HOL_cs addSEs [nat_neqE]) - -3. To get rid of `pred', you may be able to do the following: - expand `pred(m)' into `case m of 0 => 0 | Suc n => n' and use split_tac - to turn the case-expressions into logical case distinctions. In HOL: - simp_tac (... addsimps [pred_def] setloop (split_tac [expand_nat_case])) - -The basic tactic is `trans_tac'. In order to use `trans_tac' as a solver in -the simplifier, `cut_trans_tac' is also provided, which cuts the given thms -in as facts. - -Notes: -- It should easily be possible to adapt this package to other numeric types - like int. -- There is ample scope for optimizations, which so far have not proved - necessary. -- The code can be simplified by adding the negated conclusion to the - premises to derive a contradiction. However, this would restrict the - package to classical logics. -***) - -(* The package works for arbitrary logics. - You just need to instantiate the following parameter structure. -*) -signature LESS_ARITH = -sig - val lessI: thm (* n < Suc n *) - val zero_less_Suc: thm (* 0 < Suc n *) - val less_reflE: thm (* n < n ==> P *) - val less_zeroE: thm (* n < 0 ==> P *) - val less_incr: thm (* m < n ==> Suc m < Suc n *) - val less_decr: thm (* Suc m < Suc n ==> m < n *) - val less_incr_rhs: thm (* m < n ==> m < Suc n *) - val less_decr_lhs: thm (* Suc m < n ==> m < n *) - val less_trans_Suc: thm (* [| i < j; j < k |] ==> Suc i < k *) - val leD: thm (* m <= n ==> m < Suc n *) - val not_lessD: thm (* ~(m < n) ==> n < Suc m *) - val not_leD: thm (* ~(m <= n) ==> n < m *) - val eqD1: thm (* m = n ==> m < Suc n *) - val eqD2: thm (* m = n ==> m < Suc n *) - val not_lessI: thm (* n < Suc m ==> ~(m < n) *) - val leI: thm (* m < Suc n ==> m <= n *) - val not_leI: thm (* n < m ==> ~(m <= n) *) - val eqI: thm (* [| m < Suc n; n < Suc m |] ==> n = m *) - val is_zero: term -> bool - val decomp: term -> (term * int * string * term * int)option -(* decomp(`Suc^i(x) Rel Suc^j(y)') should yield (x,i,Rel,y,j) - where Rel is one of "<", "~<", "<=", "~<=" and "=" *) -end; - - -signature TRANS_TAC = -sig - val trans_tac: int -> tactic - val cut_trans_tac: thm list -> int -> tactic -end; - -functor Trans_Tac_Fun(Less:LESS_ARITH):TRANS_TAC = -struct - -datatype proof = Asm of int - | Thm of proof list * thm - | Incr1 of proof * int (* Increment 1 side *) - | Incr2 of proof * int (* Increment 2 sides *); - - -(*** Turn proof objects into thms ***) - -fun incr2(th,i) = if i=0 then th else - if i>0 then incr2(th RS Less.less_incr,i-1) - else incr2(th RS Less.less_decr,i+1); - -fun incr1(th,i) = if i=0 then th else - if i>0 then incr1(th RS Less.less_incr_rhs,i-1) - else incr1(th RS Less.less_decr_lhs,i+1); - -fun prove asms = - let fun pr(Asm i) = nth_elem(i,asms) - | pr(Thm(prfs,thm)) = (map pr prfs) MRS thm - | pr(Incr1(p,i)) = incr1(pr p,i) - | pr(Incr2(p,i)) = incr2(pr p,i) - in pr end; - -(*** Internal representation of inequalities -(x,i,y,j) means x+i < y+j. -Leads to simpler case distinctions than the normalized x < y+k -***) -type less = term * int * term * int * proof; - -(*** raised when contradiction is found ***) -exception Contr of proof; - -(*** raised when goal can't be proved ***) -exception Cant; - -infix subsumes; - -fun (x,i,y,j:int,_) subsumes (x',i',y',j',_) = - x=x' andalso y=y' andalso j-i<=j'-i'; - -fun trivial(x,i:int,y,j,_) = (x=y orelse Less.is_zero(x)) andalso i=j - then raise Contr(Thm([Incr1(Incr2(p,~j),j-i)],Less.less_reflE)) else - if Less.is_zero(y) andalso i>=j - then raise Contr(Thm([Incr2(p,~j)],Less.less_zeroE)) - else less; - -fun mktrans((x,i,_,j,p):less,(_,k,z,l,q)) = - ctest(if j >= k - then (x,i+1,z,l+(j-k),Thm([p,Incr2(q,j-k)],Less.less_trans_Suc)) - else (x,i+(k-j)+1,z,l,Thm([Incr2(p,k-j),q],Less.less_trans_Suc))); - -fun trans (new as (x,i,y,j,p)) olds = - let fun tr(news, old as (x1,i1,y1,j1,p1):less) = - if y1=x then mktrans(old,new)::news else - if x1=y then mktrans(new,old)::news else news - in foldl tr ([],olds) end; - -fun close1(olds: less list)(new:less):less list = - if trivial new orelse exists (fn old => old subsumes new) olds then olds - else let val news = trans new olds - in close (add new (olds,[])) news end -and close (olds: less list) ([]:less list) = olds - | close olds ((new:less)::news) = close (close1 olds (ctest new)) news; - -(*** end of transitive closure ***) - -(* recognize and solve trivial goal *) -fun triv_sol(x,i,y,j,_) = - if x=y andalso i (case find_first (fn fact => fact subsumes less) facts of - None => raise Cant - | Some(a,m,b,n,p) => Incr1(Incr2(p,j-n),n+i-m-j)) - | Some prf => prf; - -(* turn term into a less-tuple *) -fun mkasm(t,n) = - case Less.decomp(t) of - Some(x,i,rel,y,j) => (case rel of - "<" => [(x,i,y,j,Asm n)] - | "~<" => [(y,j,x,i+1,Thm([Asm n],Less.not_lessD))] - | "<=" => [(x,i,y,j+1,Thm([Asm n],Less.leD))] - | "~<=" => [(y,j,x,i,Thm([Asm n],Less.not_leD))] - | "=" => [(x,i,y,j+1,Thm([Asm n],Less.eqD1)), - (y,j,x,i+1,Thm([Asm n],Less.eqD2))] - | "~=" => [] - | _ => error("trans_tac/decomp: unknown relation " ^ rel)) - | None => []; - -(* mkconcl t returns a pair (goals,proof) where goals is a list of *) -(* less-subgoals to solve, and proof the validation which proves the concl t *) -(* from the subgoals. Asm ~1 is dummy *) -fun mkconcl t = - case Less.decomp(t) of - Some(x,i,rel,y,j) => (case rel of - "<" => ([(x,i,y,j,Asm ~1)],Asm 0) - | "~<" => ([(y,j,x,i+1,Asm ~1)],Thm([Asm 0],Less.not_lessI)) - | "<=" => ([(x,i,y,j+1,Asm ~1)],Thm([Asm 0],Less.leI)) - | "~<=" => ([(y,j,x,i,Asm ~1)],Thm([Asm 0],Less.not_leI)) - | "=" => ([(x,i,y,j+1,Asm ~1),(y,j,x,i+1,Asm ~1)], - Thm([Asm 0,Asm 1],Less.eqI)) - | "~=" => raise Cant - | _ => error("trans_tac/decomp: unknown relation " ^ rel)) - | None => raise Cant; - - -val trans_tac = SUBGOAL (fn (A,n) => - let val Hs = Logic.strip_assums_hyp A - val C = Logic.strip_assums_concl A - val lesss = flat(ListPair.map mkasm (Hs, 0 upto (length Hs - 1))) - val clesss = close [] lesss - val (subgoals,prf) = mkconcl C - val prfs = map (solve clesss) subgoals - in METAHYPS (fn asms => let val thms = map (prove asms) prfs - in rtac (prove thms prf) 1 end) n - end - handle Contr(p) => METAHYPS (fn asms => rtac (prove asms p) 1) n - | Cant => no_tac); - -fun cut_trans_tac thms = cut_facts_tac thms THEN' trans_tac; - -end; - -(*** Tests -fun test s = prove_goal Nat.thy ("!!m::nat." ^ s) (fn _ => [trans_tac 1]); - -test "[| i Suc(Suc i) < m"; -test "[| i Suc(Suc(Suc i)) <= m"; -test "[| i ~ m <= Suc(Suc i)"; -test "[| i ~ m < Suc(Suc(Suc i))"; -test "[| i m = Suc(Suc(Suc i))"; -test "[| i m = Suc(Suc(Suc i))"; -***)