--- a/src/Provers/Arith/fast_lin_arith.ML Mon Jan 04 15:07:47 1999 +0100
+++ b/src/Provers/Arith/fast_lin_arith.ML Mon Jan 04 15:08:40 1999 +0100
@@ -28,6 +28,9 @@
val decomp: term ->
((term * int)list * int * string * (term * int)list * int)option
val simp: thm -> thm
+ val is_False: thm -> bool
+ val is_nat: typ list * term -> bool
+ val mk_nat_thm: Sign.sg -> term -> thm
end;
(*
decomp(`x Rel y') should yield (p,i,Rel,q,j)
@@ -38,6 +41,9 @@
simp must reduce contradictory <= to False.
It should also cancel common summands to keep <= reduced;
otherwise <= can grow to massive proportions.
+
+is_nat(parameter-types,t) = t:nat
+mk_nat_thm(t) = "0 <= t"
*)
functor Fast_Lin_Arith(LA_Data:LIN_ARITH_DATA) =
@@ -45,11 +51,15 @@
(*** A fast decision procedure ***)
(*** Code ported from HOL Light ***)
-(* possible optimizations: eliminate eqns first; use (var,coeff) rep *)
+(* possible optimizations:
+ use (var,coeff) rep or vector rep tp save space;
+ treat non-negative atoms separately rather than adding 0 <= atom
+*)
datatype lineq_type = Eq | Le | Lt;
-datatype injust = Given of int
+datatype injust = Asm of int
+ | Nat of int (* index of atom *)
| Fwd of injust * thm
| Multiplied of int * injust
| Added of injust * injust;
@@ -148,8 +158,17 @@
| extract xs [] = (None,xs)
in extract [] end;
+(*
+fun print_ineqs ineqs =
+ writeln(cat_lines(""::map (fn Lineq(c,t,l,_) =>
+ string_of_int c ^
+ (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^
+ commas(map string_of_int l)) ineqs));
+*)
+
fun elim ineqs =
- let val (triv,nontriv) = partition is_trivial ineqs in
+ let (*val dummy = print_ineqs ineqs;*)
+ val (triv,nontriv) = partition is_trivial ineqs in
if not(null triv)
then case find_first is_answer triv of
None => elim nontriv | some => some
@@ -188,9 +207,22 @@
(* Translate back a proof. *)
(* ------------------------------------------------------------------------- *)
-(* FIXME OPTIMIZE!!!! *)
-fun mkproof asms just =
- let fun addthms thm1 thm2 =
+(* FIXME OPTIMIZE!!!!
+ Addition/Multiplication need i*t representation rather than t+t+...
+
+Simplification may detect a contradiction 'prematurely' due to type
+information: n+1 <= 0 is simplified to False and does not need to be crossed
+with 0 <= n.
+*)
+local
+ exception FalseE of thm
+in
+fun mkproof sg asms just =
+ let val atoms = foldl (fn (ats,(lhs,_,_,rhs,_)) =>
+ map fst lhs union (map fst rhs union ats))
+ ([], mapfilter (LA_Data.decomp o concl_of) asms)
+
+ fun addthms thm1 thm2 =
let val conj = thm1 RS (thm2 RS LA_Data.conjI)
in the(get_first (fn th => Some(conj RS th) handle _ => None)
LA_Data.add_mono_thms)
@@ -200,13 +232,18 @@
let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
in if n < 0 then mul(~n,thm) RS LA_Data.sym else mul(n,thm) end;
- fun mk(Given i) = nth_elem(i,asms)
+ fun simp thm =
+ let val thm' = LA_Data.simp thm
+ in if LA_Data.is_False thm' then raise FalseE thm' else thm' end
+
+ fun mk(Asm i) = nth_elem(i,asms)
+ | mk(Nat(i)) = LA_Data.mk_nat_thm sg (nth_elem(i,atoms))
| mk(Fwd(j,thm)) = mk j RS thm
- | mk(Added(j1,j2)) = LA_Data.simp(addthms (mk j1) (mk j2))
+ | mk(Added(j1,j2)) = simp(addthms (mk j1) (mk j2))
| mk(Multiplied(n,j)) = multn(n,mk j)
- in LA_Data.simp(mk just) end;
-
+ in LA_Data.simp(mk just) handle FalseE thm => thm end
+end;
fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
@@ -217,7 +254,7 @@
and rhsa = map (coeff rhs) atoms
val diff = map2 (op -) (rhsa,lhsa)
val c = i-j
- val just = Given k
+ val just = Asm k
in case rel of
"<=" => Some(Lineq(c,Le,diff,just))
| "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,Fwd(just,LA_Data.not_leD)))
@@ -229,52 +266,67 @@
end
end;
-fun abstract items =
+fun mknat pTs ixs (atom,i) =
+ if LA_Data.is_nat(pTs,atom)
+ then let val l = map (fn j => if j=i then 1 else 0) ixs
+ in Some(Lineq(0,Le,l,Nat(i))) end
+ else None
+
+fun abstract pTs items =
let val atoms = foldl (fn (ats,((lhs,_,_,rhs,_),_)) =>
(map fst lhs) union ((map fst rhs) union ats))
([],items)
- in mapfilter (mklineq atoms) items end;
+ val ixs = 0 upto (length(atoms)-1)
+ val iatoms = atoms ~~ ixs
+ in mapfilter (mklineq atoms) items @ mapfilter (mknat pTs ixs) iatoms end;
(* Ordinary refutation *)
-fun refute1_tac items =
- let val lineqs = abstract items
+fun refute1_tac pTs items =
+ let val lineqs = abstract pTs items
in case elim lineqs of
None => K no_tac
| Some(Lineq(_,_,_,j)) =>
- resolve_tac [LA_Data.notI,LA_Data.ccontr] THEN'
- METAHYPS (fn asms => rtac (mkproof asms j) 1)
+ fn i => fn state =>
+ let val sg = #sign(rep_thm state)
+ in resolve_tac [LA_Data.notI,LA_Data.ccontr] i THEN
+ METAHYPS (fn asms => rtac (mkproof sg asms j) 1) i
+ end state
end;
(* Double refutation caused by equality in conclusion *)
-fun refute2_tac items (rhs,i,_,lhs,j) nHs =
- (case elim (abstract(items@[((rhs,i,"<",lhs,j),nHs)])) of
+fun refute2_tac pTs items (rhs,i,_,lhs,j) nHs =
+ (case elim (abstract pTs (items@[((rhs,i,"<",lhs,j),nHs)])) of
None => K no_tac
| Some(Lineq(_,_,_,j1)) =>
- (case elim (abstract(items@[((lhs,j,"<",rhs,i),nHs)])) of
+ (case elim (abstract pTs (items@[((lhs,j,"<",rhs,i),nHs)])) of
None => K no_tac
| Some(Lineq(_,_,_,j2)) =>
- rtac LA_Data.ccontr THEN' etac LA_Data.nat_neqE THEN'
- METAHYPS (fn asms => rtac (mkproof asms j1) 1) THEN'
- METAHYPS (fn asms => rtac (mkproof asms j2) 1) ));
+ fn i => fn state =>
+ let val sg = #sign(rep_thm state)
+ in rtac LA_Data.ccontr i THEN etac LA_Data.nat_neqE i THEN
+ METAHYPS (fn asms => rtac (mkproof sg asms j1) 1) i THEN
+ METAHYPS (fn asms => rtac (mkproof sg asms j2) 1) i
+ end state));
(*
Fast but very incomplete decider. Only premises and conclusions
that are already (negated) (in)equations are taken into account.
*)
val lin_arith_tac = SUBGOAL (fn (A,n) =>
- let val Hs = Logic.strip_assums_hyp A
+ let val pTs = rev(map snd (Logic.strip_params A))
+ val Hs = Logic.strip_assums_hyp A
val nHs = length Hs
val His = Hs ~~ (0 upto (nHs-1))
val Hitems = mapfilter (fn (h,i) => case LA_Data.decomp h of
None => None | Some(it) => Some(it,i)) His
in case LA_Data.decomp(Logic.strip_assums_concl A) of
- None => if null Hitems then no_tac else refute1_tac Hitems n
+ None => if null Hitems then no_tac else refute1_tac pTs Hitems n
| Some(citem as (r,i,rel,l,j)) =>
if rel = "="
- then refute2_tac Hitems citem nHs n
+ then refute2_tac pTs Hitems citem nHs n
else let val neg::rel0 = explode rel
val nrel = if neg = "~" then implode rel0 else "~"^rel
- in refute1_tac (Hitems@[((r,i,nrel,l,j),nHs)]) n end
+ in refute1_tac pTs (Hitems@[((r,i,nrel,l,j),nHs)]) n end
end);
fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac i;