--- a/src/HOL/Tools/lin_arith.ML Mon Jun 08 20:43:57 2009 +0200
+++ b/src/HOL/Tools/lin_arith.ML Mon Jun 08 22:29:37 2009 +0200
@@ -16,6 +16,8 @@
val add_simprocs: simproc list -> Context.generic -> Context.generic
val add_inj_const: string * typ -> Context.generic -> Context.generic
val add_discrete_type: string -> Context.generic -> Context.generic
+ val set_number_of: (theory -> typ -> int -> cterm) -> Context.generic ->
+ Context.generic
val setup: Context.generic -> Context.generic
val global_setup: theory -> theory
val split_limit: int Config.T
@@ -36,6 +38,7 @@
val conjI = conjI;
val notI = notI;
val sym = sym;
+val trueI = TrueI;
val not_lessD = @{thm linorder_not_less} RS iffD1;
val not_leD = @{thm linorder_not_le} RS iffD1;
@@ -274,7 +277,6 @@
| domain_is_nat (_ $ (Const ("Not", _) $ (Const (_, T) $ _ $ _))) = nT T
| domain_is_nat _ = false;
-val mk_number = HOLogic.mk_number;
(*---------------------------------------------------------------------------*)
(* the following code performs splitting of certain constants (e.g. min, *)
@@ -752,23 +754,30 @@
val map_data = Fast_Arith.map_data;
-fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
{add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
- lessD = lessD, neqE = neqE, simpset = simpset};
+ lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
-fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
{add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
- lessD = f lessD, neqE = neqE, simpset = simpset};
+ lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
-fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
{add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
- lessD = lessD, neqE = neqE, simpset = f simpset};
+ lessD = lessD, neqE = neqE, simpset = f simpset, number_of = number_of};
+
+fun map_number_of f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
+ {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
+ lessD = lessD, neqE = neqE, simpset = simpset, number_of = f number_of};
fun add_inj_thms thms = Fast_Arith.map_data (map_inj_thms (append thms));
fun add_lessD thm = Fast_Arith.map_data (map_lessD (fn thms => thms @ [thm]));
fun add_simps thms = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimps thms));
fun add_simprocs procs = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimprocs procs));
+fun set_number_of f = Fast_Arith.map_data (map_number_of (K (serial (), f)))
+
+
fun simple_tac ctxt = Fast_Arith.lin_arith_tac ctxt false;
val lin_arith_tac = Fast_Arith.lin_arith_tac;
val trace = Fast_Arith.trace;
@@ -778,13 +787,16 @@
Most of the work is done by the cancel tactics. *)
val init_arith_data =
- Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, ...} =>
+ Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, number_of, ...} =>
{add_mono_thms = @{thms add_mono_thms_ordered_semiring} @ @{thms add_mono_thms_ordered_field} @ add_mono_thms,
- mult_mono_thms = @{thm mult_strict_left_mono} :: @{thm mult_left_mono} :: mult_mono_thms,
+ mult_mono_thms = @{thm mult_strict_left_mono} :: @{thm mult_left_mono} ::
+ @{lemma "a = b ==> c*a = c*b" by (rule arg_cong)} :: mult_mono_thms,
inj_thms = inj_thms,
lessD = lessD @ [@{thm "Suc_leI"}],
neqE = [@{thm linorder_neqE_nat}, @{thm linorder_neqE_ordered_idom}],
simpset = HOL_basic_ss
+ addsimps @{thms ring_distribs}
+ addsimps [@{thm if_True}, @{thm if_False}]
addsimps
[@{thm "monoid_add_class.add_0_left"},
@{thm "monoid_add_class.add_0_right"},
@@ -795,7 +807,8 @@
addsimprocs [ab_group_add_cancel.sum_conv, ab_group_add_cancel.rel_conv]
(*abel_cancel helps it work in abstract algebraic domains*)
addsimprocs Nat_Arith.nat_cancel_sums_add
- addcongs [if_weak_cong]}) #>
+ addcongs [if_weak_cong],
+ number_of = number_of}) #>
add_discrete_type @{type_name nat};
fun add_arith_facts ss =
--- a/src/Provers/Arith/fast_lin_arith.ML Mon Jun 08 20:43:57 2009 +0200
+++ b/src/Provers/Arith/fast_lin_arith.ML Mon Jun 08 22:29:37 2009 +0200
@@ -1,6 +1,6 @@
(* Title: Provers/Arith/fast_lin_arith.ML
ID: $Id$
- Author: Tobias Nipkow and Tjark Weber
+ Author: Tobias Nipkow and Tjark Weber and Sascha Boehme
A generic linear arithmetic package. It provides two tactics
(cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
@@ -21,6 +21,7 @@
val not_lessD : thm (* ~(m < n) ==> n <= m *)
val not_leD : thm (* ~(m <= n) ==> n < m *)
val sym : thm (* x = y ==> y = x *)
+ val trueI : thm (* True *)
val mk_Eq : thm -> thm
val atomize : thm -> thm list
val mk_Trueprop : term -> term
@@ -56,7 +57,6 @@
(*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
val pre_tac: Proof.context -> int -> tactic
- val mk_number: typ -> int -> term
(*the limit on the number of ~= allowed; because each ~= is split
into two cases, this can lead to an explosion*)
@@ -90,9 +90,11 @@
val lin_arith_tac: Proof.context -> bool -> int -> tactic
val lin_arith_simproc: simpset -> term -> thm option
val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
- lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}
+ lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
+ number_of : serial * (theory -> typ -> int -> cterm)}
-> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
- lessD: thm list, neqE: thm list, simpset: Simplifier.simpset})
+ lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
+ number_of : serial * (theory -> typ -> int -> cterm)})
-> Context.generic -> Context.generic
val trace: bool ref
val warning_count: int ref;
@@ -105,6 +107,8 @@
(** theory data **)
+fun no_number_of _ _ _ = raise CTERM ("number_of", [])
+
structure Data = GenericDataFun
(
type T =
@@ -113,22 +117,27 @@
inj_thms: thm list,
lessD: thm list,
neqE: thm list,
- simpset: Simplifier.simpset};
+ simpset: Simplifier.simpset,
+ number_of : serial * (theory -> typ -> int -> cterm)};
val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
- lessD = [], neqE = [], simpset = Simplifier.empty_ss};
+ lessD = [], neqE = [], simpset = Simplifier.empty_ss,
+ number_of = (serial (), no_number_of) };
val extend = I;
fun merge _
({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
- lessD = lessD1, neqE=neqE1, simpset = simpset1},
+ lessD = lessD1, neqE=neqE1, simpset = simpset1,
+ number_of = (number_of1 as (s1, _))},
{add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
- lessD = lessD2, neqE=neqE2, simpset = simpset2}) =
+ lessD = lessD2, neqE=neqE2, simpset = simpset2,
+ number_of = (number_of2 as (s2, _))}) =
{add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
lessD = Thm.merge_thms (lessD1, lessD2),
neqE = Thm.merge_thms (neqE1, neqE2),
- simpset = Simplifier.merge_ss (simpset1, simpset2)};
+ simpset = Simplifier.merge_ss (simpset1, simpset2),
+ number_of = if s1 > s2 then number_of1 else number_of2};
);
val map_data = Data.map;
@@ -320,7 +329,7 @@
else (m1,m2)
val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
then (~n1,~n2) else (n1,n2)
- in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end;
+ in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
(* ------------------------------------------------------------------------- *)
(* The main refutation-finding code. *)
@@ -328,7 +337,7 @@
fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
-fun is_answer (ans as Lineq(k,ty,l,_)) =
+fun is_contradictory (ans as Lineq(k,ty,l,_)) =
case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
fun calc_blowup l =
@@ -347,13 +356,10 @@
(* blowup (number of consequences generated) and eliminates it. *)
(* ------------------------------------------------------------------------- *)
-fun allpairs f xs ys =
- maps (fn x => map (fn y => f x y) ys) xs;
-
fun extract_first p =
- let fun extract xs (y::ys) = if p y then (SOME y,xs@ys)
- else extract (y::xs) ys
- | extract xs [] = (NONE,xs)
+ let
+ fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
+ | extract xs [] = raise Empty
in extract [] end;
fun print_ineqs ineqs =
@@ -368,10 +374,10 @@
datatype result = Success of injust | Failure of history;
fun elim (ineqs, hist) =
- let val dummy = print_ineqs ineqs
+ let val _ = print_ineqs ineqs
val (triv, nontriv) = List.partition is_trivial ineqs in
if not (null triv)
- then case Library.find_first is_answer triv of
+ then case Library.find_first is_contradictory triv of
NONE => elim (nontriv, hist)
| SOME(Lineq(_,_,_,j)) => Success j
else
@@ -379,11 +385,12 @@
else
let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
if not (null eqs) then
- let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
- val sclist = sort (fn (x,y) => int_ord (abs x, abs y))
- (List.filter (fn i => i<>0) clist)
- val c = hd sclist
- val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
+ let val c =
+ fold (fn Lineq(_,_,l,_) => fn cs => l union cs) eqs []
+ |> filter (fn i => i <> 0)
+ |> sort (int_ord o pairself abs)
+ |> hd
+ val (eq as Lineq(_,_,ceq,_),othereqs) =
extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
val v = find_index_eq c ceq
val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
@@ -402,7 +409,7 @@
let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
- in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
+ in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
end
end
end;
@@ -427,11 +434,12 @@
val union_bterm = curry (gen_union
(fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t')));
-(* FIXME OPTIMIZE!!!! (partly done already)
- Addition/Multiplication need i*t representation rather than t+t+...
- Get rid of Mulitplied(2). For Nat LA_Data.mk_number should return Suc^n
- because Numerals are not known early enough.
+fun add_atoms (lhs, _, _, rhs, _, _) =
+ union_term (map fst lhs) o union_term (map fst rhs);
+fun atoms_of ds = fold add_atoms ds [];
+
+(*
Simplification may detect a contradiction 'prematurely' due to type
information: n+1 <= 0 is simplified to False and does not need to be crossed
with 0 <= n.
@@ -444,58 +452,78 @@
let
val ctxt = Simplifier.the_context ss;
val thy = ProofContext.theory_of ctxt;
- val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} = get_data ctxt;
+ val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
+ number_of = (_, num_of), ...} = get_data ctxt;
val simpset' = Simplifier.inherit_context ss simpset;
- val atoms = Library.foldl (fn (ats, (lhs,_,_,rhs,_,_)) =>
- union_term (map fst lhs) (union_term (map fst rhs) ats))
- ([], List.mapPartial (fn thm => if Thm.no_prems thm
- then LA_Data.decomp ctxt (Thm.concl_of thm)
- else NONE) asms)
+ fun only_concl f thm =
+ if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
+ val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
+
+ fun use_first rules thm =
+ get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
+
+ fun add2 thm1 thm2 =
+ use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
+ fun try_add thms thm = get_first (fn th => add2 th thm) thms;
- fun add2 thm1 thm2 =
- let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
- in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms
- end;
- fun try_add [] _ = NONE
- | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of
- NONE => try_add thm1s thm2 | some => some;
+ fun add_thms thm1 thm2 =
+ (case add2 thm1 thm2 of
+ NONE =>
+ (case try_add ([thm1] RL inj_thms) thm2 of
+ NONE =>
+ (the (try_add ([thm2] RL inj_thms) thm1)
+ handle Option =>
+ (trace_thm "" thm1; trace_thm "" thm2;
+ sys_error "Linear arithmetic: failed to add thms"))
+ | SOME thm => thm)
+ | SOME thm => thm);
+
+ fun mult_by_add n thm =
+ let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
+ in mul n thm end;
- fun addthms thm1 thm2 =
- case add2 thm1 thm2 of
- NONE => (case try_add ([thm1] RL inj_thms) thm2 of
- NONE => ( the (try_add ([thm2] RL inj_thms) thm1)
- handle Option =>
- (trace_thm "" thm1; trace_thm "" thm2;
- sys_error "Linear arithmetic: failed to add thms")
- )
- | SOME thm => thm)
- | SOME thm => thm;
-
- fun multn(n,thm) =
- let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
- in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
+ val rewr = Simplifier.rewrite simpset';
+ val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
+ (Conv.binop_conv rewr)));
+ fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
+ let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
+ in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
- fun multn2(n,thm) =
- let val SOME(mth) =
- get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms
- fun cvar(th,_ $ (_ $ _ $ var)) = cterm_of (Thm.theory_of_thm th) var;
- val cv = cvar(mth, hd(prems_of mth));
- val ct = cterm_of thy (LA_Data.mk_number (#T (rep_cterm cv)) n)
- in instantiate ([],[(cv,ct)]) mth end
+ fun mult n thm =
+ (case use_first mult_mono_thms thm of
+ NONE => mult_by_add n thm
+ | SOME mth =>
+ let
+ val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
+ |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
+ val T = #T (Thm.rep_cterm cv)
+ in
+ mth
+ |> Thm.instantiate ([], [(cv, num_of thy T n)])
+ |> rewrite_concl
+ |> discharge_prem
+ handle CTERM _ => mult_by_add n thm
+ | THM _ => mult_by_add n thm
+ end);
- fun simp thm =
- let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm)
- in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
+ fun mult_thm (n, thm) =
+ if n = ~1 then thm RS LA_Logic.sym
+ else if n < 0 then mult (~n) thm RS LA_Logic.sym
+ else mult n thm;
+
+ fun simp thm =
+ let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm)
+ in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
- fun mk (Asm i) = trace_thm ("Asm " ^ string_of_int i) (nth asms i)
- | mk (Nat i) = trace_thm ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
- | mk (LessD j) = trace_thm "L" (hd ([mk j] RL lessD))
- | mk (NotLeD j) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
- | mk (NotLeDD j) = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
- | mk (NotLessD j) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
- | mk (Added (j1, j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
- | mk (Multiplied (n, j)) = (trace_msg ("*" ^ string_of_int n); trace_thm "*" (multn (n, mk j)))
- | mk (Multiplied2 (n, j)) = simp (trace_msg ("**" ^ string_of_int n); trace_thm "**" (multn2 (n, mk j)))
+ fun mk (Asm i) = trace_thm ("Asm " ^ string_of_int i) (nth asms i)
+ | mk (Nat i) = trace_thm ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
+ | mk (LessD j) = trace_thm "L" (hd ([mk j] RL lessD))
+ | mk (NotLeD j) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
+ | mk (NotLeDD j) = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
+ | mk (NotLessD j) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
+ | mk (Added (j1, j2)) = simp (trace_thm "+" (add_thms (mk j1) (mk j2)))
+ | mk (Multiplied (n, j)) = (trace_msg ("*" ^ string_of_int n); trace_thm "*" (mult_thm (n, mk j)))
+ | mk (Multiplied2 (n, j)) = (trace_msg ("**" ^ string_of_int n); trace_thm "**" (mult_thm (n, mk j)))
in
let
@@ -676,9 +704,6 @@
result
end;
-fun add_atoms (ats : term list, ((lhs,_,_,rhs,_,_) : LA_Data.decomp, _)) : term list =
- union_term (map fst lhs) (union_term (map fst rhs) ats);
-
fun add_datoms (dats : (bool * term) list, ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _)) :
(bool * term) list =
union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
@@ -691,7 +716,7 @@
let
fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
let
- val atoms = Library.foldl add_atoms ([], initems)
+ val atoms = atoms_of (map fst initems)
val n = length atoms
val mkleq = mklineq n atoms
val ixs = 0 upto (n - 1)