src/HOL/Tools/semiring_normalizer.ML
changeset 51717 9e7d1c139569
parent 47108 2a1953f0d20d
child 53078 cc06f17d8057
--- a/src/HOL/Tools/semiring_normalizer.ML	Tue Apr 16 17:54:14 2013 +0200
+++ b/src/HOL/Tools/semiring_normalizer.ML	Thu Apr 18 17:07:01 2013 +0200
@@ -27,10 +27,20 @@
   val semiring_normalizers_conv: cterm list -> cterm list * thm list
     -> cterm list * thm list -> cterm list * thm list ->
       (cterm -> bool) * conv * conv * conv -> (cterm -> cterm -> bool) ->
-        {add: conv, mul: conv, neg: conv, main: conv, pow: conv, sub: conv}
+        {add: Proof.context -> conv,
+         mul: Proof.context -> conv,
+         neg: Proof.context -> conv,
+         main: Proof.context -> conv,
+         pow: Proof.context -> conv,
+         sub: Proof.context -> conv}
   val semiring_normalizers_ord_wrapper:  Proof.context -> entry ->
     (cterm -> cterm -> bool) ->
-      {add: conv, mul: conv, neg: conv, main: conv, pow: conv, sub: conv}
+      {add: Proof.context -> conv,
+       mul: Proof.context -> conv,
+       neg: Proof.context -> conv,
+       main: Proof.context -> conv,
+       pow: Proof.context -> conv,
+       sub: Proof.context -> conv}
 
   val setup: theory -> theory
 end
@@ -177,9 +187,9 @@
           handle TERM _ => error "ring_dest_const")),
     mk_const = fn phi => fn cT => fn x => Numeral.mk_cnumber cT
       (case Rat.quotient_of_rat x of (i, 1) => i | _ => error "int_of_rat: bad int"),
-    conv = fn phi => fn _ => Simplifier.rewrite (HOL_basic_ss addsimps @{thms semiring_norm})
-      then_conv Simplifier.rewrite (HOL_basic_ss addsimps
-        @{thms numeral_1_eq_1})};
+    conv = fn phi => fn ctxt =>
+      Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps @{thms semiring_norm})
+      then_conv Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps @{thms numeral_1_eq_1})};
 
 fun field_funs key =
   let
@@ -208,7 +218,7 @@
      {is_const = K numeral_is_const,
       dest_const = K dest_const,
       mk_const = mk_const,
-      conv = K (K Numeral_Simprocs.field_comp_conv)}
+      conv = K Numeral_Simprocs.field_comp_conv}
   end;
 
 
@@ -236,23 +246,26 @@
 val dest_numeral = term_of #> HOLogic.dest_number #> snd;
 val is_numeral = can dest_numeral;
 
-val numeral01_conv = Simplifier.rewrite
-                         (HOL_basic_ss addsimps [@{thm numeral_1_eq_1}]);
-val zero1_numeral_conv = 
- Simplifier.rewrite (HOL_basic_ss addsimps [@{thm numeral_1_eq_1} RS sym]);
-fun zerone_conv cv = zero1_numeral_conv then_conv cv then_conv numeral01_conv;
+fun numeral01_conv ctxt =
+  Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1}]);
+
+fun zero1_numeral_conv ctxt =
+  Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1} RS sym]);
+
+fun zerone_conv ctxt cv =
+  zero1_numeral_conv ctxt then_conv cv then_conv numeral01_conv ctxt;
 val natarith = [@{thm "numeral_plus_numeral"}, @{thm "diff_nat_numeral"},
                 @{thm "numeral_times_numeral"}, @{thm "numeral_eq_iff"}, 
                 @{thm "numeral_less_iff"}];
 
-val nat_add_conv = 
- zerone_conv 
-  (Simplifier.rewrite 
-    (HOL_basic_ss 
-       addsimps @{thms arith_simps} @ natarith @ @{thms rel_simps}
-             @ [@{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc},
-                 @{thm add_numeral_left}, @{thm Suc_eq_plus1}]
-             @ map (fn th => th RS sym) @{thms numerals}));
+fun nat_add_conv ctxt =
+  zerone_conv ctxt
+    (Simplifier.rewrite 
+      (put_simpset HOL_basic_ss ctxt
+         addsimps @{thms arith_simps} @ natarith @ @{thms rel_simps}
+               @ [@{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc},
+                   @{thm add_numeral_left}, @{thm Suc_eq_plus1}]
+               @ map (fn th => th RS sym) @{thms numerals}));
 
 val zeron_tm = @{cterm "0::nat"};
 val onen_tm  = @{cterm "1::nat"};
@@ -316,7 +329,7 @@
 (* Also deals with "const * const", but both terms must involve powers of    *)
 (* the same variable, or both be constants, or behaviour may be incorrect.   *)
 
- fun powvar_mul_conv tm =
+ fun powvar_mul_conv ctxt tm =
   let
   val (l,r) = dest_mul tm
   in if is_semiring_constant l andalso is_semiring_constant r
@@ -328,16 +341,16 @@
          ((let val (rx,rn) = dest_pow r
                val th1 = inst_thm [(cx,lx),(cp,ln),(cq,rn)] pthm_29
                 val (tm1,tm2) = Thm.dest_comb(concl th1) in
-               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end)
+               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end)
            handle CTERM _ =>
             (let val th1 = inst_thm [(cx,lx),(cq,ln)] pthm_31
                  val (tm1,tm2) = Thm.dest_comb(concl th1) in
-               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end)) end)
+               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end)) end)
        handle CTERM _ =>
            ((let val (rx,rn) = dest_pow r
                 val th1 = inst_thm [(cx,rx),(cq,rn)] pthm_30
                 val (tm1,tm2) = Thm.dest_comb(concl th1) in
-               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv tm2)) end)
+               Thm.transitive th1 (Drule.arg_cong_rule tm1 (nat_add_conv ctxt tm2)) end)
            handle CTERM _ => inst_thm [(cx,l)] pthm_32
 
 ))
@@ -353,7 +366,7 @@
 
 (* Conversion for "(monomial)^n", where n is a numeral.                      *)
 
- val monomial_pow_conv =
+ fun monomial_pow_conv ctxt =
   let
    fun monomial_pow tm bod ntm =
     if not(is_comb bod)
@@ -374,7 +387,7 @@
           then
             let val th1 = inst_thm [(cx,l),(cp,r),(cq,ntm)] pthm_34
                 val (l,r) = Thm.dest_comb(concl th1)
-           in Thm.transitive th1 (Drule.arg_cong_rule l (nat_add_conv r))
+           in Thm.transitive th1 (Drule.arg_cong_rule l (nat_add_conv ctxt r))
            end
            else
             if opr aconvc mul_tm
@@ -405,7 +418,7 @@
   end;
 
 (* Multiplication of canonical monomials.                                    *)
- val monomial_mul_conv =
+ fun monomial_mul_conv ctxt =
   let
    fun powvar tm =
     if is_semiring_constant tm then one_tm
@@ -435,7 +448,7 @@
              val th1 = inst_thm [(clx,lx),(cly,ly),(crx,rx),(cry,ry)] pthm_15
              val (tm1,tm2) = Thm.dest_comb(concl th1)
              val (tm3,tm4) = Thm.dest_comb tm1
-             val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2
+             val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2
              val th3 = Thm.transitive th1 th2
               val  (tm5,tm6) = Thm.dest_comb(concl th3)
               val  (tm7,tm8) = Thm.dest_comb tm6
@@ -458,7 +471,7 @@
            val th1 = inst_thm [(clx,lx),(cly,ly),(crx,r)] pthm_18
                  val (tm1,tm2) = Thm.dest_comb(concl th1)
            val (tm3,tm4) = Thm.dest_comb tm1
-           val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2
+           val th2 = Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2
           in Thm.transitive th1 th2
           end
           else
@@ -480,7 +493,7 @@
               let val th1 = inst_thm [(clx,l),(crx,rx),(cry,ry)] pthm_21
                  val (tm1,tm2) = Thm.dest_comb(concl th1)
                  val (tm3,tm4) = Thm.dest_comb tm1
-             in Thm.transitive th1 (Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv tm4)) tm2)
+             in Thm.transitive th1 (Drule.fun_cong_rule (Drule.arg_cong_rule tm3 (powvar_mul_conv ctxt tm4)) tm2)
              end
              else if ord > 0 then
                  let val th1 = inst_thm [(clx,l),(crx,rx),(cry,ry)] pthm_22
@@ -493,7 +506,7 @@
         handle CTERM _ =>
           (let val vr = powvar r
                val  ord = vorder vl vr
-          in if ord = 0 then powvar_mul_conv tm
+          in if ord = 0 then powvar_mul_conv ctxt tm
               else if ord > 0 then inst_thm [(ca,l),(cb,r)] pthm_09
               else Thm.reflexive tm
           end)) end))
@@ -502,7 +515,7 @@
   end;
 (* Multiplication by monomial of a polynomial.                               *)
 
- val polynomial_monomial_mul_conv =
+ fun polynomial_monomial_mul_conv ctxt =
   let
    fun pmm_conv tm =
     let val (l,r) = dest_mul tm
@@ -511,10 +524,11 @@
           val th1 = inst_thm [(cx,l),(cy,y),(cz,z)] pthm_37
           val (tm1,tm2) = Thm.dest_comb(concl th1)
           val (tm3,tm4) = Thm.dest_comb tm1
-          val th2 = Thm.combination (Drule.arg_cong_rule tm3 (monomial_mul_conv tm4)) (pmm_conv tm2)
+          val th2 =
+            Thm.combination (Drule.arg_cong_rule tm3 (monomial_mul_conv ctxt tm4)) (pmm_conv tm2)
       in Thm.transitive th1 th2
       end)
-     handle CTERM _ => monomial_mul_conv tm)
+     handle CTERM _ => monomial_mul_conv ctxt tm)
    end
  in pmm_conv
  end;
@@ -592,7 +606,7 @@
 
 (* Addition of two polynomials.                                              *)
 
-val polynomial_add_conv =
+fun polynomial_add_conv ctxt =
  let
  fun dezero_rule th =
   let
@@ -690,25 +704,25 @@
 
 (* Multiplication of two polynomials.                                        *)
 
-val polynomial_mul_conv =
+fun polynomial_mul_conv ctxt =
  let
   fun pmul tm =
    let val (l,r) = dest_mul tm
    in
-    if not(is_add l) then polynomial_monomial_mul_conv tm
+    if not(is_add l) then polynomial_monomial_mul_conv ctxt tm
     else
      if not(is_add r) then
       let val th1 = inst_thm [(ca,l),(cb,r)] pthm_09
-      in Thm.transitive th1 (polynomial_monomial_mul_conv(concl th1))
+      in Thm.transitive th1 (polynomial_monomial_mul_conv ctxt (concl th1))
       end
      else
        let val (a,b) = dest_add l
            val th1 = inst_thm [(ca,a),(cb,b),(cc,r)] pthm_10
            val (tm1,tm2) = Thm.dest_comb(concl th1)
            val (tm3,tm4) = Thm.dest_comb tm1
-           val th2 = Drule.arg_cong_rule tm3 (polynomial_monomial_mul_conv tm4)
+           val th2 = Drule.arg_cong_rule tm3 (polynomial_monomial_mul_conv ctxt tm4)
            val th3 = Thm.transitive th1 (Thm.combination th2 (pmul tm2))
-       in Thm.transitive th3 (polynomial_add_conv (concl th3))
+       in Thm.transitive th3 (polynomial_add_conv ctxt (concl th3))
        end
    end
  in fn tm =>
@@ -724,12 +738,12 @@
 
 (* Power of polynomial (optimized for the monomial and trivial cases).       *)
 
-fun num_conv n =
-  nat_add_conv (Thm.apply @{cterm Suc} (Numeral.mk_cnumber @{ctyp nat} (dest_numeral n - 1)))
+fun num_conv ctxt n =
+  nat_add_conv ctxt (Thm.apply @{cterm Suc} (Numeral.mk_cnumber @{ctyp nat} (dest_numeral n - 1)))
   |> Thm.symmetric;
 
 
-val polynomial_pow_conv =
+fun polynomial_pow_conv ctxt =
  let
   fun ppow tm =
     let val (l,n) = dest_pow tm
@@ -737,52 +751,52 @@
      if n aconvc zeron_tm then inst_thm [(cx,l)] pthm_35
      else if n aconvc onen_tm then inst_thm [(cx,l)] pthm_36
      else
-         let val th1 = num_conv n
+         let val th1 = num_conv ctxt n
              val th2 = inst_thm [(cx,l),(cq,Thm.dest_arg (concl th1))] pthm_38
              val (tm1,tm2) = Thm.dest_comb(concl th2)
              val th3 = Thm.transitive th2 (Drule.arg_cong_rule tm1 (ppow tm2))
              val th4 = Thm.transitive (Drule.arg_cong_rule (Thm.dest_fun tm) th1) th3
-         in Thm.transitive th4 (polynomial_mul_conv (concl th4))
+         in Thm.transitive th4 (polynomial_mul_conv ctxt (concl th4))
          end
     end
  in fn tm =>
-       if is_add(Thm.dest_arg1 tm) then ppow tm else monomial_pow_conv tm
+       if is_add(Thm.dest_arg1 tm) then ppow tm else monomial_pow_conv ctxt tm
  end;
 
 (* Negation.                                                                 *)
 
-fun polynomial_neg_conv tm =
+fun polynomial_neg_conv ctxt tm =
    let val (l,r) = Thm.dest_comb tm in
         if not (l aconvc neg_tm) then raise CTERM ("polynomial_neg_conv",[tm]) else
         let val th1 = inst_thm [(cx',r)] neg_mul
             val th2 = Thm.transitive th1 (Conv.arg1_conv semiring_mul_conv (concl th1))
-        in Thm.transitive th2 (polynomial_monomial_mul_conv (concl th2))
+        in Thm.transitive th2 (polynomial_monomial_mul_conv ctxt (concl th2))
         end
    end;
 
 
 (* Subtraction.                                                              *)
-fun polynomial_sub_conv tm =
+fun polynomial_sub_conv ctxt tm =
   let val (l,r) = dest_sub tm
       val th1 = inst_thm [(cx',l),(cy',r)] sub_add
       val (tm1,tm2) = Thm.dest_comb(concl th1)
-      val th2 = Drule.arg_cong_rule tm1 (polynomial_neg_conv tm2)
-  in Thm.transitive th1 (Thm.transitive th2 (polynomial_add_conv (concl th2)))
+      val th2 = Drule.arg_cong_rule tm1 (polynomial_neg_conv ctxt tm2)
+  in Thm.transitive th1 (Thm.transitive th2 (polynomial_add_conv ctxt (concl th2)))
   end;
 
 (* Conversion from HOL term.                                                 *)
 
-fun polynomial_conv tm =
+fun polynomial_conv ctxt tm =
  if is_semiring_constant tm then semiring_add_conv tm
  else if not(is_comb tm) then Thm.reflexive tm
  else
   let val (lopr,r) = Thm.dest_comb tm
   in if lopr aconvc neg_tm then
-       let val th1 = Drule.arg_cong_rule lopr (polynomial_conv r)
-       in Thm.transitive th1 (polynomial_neg_conv (concl th1))
+       let val th1 = Drule.arg_cong_rule lopr (polynomial_conv ctxt r)
+       in Thm.transitive th1 (polynomial_neg_conv ctxt (concl th1))
        end
      else if lopr aconvc inverse_tm then
-       let val th1 = Drule.arg_cong_rule lopr (polynomial_conv r)
+       let val th1 = Drule.arg_cong_rule lopr (polynomial_conv ctxt r)
        in Thm.transitive th1 (semiring_mul_conv (concl th1))
        end
      else
@@ -791,14 +805,14 @@
          let val (opr,l) = Thm.dest_comb lopr
          in if opr aconvc pow_tm andalso is_numeral r
             then
-              let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv l)) r
-              in Thm.transitive th1 (polynomial_pow_conv (concl th1))
+              let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) r
+              in Thm.transitive th1 (polynomial_pow_conv ctxt (concl th1))
               end
          else if opr aconvc divide_tm 
             then
-              let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv l)) 
-                                        (polynomial_conv r)
-                  val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv)
+              let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) 
+                                        (polynomial_conv ctxt r)
+                  val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv ctxt)
                               (Thm.rhs_of th1)
               in Thm.transitive th1 th2
               end
@@ -806,10 +820,11 @@
               if opr aconvc add_tm orelse opr aconvc mul_tm orelse opr aconvc sub_tm
               then
                let val th1 =
-                    Thm.combination (Drule.arg_cong_rule opr (polynomial_conv l)) (polynomial_conv r)
-                   val f = if opr aconvc add_tm then polynomial_add_conv
-                      else if opr aconvc mul_tm then polynomial_mul_conv
-                      else polynomial_sub_conv
+                    Thm.combination
+                      (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) (polynomial_conv ctxt r)
+                   val f = if opr aconvc add_tm then polynomial_add_conv ctxt
+                      else if opr aconvc mul_tm then polynomial_mul_conv ctxt
+                      else polynomial_sub_conv ctxt
                in Thm.transitive th1 (f (concl th1))
                end
               else Thm.reflexive tm
@@ -826,8 +841,10 @@
 end;
 
 val nat_exp_ss =
-  HOL_basic_ss addsimps (@{thms eval_nat_numeral} @ @{thms nat_arith} @ @{thms arith_simps} @ @{thms rel_simps})
-    addsimps [@{thm Let_def}, @{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}];
+  simpset_of
+   (put_simpset HOL_basic_ss @{context}
+    addsimps (@{thms eval_nat_numeral} @ @{thms nat_arith} @ @{thms arith_simps} @ @{thms rel_simps})
+    addsimps [@{thm Let_def}, @{thm if_False}, @{thm if_True}, @{thm Nat.add_0}, @{thm add_Suc}]);
 
 fun simple_cterm_ord t u = Term_Ord.term_ord (term_of t, term_of u) = LESS;
 
@@ -838,15 +855,17 @@
                                      {conv, dest_const, mk_const, is_const}) ord =
   let
     val pow_conv =
-      Conv.arg_conv (Simplifier.rewrite nat_exp_ss)
+      Conv.arg_conv (Simplifier.rewrite (put_simpset nat_exp_ss ctxt))
       then_conv Simplifier.rewrite
-        (HOL_basic_ss addsimps [nth (snd semiring) 31, nth (snd semiring) 34])
+        (put_simpset HOL_basic_ss ctxt addsimps [nth (snd semiring) 31, nth (snd semiring) 34])
       then_conv conv ctxt
     val dat = (is_const, conv ctxt, conv ctxt, pow_conv)
   in semiring_normalizers_conv vars semiring ring field dat ord end;
 
 fun semiring_normalize_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal}, {conv, dest_const, mk_const, is_const}) ord =
- #main (semiring_normalizers_ord_wrapper ctxt ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal},{conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord);
+ #main (semiring_normalizers_ord_wrapper ctxt
+  ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal},
+   {conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord) ctxt;
 
 fun semiring_normalize_wrapper ctxt data = 
   semiring_normalize_ord_wrapper ctxt data simple_cterm_ord;