# HG changeset patch # User nipkow # Date 915458920 -3600 # Node ID b21813d1b7012dd8c69ddadbf3b16b16e9d469e2 # Parent fdf4638bf7263f88d242e90031edbbe345d7b4c9 Version 1 of linear arithmetic for nat. diff -r fdf4638bf726 -r b21813d1b701 src/Provers/Arith/fast_lin_arith.ML --- a/src/Provers/Arith/fast_lin_arith.ML Mon Jan 04 15:07:47 1999 +0100 +++ b/src/Provers/Arith/fast_lin_arith.ML Mon Jan 04 15:08:40 1999 +0100 @@ -28,6 +28,9 @@ val decomp: term -> ((term * int)list * int * string * (term * int)list * int)option val simp: thm -> thm + val is_False: thm -> bool + val is_nat: typ list * term -> bool + val mk_nat_thm: Sign.sg -> term -> thm end; (* decomp(`x Rel y') should yield (p,i,Rel,q,j) @@ -38,6 +41,9 @@ simp must reduce contradictory <= to False. It should also cancel common summands to keep <= reduced; otherwise <= can grow to massive proportions. + +is_nat(parameter-types,t) = t:nat +mk_nat_thm(t) = "0 <= t" *) functor Fast_Lin_Arith(LA_Data:LIN_ARITH_DATA) = @@ -45,11 +51,15 @@ (*** A fast decision procedure ***) (*** Code ported from HOL Light ***) -(* possible optimizations: eliminate eqns first; use (var,coeff) rep *) +(* possible optimizations: + use (var,coeff) rep or vector rep tp save space; + treat non-negative atoms separately rather than adding 0 <= atom +*) datatype lineq_type = Eq | Le | Lt; -datatype injust = Given of int +datatype injust = Asm of int + | Nat of int (* index of atom *) | Fwd of injust * thm | Multiplied of int * injust | Added of injust * injust; @@ -148,8 +158,17 @@ | extract xs [] = (None,xs) in extract [] end; +(* +fun print_ineqs ineqs = + writeln(cat_lines(""::map (fn Lineq(c,t,l,_) => + string_of_int c ^ + (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^ + commas(map string_of_int l)) ineqs)); +*) + fun elim ineqs = - let val (triv,nontriv) = partition is_trivial ineqs in + let (*val dummy = print_ineqs ineqs;*) + val (triv,nontriv) = partition is_trivial ineqs in if not(null triv) then case find_first is_answer triv of None => elim nontriv | some => some @@ -188,9 +207,22 @@ (* Translate back a proof. *) (* ------------------------------------------------------------------------- *) -(* FIXME OPTIMIZE!!!! *) -fun mkproof asms just = - let fun addthms thm1 thm2 = +(* FIXME OPTIMIZE!!!! + Addition/Multiplication need i*t representation rather than t+t+... + +Simplification may detect a contradiction 'prematurely' due to type +information: n+1 <= 0 is simplified to False and does not need to be crossed +with 0 <= n. +*) +local + exception FalseE of thm +in +fun mkproof sg asms just = + let val atoms = foldl (fn (ats,(lhs,_,_,rhs,_)) => + map fst lhs union (map fst rhs union ats)) + ([], mapfilter (LA_Data.decomp o concl_of) asms) + + fun addthms thm1 thm2 = let val conj = thm1 RS (thm2 RS LA_Data.conjI) in the(get_first (fn th => Some(conj RS th) handle _ => None) LA_Data.add_mono_thms) @@ -200,13 +232,18 @@ let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th) in if n < 0 then mul(~n,thm) RS LA_Data.sym else mul(n,thm) end; - fun mk(Given i) = nth_elem(i,asms) + fun simp thm = + let val thm' = LA_Data.simp thm + in if LA_Data.is_False thm' then raise FalseE thm' else thm' end + + fun mk(Asm i) = nth_elem(i,asms) + | mk(Nat(i)) = LA_Data.mk_nat_thm sg (nth_elem(i,atoms)) | mk(Fwd(j,thm)) = mk j RS thm - | mk(Added(j1,j2)) = LA_Data.simp(addthms (mk j1) (mk j2)) + | mk(Added(j1,j2)) = simp(addthms (mk j1) (mk j2)) | mk(Multiplied(n,j)) = multn(n,mk j) - in LA_Data.simp(mk just) end; - + in LA_Data.simp(mk just) handle FalseE thm => thm end +end; fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i; @@ -217,7 +254,7 @@ and rhsa = map (coeff rhs) atoms val diff = map2 (op -) (rhsa,lhsa) val c = i-j - val just = Given k + val just = Asm k in case rel of "<=" => Some(Lineq(c,Le,diff,just)) | "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,Fwd(just,LA_Data.not_leD))) @@ -229,52 +266,67 @@ end end; -fun abstract items = +fun mknat pTs ixs (atom,i) = + if LA_Data.is_nat(pTs,atom) + then let val l = map (fn j => if j=i then 1 else 0) ixs + in Some(Lineq(0,Le,l,Nat(i))) end + else None + +fun abstract pTs items = let val atoms = foldl (fn (ats,((lhs,_,_,rhs,_),_)) => (map fst lhs) union ((map fst rhs) union ats)) ([],items) - in mapfilter (mklineq atoms) items end; + val ixs = 0 upto (length(atoms)-1) + val iatoms = atoms ~~ ixs + in mapfilter (mklineq atoms) items @ mapfilter (mknat pTs ixs) iatoms end; (* Ordinary refutation *) -fun refute1_tac items = - let val lineqs = abstract items +fun refute1_tac pTs items = + let val lineqs = abstract pTs items in case elim lineqs of None => K no_tac | Some(Lineq(_,_,_,j)) => - resolve_tac [LA_Data.notI,LA_Data.ccontr] THEN' - METAHYPS (fn asms => rtac (mkproof asms j) 1) + fn i => fn state => + let val sg = #sign(rep_thm state) + in resolve_tac [LA_Data.notI,LA_Data.ccontr] i THEN + METAHYPS (fn asms => rtac (mkproof sg asms j) 1) i + end state end; (* Double refutation caused by equality in conclusion *) -fun refute2_tac items (rhs,i,_,lhs,j) nHs = - (case elim (abstract(items@[((rhs,i,"<",lhs,j),nHs)])) of +fun refute2_tac pTs items (rhs,i,_,lhs,j) nHs = + (case elim (abstract pTs (items@[((rhs,i,"<",lhs,j),nHs)])) of None => K no_tac | Some(Lineq(_,_,_,j1)) => - (case elim (abstract(items@[((lhs,j,"<",rhs,i),nHs)])) of + (case elim (abstract pTs (items@[((lhs,j,"<",rhs,i),nHs)])) of None => K no_tac | Some(Lineq(_,_,_,j2)) => - rtac LA_Data.ccontr THEN' etac LA_Data.nat_neqE THEN' - METAHYPS (fn asms => rtac (mkproof asms j1) 1) THEN' - METAHYPS (fn asms => rtac (mkproof asms j2) 1) )); + fn i => fn state => + let val sg = #sign(rep_thm state) + in rtac LA_Data.ccontr i THEN etac LA_Data.nat_neqE i THEN + METAHYPS (fn asms => rtac (mkproof sg asms j1) 1) i THEN + METAHYPS (fn asms => rtac (mkproof sg asms j2) 1) i + end state)); (* Fast but very incomplete decider. Only premises and conclusions that are already (negated) (in)equations are taken into account. *) val lin_arith_tac = SUBGOAL (fn (A,n) => - let val Hs = Logic.strip_assums_hyp A + let val pTs = rev(map snd (Logic.strip_params A)) + val Hs = Logic.strip_assums_hyp A val nHs = length Hs val His = Hs ~~ (0 upto (nHs-1)) val Hitems = mapfilter (fn (h,i) => case LA_Data.decomp h of None => None | Some(it) => Some(it,i)) His in case LA_Data.decomp(Logic.strip_assums_concl A) of - None => if null Hitems then no_tac else refute1_tac Hitems n + None => if null Hitems then no_tac else refute1_tac pTs Hitems n | Some(citem as (r,i,rel,l,j)) => if rel = "=" - then refute2_tac Hitems citem nHs n + then refute2_tac pTs Hitems citem nHs n else let val neg::rel0 = explode rel val nrel = if neg = "~" then implode rel0 else "~"^rel - in refute1_tac (Hitems@[((r,i,nrel,l,j),nHs)]) n end + in refute1_tac pTs (Hitems@[((r,i,nrel,l,j),nHs)]) n end end); fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac i;