diff -r 19b47bfac6ef -r 9e7d1c139569 src/HOL/Tools/semiring_normalizer.ML --- a/src/HOL/Tools/semiring_normalizer.ML Tue Apr 16 17:54:14 2013 +0200 +++ b/src/HOL/Tools/semiring_normalizer.ML Thu Apr 18 17:07:01 2013 +0200 @@ -27,10 +27,20 @@ val semiring_normalizers_conv: cterm list -> cterm list * thm list -> cterm list * thm list -> cterm list * thm list -> (cterm -> bool) * conv * conv * conv -> (cterm -> cterm -> bool) -> - {add: conv, mul: conv, neg: conv, main: conv, pow: conv, sub: conv} + {add: Proof.context -> conv, + mul: Proof.context -> conv, + neg: Proof.context -> conv, + main: Proof.context -> conv, + pow: Proof.context -> conv, + sub: Proof.context -> conv} val semiring_normalizers_ord_wrapper: Proof.context -> entry -> (cterm -> cterm -> bool) -> - {add: conv, mul: conv, neg: conv, main: conv, pow: conv, sub: conv} + {add: Proof.context -> conv, + mul: Proof.context -> conv, + neg: Proof.context -> conv, + main: Proof.context -> conv, + pow: Proof.context -> conv, + sub: Proof.context -> conv} val setup: theory -> theory end @@ -177,9 +187,9 @@ handle TERM _ => error "ring_dest_const")), mk_const = fn phi => fn cT => fn x => Numeral.mk_cnumber cT (case Rat.quotient_of_rat x of (i, 1) => i | _ => error "int_of_rat: bad int"), - conv = fn phi => fn _ => Simplifier.rewrite (HOL_basic_ss addsimps @{thms semiring_norm}) - then_conv Simplifier.rewrite (HOL_basic_ss addsimps - @{thms numeral_1_eq_1})}; + conv = fn phi => fn ctxt => + Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps @{thms semiring_norm}) + then_conv Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps @{thms numeral_1_eq_1})}; fun field_funs key = let @@ -208,7 +218,7 @@ {is_const = K numeral_is_const, dest_const = K dest_const, mk_const = mk_const, - conv = K (K Numeral_Simprocs.field_comp_conv)} + conv = K Numeral_Simprocs.field_comp_conv} end; @@ -236,23 +246,26 @@ val dest_numeral = term_of #> HOLogic.dest_number #> snd; val is_numeral = can dest_numeral; -val numeral01_conv = Simplifier.rewrite - (HOL_basic_ss addsimps [@{thm numeral_1_eq_1}]); -val zero1_numeral_conv = - Simplifier.rewrite (HOL_basic_ss addsimps [@{thm numeral_1_eq_1} RS sym]); -fun zerone_conv cv = zero1_numeral_conv then_conv cv then_conv numeral01_conv; +fun numeral01_conv ctxt = + Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1}]); + +fun zero1_numeral_conv ctxt = + Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1} RS sym]); + +fun zerone_conv ctxt cv = + zero1_numeral_conv ctxt then_conv cv then_conv numeral01_conv ctxt; val natarith = [@{thm "numeral_plus_numeral"}, @{thm "diff_nat_numeral"}, @{thm "numeral_times_numeral"}, @{thm "numeral_eq_iff"}, @{thm "numeral_less_iff"}]; -val nat_add_conv = - zerone_conv - (Simplifier.rewrite - (HOL_basic_ss - addsimps @{thms arith_simps} @ natarith @ @{thms rel_simps} - @ [@{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}, - @{thm add_numeral_left}, @{thm Suc_eq_plus1}] - @ map (fn th => th RS sym) @{thms numerals})); +fun nat_add_conv ctxt = + zerone_conv ctxt + (Simplifier.rewrite + (put_simpset HOL_basic_ss ctxt + addsimps @{thms arith_simps} @ natarith @ @{thms rel_simps} + @ [@{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}, + @{thm add_numeral_left}, @{thm Suc_eq_plus1}] + @ map (fn th => th RS sym) @{thms numerals})); val zeron_tm = @{cterm "0::nat"}; val onen_tm = @{cterm "1::nat"}; @@ -316,7 +329,7 @@ (* Also deals with "const * const", but both terms must involve powers of *) (* the same variable, or both be constants, or behaviour may be incorrect. *) - fun powvar_mul_conv tm = + fun powvar_mul_conv ctxt tm = let val (l,r) = dest_mul tm in if is_semiring_constant l andalso is_semiring_constant r @@ -328,16 +341,16 @@ ((let val (rx,rn) = dest_pow r val th1 = inst_thm [(cx,lx),(cp,ln),(cq,rn)] pthm_29 val (tm1,tm2) = Thm.dest_comb(concl th1) in - Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end) + Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end) handle CTERM _ => (let val th1 = inst_thm [(cx,lx),(cq,ln)] pthm_31 val (tm1,tm2) = Thm.dest_comb(concl th1) in - Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end)) end) + Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end)) end) handle CTERM _ => ((let val (rx,rn) = dest_pow r val th1 = inst_thm [(cx,rx),(cq,rn)] pthm_30 val (tm1,tm2) = Thm.dest_comb(concl th1) in - Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end) + Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end) handle CTERM _ => inst_thm [(cx,l)] pthm_32 )) @@ -353,7 +366,7 @@ (* Conversion for "(monomial)^n", where n is a numeral. *) - val monomial_pow_conv = + fun monomial_pow_conv ctxt = let fun monomial_pow tm bod ntm = if not(is_comb bod) @@ -374,7 +387,7 @@ then let val th1 = inst_thm [(cx,l),(cp,r),(cq,ntm)] pthm_34 val (l,r) = Thm.dest_comb(concl th1) - in Thm.transitive th1 (Drule.arg_cong_rule l (nat_add_conv r)) + in Thm.transitive th1 (Drule.arg_cong_rule l (nat_add_conv ctxt r)) end else if opr aconvc mul_tm @@ -405,7 +418,7 @@ end; (* Multiplication of canonical monomials. *) - val monomial_mul_conv = + fun monomial_mul_conv ctxt = let fun powvar tm = if is_semiring_constant tm then one_tm @@ -435,7 +448,7 @@ val th1 = inst_thm [(clx,lx),(cly,ly),(crx,rx),(cry,ry)] pthm_15 val (tm1,tm2) = Thm.dest_comb(concl th1) val (tm3,tm4) = Thm.dest_comb tm1 - val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2 + val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2 val th3 = Thm.transitive th1 th2 val (tm5,tm6) = Thm.dest_comb(concl th3) val (tm7,tm8) = Thm.dest_comb tm6 @@ -458,7 +471,7 @@ val th1 = inst_thm [(clx,lx),(cly,ly),(crx,r)] pthm_18 val (tm1,tm2) = Thm.dest_comb(concl th1) val (tm3,tm4) = Thm.dest_comb tm1 - val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2 + val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2 in Thm.transitive th1 th2 end else @@ -480,7 +493,7 @@ let val th1 = inst_thm [(clx,l),(crx,rx),(cry,ry)] pthm_21 val (tm1,tm2) = Thm.dest_comb(concl th1) val (tm3,tm4) = Thm.dest_comb tm1 - in Thm.transitive th1 (Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2) + in Thm.transitive th1 (Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2) end else if ord > 0 then let val th1 = inst_thm [(clx,l),(crx,rx),(cry,ry)] pthm_22 @@ -493,7 +506,7 @@ handle CTERM _ => (let val vr = powvar r val ord = vorder vl vr - in if ord = 0 then powvar_mul_conv tm + in if ord = 0 then powvar_mul_conv ctxt tm else if ord > 0 then inst_thm [(ca,l),(cb,r)] pthm_09 else Thm.reflexive tm end)) end)) @@ -502,7 +515,7 @@ end; (* Multiplication by monomial of a polynomial. *) - val polynomial_monomial_mul_conv = + fun polynomial_monomial_mul_conv ctxt = let fun pmm_conv tm = let val (l,r) = dest_mul tm @@ -511,10 +524,11 @@ val th1 = inst_thm [(cx,l),(cy,y),(cz,z)] pthm_37 val (tm1,tm2) = Thm.dest_comb(concl th1) val (tm3,tm4) = Thm.dest_comb tm1 - val th2 = Thm.combination (Drule.arg_cong_rule tm3 (monomial_mul_conv tm4)) (pmm_conv tm2) + val th2 = + Thm.combination (Drule.arg_cong_rule tm3 (monomial_mul_conv ctxt tm4)) (pmm_conv tm2) in Thm.transitive th1 th2 end) - handle CTERM _ => monomial_mul_conv tm) + handle CTERM _ => monomial_mul_conv ctxt tm) end in pmm_conv end; @@ -592,7 +606,7 @@ (* Addition of two polynomials. *) -val polynomial_add_conv = +fun polynomial_add_conv ctxt = let fun dezero_rule th = let @@ -690,25 +704,25 @@ (* Multiplication of two polynomials. *) -val polynomial_mul_conv = +fun polynomial_mul_conv ctxt = let fun pmul tm = let val (l,r) = dest_mul tm in - if not(is_add l) then polynomial_monomial_mul_conv tm + if not(is_add l) then polynomial_monomial_mul_conv ctxt tm else if not(is_add r) then let val th1 = inst_thm [(ca,l),(cb,r)] pthm_09 - in Thm.transitive th1 (polynomial_monomial_mul_conv(concl th1)) + in Thm.transitive th1 (polynomial_monomial_mul_conv ctxt (concl th1)) end else let val (a,b) = dest_add l val th1 = inst_thm [(ca,a),(cb,b),(cc,r)] pthm_10 val (tm1,tm2) = Thm.dest_comb(concl th1) val (tm3,tm4) = Thm.dest_comb tm1 - val th2 = Drule.arg_cong_rule tm3 (polynomial_monomial_mul_conv tm4) + val th2 = Drule.arg_cong_rule tm3 (polynomial_monomial_mul_conv ctxt tm4) val th3 = Thm.transitive th1 (Thm.combination th2 (pmul tm2)) - in Thm.transitive th3 (polynomial_add_conv (concl th3)) + in Thm.transitive th3 (polynomial_add_conv ctxt (concl th3)) end end in fn tm => @@ -724,12 +738,12 @@ (* Power of polynomial (optimized for the monomial and trivial cases). *) -fun num_conv n = - nat_add_conv (Thm.apply @{cterm Suc} (Numeral.mk_cnumber @{ctyp nat} (dest_numeral n - 1))) +fun num_conv ctxt n = + nat_add_conv ctxt (Thm.apply @{cterm Suc} (Numeral.mk_cnumber @{ctyp nat} (dest_numeral n - 1))) |> Thm.symmetric; -val polynomial_pow_conv = +fun polynomial_pow_conv ctxt = let fun ppow tm = let val (l,n) = dest_pow tm @@ -737,52 +751,52 @@ if n aconvc zeron_tm then inst_thm [(cx,l)] pthm_35 else if n aconvc onen_tm then inst_thm [(cx,l)] pthm_36 else - let val th1 = num_conv n + let val th1 = num_conv ctxt n val th2 = inst_thm [(cx,l),(cq,Thm.dest_arg (concl th1))] pthm_38 val (tm1,tm2) = Thm.dest_comb(concl th2) val th3 = Thm.transitive th2 (Drule.arg_cong_rule tm1 (ppow tm2)) val th4 = Thm.transitive (Drule.arg_cong_rule (Thm.dest_fun tm) th1) th3 - in Thm.transitive th4 (polynomial_mul_conv (concl th4)) + in Thm.transitive th4 (polynomial_mul_conv ctxt (concl th4)) end end in fn tm => - if is_add(Thm.dest_arg1 tm) then ppow tm else monomial_pow_conv tm + if is_add(Thm.dest_arg1 tm) then ppow tm else monomial_pow_conv ctxt tm end; (* Negation. *) -fun polynomial_neg_conv tm = +fun polynomial_neg_conv ctxt tm = let val (l,r) = Thm.dest_comb tm in if not (l aconvc neg_tm) then raise CTERM ("polynomial_neg_conv",[tm]) else let val th1 = inst_thm [(cx',r)] neg_mul val th2 = Thm.transitive th1 (Conv.arg1_conv semiring_mul_conv (concl th1)) - in Thm.transitive th2 (polynomial_monomial_mul_conv (concl th2)) + in Thm.transitive th2 (polynomial_monomial_mul_conv ctxt (concl th2)) end end; (* Subtraction. *) -fun polynomial_sub_conv tm = +fun polynomial_sub_conv ctxt tm = let val (l,r) = dest_sub tm val th1 = inst_thm [(cx',l),(cy',r)] sub_add val (tm1,tm2) = Thm.dest_comb(concl th1) - val th2 = Drule.arg_cong_rule tm1 (polynomial_neg_conv tm2) - in Thm.transitive th1 (Thm.transitive th2 (polynomial_add_conv (concl th2))) + val th2 = Drule.arg_cong_rule tm1 (polynomial_neg_conv ctxt tm2) + in Thm.transitive th1 (Thm.transitive th2 (polynomial_add_conv ctxt (concl th2))) end; (* Conversion from HOL term. *) -fun polynomial_conv tm = +fun polynomial_conv ctxt tm = if is_semiring_constant tm then semiring_add_conv tm else if not(is_comb tm) then Thm.reflexive tm else let val (lopr,r) = Thm.dest_comb tm in if lopr aconvc neg_tm then - let val th1 = Drule.arg_cong_rule lopr (polynomial_conv r) - in Thm.transitive th1 (polynomial_neg_conv (concl th1)) + let val th1 = Drule.arg_cong_rule lopr (polynomial_conv ctxt r) + in Thm.transitive th1 (polynomial_neg_conv ctxt (concl th1)) end else if lopr aconvc inverse_tm then - let val th1 = Drule.arg_cong_rule lopr (polynomial_conv r) + let val th1 = Drule.arg_cong_rule lopr (polynomial_conv ctxt r) in Thm.transitive th1 (semiring_mul_conv (concl th1)) end else @@ -791,14 +805,14 @@ let val (opr,l) = Thm.dest_comb lopr in if opr aconvc pow_tm andalso is_numeral r then - let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv l)) r - in Thm.transitive th1 (polynomial_pow_conv (concl th1)) + let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) r + in Thm.transitive th1 (polynomial_pow_conv ctxt (concl th1)) end else if opr aconvc divide_tm then - let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv l)) - (polynomial_conv r) - val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv) + let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) + (polynomial_conv ctxt r) + val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv ctxt) (Thm.rhs_of th1) in Thm.transitive th1 th2 end @@ -806,10 +820,11 @@ if opr aconvc add_tm orelse opr aconvc mul_tm orelse opr aconvc sub_tm then let val th1 = - Thm.combination (Drule.arg_cong_rule opr (polynomial_conv l)) (polynomial_conv r) - val f = if opr aconvc add_tm then polynomial_add_conv - else if opr aconvc mul_tm then polynomial_mul_conv - else polynomial_sub_conv + Thm.combination + (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) (polynomial_conv ctxt r) + val f = if opr aconvc add_tm then polynomial_add_conv ctxt + else if opr aconvc mul_tm then polynomial_mul_conv ctxt + else polynomial_sub_conv ctxt in Thm.transitive th1 (f (concl th1)) end else Thm.reflexive tm @@ -826,8 +841,10 @@ end; val nat_exp_ss = - HOL_basic_ss addsimps (@{thms eval_nat_numeral} @ @{thms nat_arith} @ @{thms arith_simps} @ @{thms rel_simps}) - addsimps [@{thm Let_def}, @{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}]; + simpset_of + (put_simpset HOL_basic_ss @{context} + addsimps (@{thms eval_nat_numeral} @ @{thms nat_arith} @ @{thms arith_simps} @ @{thms rel_simps}) + addsimps [@{thm Let_def}, @{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}]); fun simple_cterm_ord t u = Term_Ord.term_ord (term_of t, term_of u) = LESS; @@ -838,15 +855,17 @@ {conv, dest_const, mk_const, is_const}) ord = let val pow_conv = - Conv.arg_conv (Simplifier.rewrite nat_exp_ss) + Conv.arg_conv (Simplifier.rewrite (put_simpset nat_exp_ss ctxt)) then_conv Simplifier.rewrite - (HOL_basic_ss addsimps [nth (snd semiring) 31, nth (snd semiring) 34]) + (put_simpset HOL_basic_ss ctxt addsimps [nth (snd semiring) 31, nth (snd semiring) 34]) then_conv conv ctxt val dat = (is_const, conv ctxt, conv ctxt, pow_conv) in semiring_normalizers_conv vars semiring ring field dat ord end; fun semiring_normalize_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal}, {conv, dest_const, mk_const, is_const}) ord = - #main (semiring_normalizers_ord_wrapper ctxt ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal},{conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord); + #main (semiring_normalizers_ord_wrapper ctxt + ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal}, + {conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord) ctxt; fun semiring_normalize_wrapper ctxt data = semiring_normalize_ord_wrapper ctxt data simple_cterm_ord;