# HG changeset patch # User boehmes # Date 1244492977 -7200 # Node ID e0f2bb4b002167070c1fe1443fe35f4a380abc13 # Parent 00ede188c5d6cded86581ddf300f2fcaa918cd84 fast_lin_arith uses proper multiplication instead of unfolding to additions diff -r 00ede188c5d6 -r e0f2bb4b0021 src/HOL/Tools/int_arith.ML --- a/src/HOL/Tools/int_arith.ML Mon Jun 08 20:43:57 2009 +0200 +++ b/src/HOL/Tools/int_arith.ML Mon Jun 08 22:29:37 2009 +0200 @@ -87,6 +87,12 @@ val global_setup = Simplifier.map_simpset (fn simpset => simpset addsimprocs [fast_int_arith_simproc]); + +fun number_of thy T n = + if not (Sign.of_sort thy (T, @{sort number})) + then raise CTERM ("number_of", []) + else Numeral.mk_cnumber (Thm.ctyp_of thy T) n + val setup = Lin_Arith.add_inj_thms [@{thm zle_int} RS iffD2, @{thm int_int_eq} RS iffD2] #> Lin_Arith.add_lessD @{thm zless_imp_add1_zle} @@ -95,6 +101,7 @@ #> Lin_Arith.add_simprocs (Numeral_Simprocs.assoc_fold_simproc :: zero_one_idom_simproc :: Numeral_Simprocs.combine_numerals :: Numeral_Simprocs.cancel_numerals) + #> Lin_Arith.set_number_of number_of #> Lin_Arith.add_inj_const (@{const_name of_nat}, HOLogic.natT --> HOLogic.intT) #> Lin_Arith.add_discrete_type @{type_name Int.int} diff -r 00ede188c5d6 -r e0f2bb4b0021 src/HOL/Tools/lin_arith.ML --- a/src/HOL/Tools/lin_arith.ML Mon Jun 08 20:43:57 2009 +0200 +++ b/src/HOL/Tools/lin_arith.ML Mon Jun 08 22:29:37 2009 +0200 @@ -16,6 +16,8 @@ val add_simprocs: simproc list -> Context.generic -> Context.generic val add_inj_const: string * typ -> Context.generic -> Context.generic val add_discrete_type: string -> Context.generic -> Context.generic + val set_number_of: (theory -> typ -> int -> cterm) -> Context.generic -> + Context.generic val setup: Context.generic -> Context.generic val global_setup: theory -> theory val split_limit: int Config.T @@ -36,6 +38,7 @@ val conjI = conjI; val notI = notI; val sym = sym; +val trueI = TrueI; val not_lessD = @{thm linorder_not_less} RS iffD1; val not_leD = @{thm linorder_not_le} RS iffD1; @@ -274,7 +277,6 @@ | domain_is_nat (_ $ (Const ("Not", _) $ (Const (_, T) $ _ $ _))) = nT T | domain_is_nat _ = false; -val mk_number = HOLogic.mk_number; (*---------------------------------------------------------------------------*) (* the following code performs splitting of certain constants (e.g. min, *) @@ -752,23 +754,30 @@ val map_data = Fast_Arith.map_data; -fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} = +fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} = {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms, - lessD = lessD, neqE = neqE, simpset = simpset}; + lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of}; -fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} = +fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} = {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms, - lessD = f lessD, neqE = neqE, simpset = simpset}; + lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of}; -fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} = +fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} = {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms, - lessD = lessD, neqE = neqE, simpset = f simpset}; + lessD = lessD, neqE = neqE, simpset = f simpset, number_of = number_of}; + +fun map_number_of f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} = + {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms, + lessD = lessD, neqE = neqE, simpset = simpset, number_of = f number_of}; fun add_inj_thms thms = Fast_Arith.map_data (map_inj_thms (append thms)); fun add_lessD thm = Fast_Arith.map_data (map_lessD (fn thms => thms @ [thm])); fun add_simps thms = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimps thms)); fun add_simprocs procs = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimprocs procs)); +fun set_number_of f = Fast_Arith.map_data (map_number_of (K (serial (), f))) + + fun simple_tac ctxt = Fast_Arith.lin_arith_tac ctxt false; val lin_arith_tac = Fast_Arith.lin_arith_tac; val trace = Fast_Arith.trace; @@ -778,13 +787,16 @@ Most of the work is done by the cancel tactics. *) val init_arith_data = - Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, ...} => + Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, number_of, ...} => {add_mono_thms = @{thms add_mono_thms_ordered_semiring} @ @{thms add_mono_thms_ordered_field} @ add_mono_thms, - mult_mono_thms = @{thm mult_strict_left_mono} :: @{thm mult_left_mono} :: mult_mono_thms, + mult_mono_thms = @{thm mult_strict_left_mono} :: @{thm mult_left_mono} :: + @{lemma "a = b ==> c*a = c*b" by (rule arg_cong)} :: mult_mono_thms, inj_thms = inj_thms, lessD = lessD @ [@{thm "Suc_leI"}], neqE = [@{thm linorder_neqE_nat}, @{thm linorder_neqE_ordered_idom}], simpset = HOL_basic_ss + addsimps @{thms ring_distribs} + addsimps [@{thm if_True}, @{thm if_False}] addsimps [@{thm "monoid_add_class.add_0_left"}, @{thm "monoid_add_class.add_0_right"}, @@ -795,7 +807,8 @@ addsimprocs [ab_group_add_cancel.sum_conv, ab_group_add_cancel.rel_conv] (*abel_cancel helps it work in abstract algebraic domains*) addsimprocs Nat_Arith.nat_cancel_sums_add - addcongs [if_weak_cong]}) #> + addcongs [if_weak_cong], + number_of = number_of}) #> add_discrete_type @{type_name nat}; fun add_arith_facts ss = diff -r 00ede188c5d6 -r e0f2bb4b0021 src/Provers/Arith/fast_lin_arith.ML --- a/src/Provers/Arith/fast_lin_arith.ML Mon Jun 08 20:43:57 2009 +0200 +++ b/src/Provers/Arith/fast_lin_arith.ML Mon Jun 08 22:29:37 2009 +0200 @@ -1,6 +1,6 @@ (* Title: Provers/Arith/fast_lin_arith.ML ID: $Id$ - Author: Tobias Nipkow and Tjark Weber + Author: Tobias Nipkow and Tjark Weber and Sascha Boehme A generic linear arithmetic package. It provides two tactics (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure @@ -21,6 +21,7 @@ val not_lessD : thm (* ~(m < n) ==> n <= m *) val not_leD : thm (* ~(m <= n) ==> n < m *) val sym : thm (* x = y ==> y = x *) + val trueI : thm (* True *) val mk_Eq : thm -> thm val atomize : thm -> thm list val mk_Trueprop : term -> term @@ -56,7 +57,6 @@ (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*) val pre_tac: Proof.context -> int -> tactic - val mk_number: typ -> int -> term (*the limit on the number of ~= allowed; because each ~= is split into two cases, this can lead to an explosion*) @@ -90,9 +90,11 @@ val lin_arith_tac: Proof.context -> bool -> int -> tactic val lin_arith_simproc: simpset -> term -> thm option val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list, - lessD: thm list, neqE: thm list, simpset: Simplifier.simpset} + lessD: thm list, neqE: thm list, simpset: Simplifier.simpset, + number_of : serial * (theory -> typ -> int -> cterm)} -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list, - lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}) + lessD: thm list, neqE: thm list, simpset: Simplifier.simpset, + number_of : serial * (theory -> typ -> int -> cterm)}) -> Context.generic -> Context.generic val trace: bool ref val warning_count: int ref; @@ -105,6 +107,8 @@ (** theory data **) +fun no_number_of _ _ _ = raise CTERM ("number_of", []) + structure Data = GenericDataFun ( type T = @@ -113,22 +117,27 @@ inj_thms: thm list, lessD: thm list, neqE: thm list, - simpset: Simplifier.simpset}; + simpset: Simplifier.simpset, + number_of : serial * (theory -> typ -> int -> cterm)}; val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [], - lessD = [], neqE = [], simpset = Simplifier.empty_ss}; + lessD = [], neqE = [], simpset = Simplifier.empty_ss, + number_of = (serial (), no_number_of) }; val extend = I; fun merge _ ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1, - lessD = lessD1, neqE=neqE1, simpset = simpset1}, + lessD = lessD1, neqE=neqE1, simpset = simpset1, + number_of = (number_of1 as (s1, _))}, {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2, - lessD = lessD2, neqE=neqE2, simpset = simpset2}) = + lessD = lessD2, neqE=neqE2, simpset = simpset2, + number_of = (number_of2 as (s2, _))}) = {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2), mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2), inj_thms = Thm.merge_thms (inj_thms1, inj_thms2), lessD = Thm.merge_thms (lessD1, lessD2), neqE = Thm.merge_thms (neqE1, neqE2), - simpset = Simplifier.merge_ss (simpset1, simpset2)}; + simpset = Simplifier.merge_ss (simpset1, simpset2), + number_of = if s1 > s2 then number_of1 else number_of2}; ); val map_data = Data.map; @@ -320,7 +329,7 @@ else (m1,m2) val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1) then (~n1,~n2) else (n1,n2) - in add_ineq (multiply_ineq n1 i1) (multiply_ineq n2 i2) end; + in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end; (* ------------------------------------------------------------------------- *) (* The main refutation-finding code. *) @@ -328,7 +337,7 @@ fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l; -fun is_answer (ans as Lineq(k,ty,l,_)) = +fun is_contradictory (ans as Lineq(k,ty,l,_)) = case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0; fun calc_blowup l = @@ -347,13 +356,10 @@ (* blowup (number of consequences generated) and eliminates it. *) (* ------------------------------------------------------------------------- *) -fun allpairs f xs ys = - maps (fn x => map (fn y => f x y) ys) xs; - fun extract_first p = - let fun extract xs (y::ys) = if p y then (SOME y,xs@ys) - else extract (y::xs) ys - | extract xs [] = (NONE,xs) + let + fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys + | extract xs [] = raise Empty in extract [] end; fun print_ineqs ineqs = @@ -368,10 +374,10 @@ datatype result = Success of injust | Failure of history; fun elim (ineqs, hist) = - let val dummy = print_ineqs ineqs + let val _ = print_ineqs ineqs val (triv, nontriv) = List.partition is_trivial ineqs in if not (null triv) - then case Library.find_first is_answer triv of + then case Library.find_first is_contradictory triv of NONE => elim (nontriv, hist) | SOME(Lineq(_,_,_,j)) => Success j else @@ -379,11 +385,12 @@ else let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in if not (null eqs) then - let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs) - val sclist = sort (fn (x,y) => int_ord (abs x, abs y)) - (List.filter (fn i => i<>0) clist) - val c = hd sclist - val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) = + let val c = + fold (fn Lineq(_,_,l,_) => fn cs => l union cs) eqs [] + |> filter (fn i => i <> 0) + |> sort (int_ord o pairself abs) + |> hd + val (eq as Lineq(_,_,ceq,_),othereqs) = extract_first (fn Lineq(_,_,l,_) => c mem l) eqs val v = find_index_eq c ceq val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) @@ -402,7 +409,7 @@ let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows) val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes - in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end + in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end end end end; @@ -427,11 +434,12 @@ val union_bterm = curry (gen_union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'))); -(* FIXME OPTIMIZE!!!! (partly done already) - Addition/Multiplication need i*t representation rather than t+t+... - Get rid of Mulitplied(2). For Nat LA_Data.mk_number should return Suc^n - because Numerals are not known early enough. +fun add_atoms (lhs, _, _, rhs, _, _) = + union_term (map fst lhs) o union_term (map fst rhs); +fun atoms_of ds = fold add_atoms ds []; + +(* Simplification may detect a contradiction 'prematurely' due to type information: n+1 <= 0 is simplified to False and does not need to be crossed with 0 <= n. @@ -444,58 +452,78 @@ let val ctxt = Simplifier.the_context ss; val thy = ProofContext.theory_of ctxt; - val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, ...} = get_data ctxt; + val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset, + number_of = (_, num_of), ...} = get_data ctxt; val simpset' = Simplifier.inherit_context ss simpset; - val atoms = Library.foldl (fn (ats, (lhs,_,_,rhs,_,_)) => - union_term (map fst lhs) (union_term (map fst rhs) ats)) - ([], List.mapPartial (fn thm => if Thm.no_prems thm - then LA_Data.decomp ctxt (Thm.concl_of thm) - else NONE) asms) + fun only_concl f thm = + if Thm.no_prems thm then f (Thm.concl_of thm) else NONE; + val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms); + + fun use_first rules thm = + get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules + + fun add2 thm1 thm2 = + use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI)); + fun try_add thms thm = get_first (fn th => add2 th thm) thms; - fun add2 thm1 thm2 = - let val conj = thm1 RS (thm2 RS LA_Logic.conjI) - in get_first (fn th => SOME(conj RS th) handle THM _ => NONE) add_mono_thms - end; - fun try_add [] _ = NONE - | try_add (thm1::thm1s) thm2 = case add2 thm1 thm2 of - NONE => try_add thm1s thm2 | some => some; + fun add_thms thm1 thm2 = + (case add2 thm1 thm2 of + NONE => + (case try_add ([thm1] RL inj_thms) thm2 of + NONE => + (the (try_add ([thm2] RL inj_thms) thm1) + handle Option => + (trace_thm "" thm1; trace_thm "" thm2; + sys_error "Linear arithmetic: failed to add thms")) + | SOME thm => thm) + | SOME thm => thm); + + fun mult_by_add n thm = + let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th) + in mul n thm end; - fun addthms thm1 thm2 = - case add2 thm1 thm2 of - NONE => (case try_add ([thm1] RL inj_thms) thm2 of - NONE => ( the (try_add ([thm2] RL inj_thms) thm1) - handle Option => - (trace_thm "" thm1; trace_thm "" thm2; - sys_error "Linear arithmetic: failed to add thms") - ) - | SOME thm => thm) - | SOME thm => thm; - - fun multn(n,thm) = - let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th) - in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end; + val rewr = Simplifier.rewrite simpset'; + val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv + (Conv.binop_conv rewr))); + fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else + let val cv = Conv.arg1_conv (Conv.arg_conv rewr) + in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end - fun multn2(n,thm) = - let val SOME(mth) = - get_first (fn th => SOME(thm RS th) handle THM _ => NONE) mult_mono_thms - fun cvar(th,_ $ (_ $ _ $ var)) = cterm_of (Thm.theory_of_thm th) var; - val cv = cvar(mth, hd(prems_of mth)); - val ct = cterm_of thy (LA_Data.mk_number (#T (rep_cterm cv)) n) - in instantiate ([],[(cv,ct)]) mth end + fun mult n thm = + (case use_first mult_mono_thms thm of + NONE => mult_by_add n thm + | SOME mth => + let + val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl + |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1 + val T = #T (Thm.rep_cterm cv) + in + mth + |> Thm.instantiate ([], [(cv, num_of thy T n)]) + |> rewrite_concl + |> discharge_prem + handle CTERM _ => mult_by_add n thm + | THM _ => mult_by_add n thm + end); - fun simp thm = - let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm) - in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end + fun mult_thm (n, thm) = + if n = ~1 then thm RS LA_Logic.sym + else if n < 0 then mult (~n) thm RS LA_Logic.sym + else mult n thm; + + fun simp thm = + let val thm' = trace_thm "Simplified:" (full_simplify simpset' thm) + in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end; - fun mk (Asm i) = trace_thm ("Asm " ^ string_of_int i) (nth asms i) - | mk (Nat i) = trace_thm ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i)) - | mk (LessD j) = trace_thm "L" (hd ([mk j] RL lessD)) - | mk (NotLeD j) = trace_thm "NLe" (mk j RS LA_Logic.not_leD) - | mk (NotLeDD j) = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD)) - | mk (NotLessD j) = trace_thm "NL" (mk j RS LA_Logic.not_lessD) - | mk (Added (j1, j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2))) - | mk (Multiplied (n, j)) = (trace_msg ("*" ^ string_of_int n); trace_thm "*" (multn (n, mk j))) - | mk (Multiplied2 (n, j)) = simp (trace_msg ("**" ^ string_of_int n); trace_thm "**" (multn2 (n, mk j))) + fun mk (Asm i) = trace_thm ("Asm " ^ string_of_int i) (nth asms i) + | mk (Nat i) = trace_thm ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i)) + | mk (LessD j) = trace_thm "L" (hd ([mk j] RL lessD)) + | mk (NotLeD j) = trace_thm "NLe" (mk j RS LA_Logic.not_leD) + | mk (NotLeDD j) = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD)) + | mk (NotLessD j) = trace_thm "NL" (mk j RS LA_Logic.not_lessD) + | mk (Added (j1, j2)) = simp (trace_thm "+" (add_thms (mk j1) (mk j2))) + | mk (Multiplied (n, j)) = (trace_msg ("*" ^ string_of_int n); trace_thm "*" (mult_thm (n, mk j))) + | mk (Multiplied2 (n, j)) = (trace_msg ("**" ^ string_of_int n); trace_thm "**" (mult_thm (n, mk j))) in let @@ -676,9 +704,6 @@ result end; -fun add_atoms (ats : term list, ((lhs,_,_,rhs,_,_) : LA_Data.decomp, _)) : term list = - union_term (map fst lhs) (union_term (map fst rhs) ats); - fun add_datoms (dats : (bool * term) list, ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _)) : (bool * term) list = union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats); @@ -691,7 +716,7 @@ let fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) = let - val atoms = Library.foldl add_atoms ([], initems) + val atoms = atoms_of (map fst initems) val n = length atoms val mkleq = mklineq n atoms val ixs = 0 upto (n - 1)