src/HOL/Tools/numeral_simprocs.ML
changeset 44945 2625de88c994
parent 44064 5bce8ff0d9ae
child 44947 8ae418dfe561
--- a/src/HOL/Tools/numeral_simprocs.ML	Sat Sep 17 04:41:44 2011 +0200
+++ b/src/HOL/Tools/numeral_simprocs.ML	Sat Sep 17 00:37:21 2011 +0200
@@ -16,6 +16,9 @@
 
 signature NUMERAL_SIMPROCS =
 sig
+  val prep_simproc: theory -> string * string list * (theory -> simpset -> term -> thm option)
+    -> simproc
+  val trans_tac: thm option -> tactic
   val assoc_fold_simproc: simproc
   val combine_numerals: simproc
   val cancel_numerals: simproc list
@@ -30,6 +33,12 @@
 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
 struct
 
+fun prep_simproc thy (name, pats, proc) =
+  Simplifier.simproc_global thy name pats proc;
+
+fun trans_tac NONE  = all_tac
+  | trans_tac (SOME th) = ALLGOALS (rtac (th RS trans));
+
 val mk_number = Arith_Data.mk_number;
 val mk_sum = Arith_Data.mk_sum;
 val long_mk_sum = Arith_Data.long_mk_sum;
@@ -199,13 +208,13 @@
 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
 
 structure CancelNumeralsCommon =
-  struct
-  val mk_sum            = mk_sum
-  val dest_sum          = dest_sum
-  val mk_coeff          = mk_coeff
-  val dest_coeff        = dest_coeff 1
-  val find_first_coeff  = find_first_coeff []
-  fun trans_tac _       = Arith_Data.trans_tac
+struct
+  val mk_sum = mk_sum
+  val dest_sum = dest_sum
+  val mk_coeff = mk_coeff
+  val dest_coeff = dest_coeff 1
+  val find_first_coeff = find_first_coeff []
+  val trans_tac = K trans_tac
 
   fun norm_tac ss =
     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
@@ -215,12 +224,11 @@
   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
-  end;
-
+  val prove_conv = Arith_Data.prove_conv
+end;
 
 structure EqCancelNumerals = CancelNumeralsFun
  (open CancelNumeralsCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_eq
   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   val bal_add1 = @{thm eq_add_iff1} RS trans
@@ -229,7 +237,6 @@
 
 structure LessCancelNumerals = CancelNumeralsFun
  (open CancelNumeralsCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   val bal_add1 = @{thm less_add_iff1} RS trans
@@ -238,7 +245,6 @@
 
 structure LeCancelNumerals = CancelNumeralsFun
  (open CancelNumeralsCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   val bal_add1 = @{thm le_add_iff1} RS trans
@@ -246,7 +252,7 @@
 );
 
 val cancel_numerals =
-  map (Arith_Data.prep_simproc @{theory})
+  map (prep_simproc @{theory})
    [("inteq_cancel_numerals",
      ["(l::'a::number_ring) + m = n",
       "(l::'a::number_ring) = m + n",
@@ -273,17 +279,17 @@
      K LeCancelNumerals.proc)];
 
 structure CombineNumeralsData =
-  struct
-  type coeff            = int
-  val iszero            = (fn x => x = 0)
-  val add               = op +
-  val mk_sum            = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
-  val dest_sum          = dest_sum
-  val mk_coeff          = mk_coeff
-  val dest_coeff        = dest_coeff 1
-  val left_distrib      = @{thm combine_common_factor} RS trans
-  val prove_conv        = Arith_Data.prove_conv_nohyps
-  fun trans_tac _       = Arith_Data.trans_tac
+struct
+  type coeff = int
+  val iszero = (fn x => x = 0)
+  val add  = op +
+  val mk_sum = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
+  val dest_sum = dest_sum
+  val mk_coeff = mk_coeff
+  val dest_coeff = dest_coeff 1
+  val left_distrib = @{thm combine_common_factor} RS trans
+  val prove_conv = Arith_Data.prove_conv_nohyps
+  val trans_tac = K trans_tac
 
   fun norm_tac ss =
     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
@@ -293,23 +299,23 @@
   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
-  end;
+end;
 
 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
 
 (*Version for fields, where coefficients can be fractions*)
 structure FieldCombineNumeralsData =
-  struct
-  type coeff            = int * int
-  val iszero            = (fn (p, q) => p = 0)
-  val add               = add_frac
-  val mk_sum            = long_mk_sum
-  val dest_sum          = dest_sum
-  val mk_coeff          = mk_fcoeff
-  val dest_coeff        = dest_fcoeff 1
-  val left_distrib      = @{thm combine_common_factor} RS trans
-  val prove_conv        = Arith_Data.prove_conv_nohyps
-  fun trans_tac _       = Arith_Data.trans_tac
+struct
+  type coeff = int * int
+  val iszero = (fn (p, q) => p = 0)
+  val add = add_frac
+  val mk_sum = long_mk_sum
+  val dest_sum = dest_sum
+  val mk_coeff = mk_fcoeff
+  val dest_coeff = dest_fcoeff 1
+  val left_distrib = @{thm combine_common_factor} RS trans
+  val prove_conv = Arith_Data.prove_conv_nohyps
+  val trans_tac = K trans_tac
 
   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   fun norm_tac ss =
@@ -320,18 +326,18 @@
   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}]
   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s @ divide_1s)
-  end;
+end;
 
 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
 
 val combine_numerals =
-  Arith_Data.prep_simproc @{theory}
+  prep_simproc @{theory}
     ("int_combine_numerals", 
      ["(i::'a::number_ring) + j", "(i::'a::number_ring) - j"], 
      K CombineNumerals.proc);
 
 val field_combine_numerals =
-  Arith_Data.prep_simproc @{theory}
+  prep_simproc @{theory}
     ("field_combine_numerals", 
      ["(i::'a::{field_inverse_zero, number_ring}) + j",
       "(i::'a::{field_inverse_zero, number_ring}) - j"], 
@@ -351,15 +357,15 @@
 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
 
 val assoc_fold_simproc =
-  Arith_Data.prep_simproc @{theory}
+  prep_simproc @{theory}
    ("semiring_assoc_fold", ["(a::'a::comm_semiring_1_cancel) * b"],
     K Semiring_Times_Assoc.proc);
 
 structure CancelNumeralFactorCommon =
-  struct
-  val mk_coeff          = mk_coeff
-  val dest_coeff        = dest_coeff 1
-  fun trans_tac _       = Arith_Data.trans_tac
+struct
+  val mk_coeff = mk_coeff
+  val dest_coeff = dest_coeff 1
+  val trans_tac = K trans_tac
 
   val norm_ss1 = HOL_ss addsimps minus_from_mult_simps @ mult_1s
   val norm_ss2 = HOL_ss addsimps simps @ mult_minus_simps
@@ -375,12 +381,12 @@
   val simplify_meta_eq = Arith_Data.simplify_meta_eq
     [@{thm Nat.add_0}, @{thm Nat.add_0_right}, @{thm mult_zero_left},
       @{thm mult_zero_right}, @{thm mult_Bit1}, @{thm mult_1_right}];
-  end
+  val prove_conv = Arith_Data.prove_conv
+end
 
 (*Version for semiring_div*)
 structure DivCancelNumeralFactor = CancelNumeralFactorFun
  (open CancelNumeralFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   val cancel = @{thm div_mult_mult1} RS trans
@@ -390,7 +396,6 @@
 (*Version for fields*)
 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
  (open CancelNumeralFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
@@ -399,7 +404,6 @@
 
 structure EqCancelNumeralFactor = CancelNumeralFactorFun
  (open CancelNumeralFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_eq
   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   val cancel = @{thm mult_cancel_left} RS trans
@@ -408,7 +412,6 @@
 
 structure LessCancelNumeralFactor = CancelNumeralFactorFun
  (open CancelNumeralFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   val cancel = @{thm mult_less_cancel_left} RS trans
@@ -417,7 +420,6 @@
 
 structure LeCancelNumeralFactor = CancelNumeralFactorFun
  (open CancelNumeralFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   val cancel = @{thm mult_le_cancel_left} RS trans
@@ -425,7 +427,7 @@
 )
 
 val cancel_numeral_factors =
-  map (Arith_Data.prep_simproc @{theory})
+  map (prep_simproc @{theory})
    [("ring_eq_cancel_numeral_factor",
      ["(l::'a::{idom,number_ring}) * m = n",
       "(l::'a::{idom,number_ring}) = m * n"],
@@ -449,7 +451,7 @@
      K DivideCancelNumeralFactor.proc)];
 
 val field_cancel_numeral_factors =
-  map (Arith_Data.prep_simproc @{theory})
+  map (prep_simproc @{theory})
    [("field_eq_cancel_numeral_factor",
      ["(l::'a::{field,number_ring}) * m = n",
       "(l::'a::{field,number_ring}) = m * n"],
@@ -499,22 +501,22 @@
 end
 
 structure CancelFactorCommon =
-  struct
-  val mk_sum            = long_mk_prod
-  val dest_sum          = dest_prod
-  val mk_coeff          = mk_coeff
-  val dest_coeff        = dest_coeff
-  val find_first        = find_first_t []
-  fun trans_tac _       = Arith_Data.trans_tac
+struct
+  val mk_sum = long_mk_prod
+  val dest_sum = dest_prod
+  val mk_coeff = mk_coeff
+  val dest_coeff = dest_coeff
+  val find_first = find_first_t []
+  val trans_tac = K trans_tac
   val norm_ss = HOL_ss addsimps mult_1s @ @{thms mult_ac}
   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   val simplify_meta_eq  = cancel_simplify_meta_eq 
-  end;
+  val prove_conv = Arith_Data.prove_conv
+end;
 
 (*mult_cancel_left requires a ring with no zero divisors.*)
 structure EqCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_eq
   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
@@ -523,7 +525,6 @@
 (*for ordered rings*)
 structure LeCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   val simp_conv = sign_conv
@@ -533,7 +534,6 @@
 (*for ordered rings*)
 structure LessCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   val simp_conv = sign_conv
@@ -543,7 +543,6 @@
 (*for semirings with division*)
 structure DivCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
@@ -551,7 +550,6 @@
 
 structure ModCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
@@ -560,7 +558,6 @@
 (*for idoms*)
 structure DvdCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} Term.dummyT
   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
@@ -569,14 +566,13 @@
 (*Version for all fields, including unordered ones (type complex).*)
 structure DivideCancelFactor = ExtractCommonTermFun
  (open CancelFactorCommon
-  val prove_conv = Arith_Data.prove_conv
   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
 );
 
 val cancel_factors =
-  map (Arith_Data.prep_simproc @{theory})
+  map (prep_simproc @{theory})
    [("ring_eq_cancel_factor",
      ["(l::'a::idom) * m = n",
       "(l::'a::idom) = m * n"],