(* Title: HOL/arith_data.ML
ID: $Id$
Author: Markus Wenzel, Stefan Berghofer and Tobias Nipkow
Various arithmetic proof procedures.
*)
(*---------------------------------------------------------------------------*)
(* 1. Cancellation of common terms *)
(*---------------------------------------------------------------------------*)
structure NatArithUtils =
struct
(** abstract syntax of structure nat: 0, Suc, + **)
(* mk_sum, mk_norm_sum *)
val one = HOLogic.mk_nat 1;
val mk_plus = HOLogic.mk_binop "op +";
fun mk_sum [] = HOLogic.zero
| mk_sum [t] = t
| mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
(*normal form of sums: Suc (... (Suc (a + (b + ...))))*)
fun mk_norm_sum ts =
let val (ones, sums) = List.partition (equal one) ts in
funpow (length ones) HOLogic.mk_Suc (mk_sum sums)
end;
(* dest_sum *)
val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
fun dest_sum tm =
if HOLogic.is_zero tm then []
else
(case try HOLogic.dest_Suc tm of
SOME t => one :: dest_sum t
| NONE =>
(case try dest_plus tm of
SOME (t, u) => dest_sum t @ dest_sum u
| NONE => [tm]));
(** generic proof tools **)
(* prove conversions *)
val mk_eqv = HOLogic.mk_Trueprop o HOLogic.mk_eq;
fun prove_conv expand_tac norm_tac sg ss tu =
mk_meta_eq (prove_goalw_cterm_nocheck [] (cterm_of sg (mk_eqv tu))
(K [expand_tac, norm_tac ss]))
handle ERROR => error ("The error(s) above occurred while trying to prove " ^
(string_of_cterm (cterm_of sg (mk_eqv tu))));
val subst_equals = prove_goal HOL.thy "[| t = s; u = t |] ==> u = s"
(fn prems => [cut_facts_tac prems 1, SIMPSET' asm_simp_tac 1]);
(* rewriting *)
fun simp_all_tac rules ss =
ALLGOALS (simp_tac (Simplifier.inherit_bounds ss HOL_ss addsimps rules));
val add_rules = [add_Suc, add_Suc_right, add_0, add_0_right];
val mult_rules = [mult_Suc, mult_Suc_right, mult_0, mult_0_right];
fun prep_simproc (name, pats, proc) =
Simplifier.simproc (the_context ()) name pats proc;
end;
signature ARITH_DATA =
sig
val nat_cancel_sums_add: simproc list
val nat_cancel_sums: simproc list
end;
structure ArithData: ARITH_DATA =
struct
open NatArithUtils;
(** cancel common summands **)
structure Sum =
struct
val mk_sum = mk_norm_sum;
val dest_sum = dest_sum;
val prove_conv = prove_conv;
fun norm_tac ss = simp_all_tac add_rules ss THEN simp_all_tac add_ac ss;
end;
fun gen_uncancel_tac rule ct =
rtac (instantiate' [] [NONE, SOME ct] (rule RS subst_equals)) 1;
(* nat eq *)
structure EqCancelSums = CancelSumsFun
(struct
open Sum;
val mk_bal = HOLogic.mk_eq;
val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
val uncancel_tac = gen_uncancel_tac nat_add_left_cancel;
end);
(* nat less *)
structure LessCancelSums = CancelSumsFun
(struct
open Sum;
val mk_bal = HOLogic.mk_binrel "op <";
val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
val uncancel_tac = gen_uncancel_tac nat_add_left_cancel_less;
end);
(* nat le *)
structure LeCancelSums = CancelSumsFun
(struct
open Sum;
val mk_bal = HOLogic.mk_binrel "op <=";
val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
val uncancel_tac = gen_uncancel_tac nat_add_left_cancel_le;
end);
(* nat diff *)
structure DiffCancelSums = CancelSumsFun
(struct
open Sum;
val mk_bal = HOLogic.mk_binop "op -";
val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT;
val uncancel_tac = gen_uncancel_tac diff_cancel;
end);
(** prepare nat_cancel simprocs **)
val nat_cancel_sums_add = map prep_simproc
[("nateq_cancel_sums",
["(l::nat) + m = n", "(l::nat) = m + n", "Suc m = n", "m = Suc n"], EqCancelSums.proc),
("natless_cancel_sums",
["(l::nat) + m < n", "(l::nat) < m + n", "Suc m < n", "m < Suc n"], LessCancelSums.proc),
("natle_cancel_sums",
["(l::nat) + m <= n", "(l::nat) <= m + n", "Suc m <= n", "m <= Suc n"], LeCancelSums.proc)];
val nat_cancel_sums = nat_cancel_sums_add @
[prep_simproc ("natdiff_cancel_sums",
["((l::nat) + m) - n", "(l::nat) - (m + n)", "Suc m - n", "m - Suc n"], DiffCancelSums.proc)];
end;
open ArithData;
(*---------------------------------------------------------------------------*)
(* 2. Linear arithmetic *)
(*---------------------------------------------------------------------------*)
(* Parameters data for general linear arithmetic functor *)
structure LA_Logic: LIN_ARITH_LOGIC =
struct
val ccontr = ccontr;
val conjI = conjI;
val notI = notI;
val sym = sym;
val not_lessD = linorder_not_less RS iffD1;
val not_leD = linorder_not_le RS iffD1;
fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
val mk_Trueprop = HOLogic.mk_Trueprop;
fun atomize thm = case #prop(rep_thm thm) of
Const("Trueprop",_) $ (Const("op &",_) $ _ $ _) =>
atomize(thm RS conjunct1) @ atomize(thm RS conjunct2)
| _ => [thm];
fun neg_prop(TP$(Const("Not",_)$t)) = TP$t
| neg_prop(TP$t) = TP $ (Const("Not",HOLogic.boolT-->HOLogic.boolT)$t);
fun is_False thm =
let val _ $ t = #prop(rep_thm thm)
in t = Const("False",HOLogic.boolT) end;
fun is_nat(t) = fastype_of1 t = HOLogic.natT;
fun mk_nat_thm sg t =
let val ct = cterm_of sg t and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
in instantiate ([],[(cn,ct)]) le0 end;
end;
(* arith theory data *)
structure ArithTheoryData = TheoryDataFun
(struct
val name = "HOL/arith";
type T = {splits: thm list, inj_consts: (string * typ)list, discrete: string list, presburger: (int -> tactic) option};
val empty = {splits = [], inj_consts = [], discrete = [], presburger = NONE};
val copy = I;
val extend = I;
fun merge _ ({splits= splits1, inj_consts= inj_consts1, discrete= discrete1, presburger= presburger1},
{splits= splits2, inj_consts= inj_consts2, discrete= discrete2, presburger= presburger2}) =
{splits = Drule.merge_rules (splits1, splits2),
inj_consts = merge_lists inj_consts1 inj_consts2,
discrete = merge_lists discrete1 discrete2,
presburger = (case presburger1 of NONE => presburger2 | p => p)};
fun print _ _ = ();
end);
fun arith_split_add (thy, thm) = (ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
{splits= thm::splits, inj_consts= inj_consts, discrete= discrete, presburger= presburger}) thy, thm);
fun arith_discrete d = ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
{splits = splits, inj_consts = inj_consts, discrete = d :: discrete, presburger= presburger});
fun arith_inj_const c = ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
{splits = splits, inj_consts = c :: inj_consts, discrete = discrete, presburger = presburger});
structure LA_Data_Ref: LIN_ARITH_DATA =
struct
(* Decomposition of terms *)
fun nT (Type("fun",[N,_])) = N = HOLogic.natT
| nT _ = false;
fun add_atom(t,m,(p,i)) = (case AList.lookup (op =) p t of NONE => ((t, m) :: p, i)
| SOME n => (AList.update (op =) (t, ratadd (n, m)) p, i));
exception Zero;
fun rat_of_term(numt,dent) =
let val num = HOLogic.dest_binum numt and den = HOLogic.dest_binum dent
in if den = 0 then raise Zero else int_ratdiv(num,den) end;
(* Warning: in rare cases number_of encloses a non-numeral,
in which case dest_binum raises TERM; hence all the handles below.
Same for Suc-terms that turn out not to be numerals -
although the simplifier should eliminate those anyway...
*)
fun number_of_Sucs (Const("Suc",_) $ n) = number_of_Sucs n + 1
| number_of_Sucs t = if HOLogic.is_zero t then 0
else raise TERM("number_of_Sucs",[])
(* decompose nested multiplications, bracketing them to the right and combining all
their coefficients
*)
fun demult inj_consts =
let
fun demult((mC as Const("op *",_)) $ s $ t,m) = ((case s of
Const("Numeral.number_of",_)$n
=> demult(t,ratmul(m,rat_of_intinf(HOLogic.dest_binum n)))
| Const("uminus",_)$(Const("Numeral.number_of",_)$n)
=> demult(t,ratmul(m,rat_of_intinf(~(HOLogic.dest_binum n))))
| Const("Suc",_) $ _
=> demult(t,ratmul(m,rat_of_int(number_of_Sucs s)))
| Const("op *",_) $ s1 $ s2 => demult(mC $ s1 $ (mC $ s2 $ t),m)
| Const("HOL.divide",_) $ numt $ (Const("Numeral.number_of",_)$dent) =>
let val den = HOLogic.dest_binum dent
in if den = 0 then raise Zero
else demult(mC $ numt $ t,ratmul(m, ratinv(rat_of_intinf den)))
end
| _ => atomult(mC,s,t,m)
) handle TERM _ => atomult(mC,s,t,m))
| demult(atom as Const("HOL.divide",_) $ t $ (Const("Numeral.number_of",_)$dent), m) =
(let val den = HOLogic.dest_binum dent
in if den = 0 then raise Zero else demult(t,ratmul(m, ratinv(rat_of_intinf den))) end
handle TERM _ => (SOME atom,m))
| demult(Const("0",_),m) = (NONE, rat_of_int 0)
| demult(Const("1",_),m) = (NONE, m)
| demult(t as Const("Numeral.number_of",_)$n,m) =
((NONE,ratmul(m,rat_of_intinf(HOLogic.dest_binum n)))
handle TERM _ => (SOME t,m))
| demult(Const("uminus",_)$t, m) = demult(t,ratmul(m,rat_of_int(~1)))
| demult(t as Const f $ x, m) =
(if f mem inj_consts then SOME x else SOME t,m)
| demult(atom,m) = (SOME atom,m)
and atomult(mC,atom,t,m) = (case demult(t,m) of (NONE,m') => (SOME atom,m')
| (SOME t',m') => (SOME(mC $ atom $ t'),m'))
in demult end;
fun decomp2 inj_consts (rel,lhs,rhs) =
let
(* Turn term into list of summand * multiplicity plus a constant *)
fun poly(Const("op +",_) $ s $ t, m, pi) = poly(s,m,poly(t,m,pi))
| poly(all as Const("op -",T) $ s $ t, m, pi) =
if nT T then add_atom(all,m,pi) else poly(s,m,poly(t,ratneg m,pi))
| poly(all as Const("uminus",T) $ t, m, pi) =
if nT T then add_atom(all,m,pi) else poly(t,ratneg m,pi)
| poly(Const("0",_), _, pi) = pi
| poly(Const("1",_), m, (p,i)) = (p,ratadd(i,m))
| poly(Const("Suc",_)$t, m, (p,i)) = poly(t, m, (p,ratadd(i,m)))
| poly(t as Const("op *",_) $ _ $ _, m, pi as (p,i)) =
(case demult inj_consts (t,m) of
(NONE,m') => (p,ratadd(i,m))
| (SOME u,m') => add_atom(u,m',pi))
| poly(t as Const("HOL.divide",_) $ _ $ _, m, pi as (p,i)) =
(case demult inj_consts (t,m) of
(NONE,m') => (p,ratadd(i,m'))
| (SOME u,m') => add_atom(u,m',pi))
| poly(all as (Const("Numeral.number_of",_)$t,m,(p,i))) =
((p,ratadd(i,ratmul(m,rat_of_intinf(HOLogic.dest_binum t))))
handle TERM _ => add_atom all)
| poly(all as Const f $ x, m, pi) =
if f mem inj_consts then poly(x,m,pi) else add_atom(all,m,pi)
| poly x = add_atom x;
val (p,i) = poly(lhs,rat_of_int 1,([],rat_of_int 0))
and (q,j) = poly(rhs,rat_of_int 1,([],rat_of_int 0))
in case rel of
"op <" => SOME(p,i,"<",q,j)
| "op <=" => SOME(p,i,"<=",q,j)
| "op =" => SOME(p,i,"=",q,j)
| _ => NONE
end handle Zero => NONE;
fun negate(SOME(x,i,rel,y,j,d)) = SOME(x,i,"~"^rel,y,j,d)
| negate NONE = NONE;
fun of_lin_arith_sort sg U =
Type.of_sort (Sign.tsig_of sg) (U,["Ring_and_Field.ordered_idom"])
fun allows_lin_arith sg discrete (U as Type(D,[])) =
if of_lin_arith_sort sg U
then (true, D mem discrete)
else (* special cases *)
if D mem discrete then (true,true) else (false,false)
| allows_lin_arith sg discrete U = (of_lin_arith_sort sg U, false);
fun decomp1 (sg,discrete,inj_consts) (T,xxx) =
(case T of
Type("fun",[U,_]) =>
(case allows_lin_arith sg discrete U of
(true,d) => (case decomp2 inj_consts xxx of NONE => NONE
| SOME(p,i,rel,q,j) => SOME(p,i,rel,q,j,d))
| (false,_) => NONE)
| _ => NONE);
fun decomp2 data (_$(Const(rel,T)$lhs$rhs)) = decomp1 data (T,(rel,lhs,rhs))
| decomp2 data (_$(Const("Not",_)$(Const(rel,T)$lhs$rhs))) =
negate(decomp1 data (T,(rel,lhs,rhs)))
| decomp2 data _ = NONE
fun decomp sg =
let val {discrete, inj_consts, ...} = ArithTheoryData.get sg
in decomp2 (sg,discrete,inj_consts) end
fun number_of(n,T) = HOLogic.number_of_const T $ (HOLogic.mk_bin n)
end;
structure Fast_Arith =
Fast_Lin_Arith(structure LA_Logic=LA_Logic and LA_Data=LA_Data_Ref);
val fast_arith_tac = Fast_Arith.lin_arith_tac false
and fast_ex_arith_tac = Fast_Arith.lin_arith_tac
and trace_arith = Fast_Arith.trace
and fast_arith_neq_limit = Fast_Arith.fast_arith_neq_limit;
local
val isolateSuc =
let val thy = theory "Nat"
in prove_goal thy "Suc(i+j) = i+j + Suc 0"
(fn _ => [simp_tac (simpset_of thy) 1])
end;
(* reduce contradictory <= to False.
Most of the work is done by the cancel tactics.
*)
val add_rules =
[add_zero_left,add_zero_right,Zero_not_Suc,Suc_not_Zero,le_0_eq,
One_nat_def,isolateSuc,
order_less_irrefl, zero_neq_one, zero_less_one, zero_le_one,
zero_neq_one RS not_sym, not_one_le_zero, not_one_less_zero];
val add_mono_thms_ordered_semiring = map (fn s => prove_goal (the_context ()) s
(fn prems => [cut_facts_tac prems 1,
blast_tac (claset() addIs [add_mono]) 1]))
["(i <= j) & (k <= l) ==> i + k <= j + (l::'a::pordered_ab_semigroup_add)",
"(i = j) & (k <= l) ==> i + k <= j + (l::'a::pordered_ab_semigroup_add)",
"(i <= j) & (k = l) ==> i + k <= j + (l::'a::pordered_ab_semigroup_add)",
"(i = j) & (k = l) ==> i + k = j + (l::'a::pordered_ab_semigroup_add)"
];
val mono_ss = simpset() addsimps
[add_mono,add_strict_mono,add_less_le_mono,add_le_less_mono];
val add_mono_thms_ordered_field =
map (fn s => prove_goal (the_context ()) s
(fn prems => [cut_facts_tac prems 1, asm_simp_tac mono_ss 1]))
["(i<j) & (k=l) ==> i+k < j+(l::'a::pordered_cancel_ab_semigroup_add)",
"(i=j) & (k<l) ==> i+k < j+(l::'a::pordered_cancel_ab_semigroup_add)",
"(i<j) & (k<=l) ==> i+k < j+(l::'a::pordered_cancel_ab_semigroup_add)",
"(i<=j) & (k<l) ==> i+k < j+(l::'a::pordered_cancel_ab_semigroup_add)",
"(i<j) & (k<l) ==> i+k < j+(l::'a::pordered_cancel_ab_semigroup_add)"];
in
val init_lin_arith_data =
Fast_Arith.setup @
[Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, ...} =>
{add_mono_thms = add_mono_thms @
add_mono_thms_ordered_semiring @ add_mono_thms_ordered_field,
mult_mono_thms = mult_mono_thms,
inj_thms = inj_thms,
lessD = lessD @ [Suc_leI],
neqE = [linorder_neqE_nat,
get_thm (theory "Ring_and_Field") (Name "linorder_neqE_ordered_idom")],
simpset = HOL_basic_ss addsimps add_rules
addsimprocs [ab_group_add_cancel.sum_conv,
ab_group_add_cancel.rel_conv]
(*abel_cancel helps it work in abstract algebraic domains*)
addsimprocs nat_cancel_sums_add}),
ArithTheoryData.init, arith_discrete "nat"];
end;
val fast_nat_arith_simproc =
Simplifier.simproc (the_context ()) "fast_nat_arith"
["(m::nat) < n","(m::nat) <= n", "(m::nat) = n"] Fast_Arith.lin_arith_prover;
(* Because of fast_nat_arith_simproc, the arithmetic solver is really only
useful to detect inconsistencies among the premises for subgoals which are
*not* themselves (in)equalities, because the latter activate
fast_nat_arith_simproc anyway. However, it seems cheaper to activate the
solver all the time rather than add the additional check. *)
(* arith proof method *)
(* FIXME: K true should be replaced by a sensible test to speed things up
in case there are lots of irrelevant terms involved;
elimination of min/max can be optimized:
(max m n + k <= r) = (m+k <= r & n+k <= r)
(l <= min m n + k) = (l <= m+k & l <= n+k)
*)
local
fun raw_arith_tac ex i st =
refute_tac (K true)
(REPEAT o split_tac (#splits (ArithTheoryData.get (Thm.theory_of_thm st))))
((REPEAT_DETERM o etac linorder_neqE) THEN' fast_ex_arith_tac ex)
i st;
fun presburger_tac i st =
(case ArithTheoryData.get (Thm.theory_of_thm st) of
{presburger = SOME tac, ...} =>
(warning "Trying full Presburger arithmetic ..."; tac i st)
| _ => no_tac st);
in
val simple_arith_tac = FIRST' [fast_arith_tac,
ObjectLogic.atomize_tac THEN' raw_arith_tac true];
val arith_tac = FIRST' [fast_arith_tac,
ObjectLogic.atomize_tac THEN' raw_arith_tac true,
presburger_tac];
val silent_arith_tac = FIRST' [fast_arith_tac,
ObjectLogic.atomize_tac THEN' raw_arith_tac false,
presburger_tac];
fun arith_method prems =
Method.METHOD (fn facts => HEADGOAL (Method.insert_tac (prems @ facts) THEN' arith_tac));
end;
(* antisymmetry:
combines x <= y (or ~(y < x)) and y <= x (or ~(x < y)) into x = y
local
val antisym = mk_meta_eq order_antisym
val not_lessD = linorder_not_less RS iffD1
fun prp t thm = (#prop(rep_thm thm) = t)
in
fun antisym_eq prems thm =
let
val r = #prop(rep_thm thm);
in
case r of
Tr $ ((c as Const("op <=",T)) $ s $ t) =>
let val r' = Tr $ (c $ t $ s)
in
case Library.find_first (prp r') prems of
NONE =>
let val r' = Tr $ (HOLogic.Not $ (Const("op <",T) $ s $ t))
in case Library.find_first (prp r') prems of
NONE => []
| SOME thm' => [(thm' RS not_lessD) RS (thm RS antisym)]
end
| SOME thm' => [thm' RS (thm RS antisym)]
end
| Tr $ (Const("Not",_) $ (Const("op <",T) $ s $ t)) =>
let val r' = Tr $ (Const("op <=",T) $ s $ t)
in
case Library.find_first (prp r') prems of
NONE =>
let val r' = Tr $ (HOLogic.Not $ (Const("op <",T) $ t $ s))
in case Library.find_first (prp r') prems of
NONE => []
| SOME thm' =>
[(thm' RS not_lessD) RS ((thm RS not_lessD) RS antisym)]
end
| SOME thm' => [thm' RS ((thm RS not_lessD) RS antisym)]
end
| _ => []
end
handle THM _ => []
end;
*)
(* theory setup *)
val arith_setup =
[Simplifier.change_simpset_of (op addsimprocs) nat_cancel_sums] @
init_lin_arith_data @
[Simplifier.change_simpset_of (op addSolver)
(mk_solver' "lin. arith." Fast_Arith.cut_lin_arith_tac),
Simplifier.change_simpset_of (op addsimprocs) [fast_nat_arith_simproc],
Method.add_methods
[("arith", (arith_method o #2) oo Method.syntax Args.bang_facts,
"decide linear arithmethic")],
Attrib.add_attributes [("arith_split",
(Attrib.no_args arith_split_add,
Attrib.no_args Attrib.undef_local_attribute),
"declaration of split rules for arithmetic procedure")]];