proper simproc_setup;
authorwenzelm
Mon, 29 Jan 2007 19:58:14 +0100
changeset 22205 23bd1ed32ac0
parent 22204 33da3a55c00e
child 22206 8cc04341de38
proper simproc_setup; tuned ML setup;
src/HOL/ex/Binary.thy
--- a/src/HOL/ex/Binary.thy	Mon Jan 29 19:58:14 2007 +0100
+++ b/src/HOL/ex/Binary.thy	Mon Jan 29 19:58:14 2007 +0100
@@ -20,6 +20,30 @@
     "bit n True = 2 * n + 1"
   unfolding bit_def by simp_all
 
+ML {*
+  fun dest_bit (Const ("False", _)) = 0
+    | dest_bit (Const ("True", _)) = 1
+    | dest_bit t = raise TERM ("dest_bit", [t]);
+
+  fun dest_binary (Const ("HOL.zero", Type ("nat", _))) = 0
+    | dest_binary (Const ("HOL.one", Type ("nat", _))) = 1
+    | dest_binary (Const ("Binary.bit", _) $ bs $ b) =
+        2 * dest_binary bs + IntInf.fromInt (dest_bit b)
+    | dest_binary t = raise TERM ("dest_binary", [t]);
+
+  fun mk_bit 0 = @{term False}
+    | mk_bit 1 = @{term True}
+    | mk_bit _ = raise TERM ("mk_bit", []);
+
+  fun mk_binary 0 = @{term "0::nat"}
+    | mk_binary 1 = @{term "1::nat"}
+    | mk_binary n =
+        if n < 0 then raise TERM ("mk_binary", [])
+        else
+          let val (q, r) = IntInf.divMod (n, 2)
+          in @{term bit} $ mk_binary q $ mk_bit (IntInf.toInt r) end;
+*}
+
 
 subsection {* Direct operations -- plain normalization *}
 
@@ -87,92 +111,78 @@
 qed
 
 ML {*
-  fun dest_bit (Const ("False", _)) = 0
-    | dest_bit (Const ("True", _)) = 1
-    | dest_bit t = raise TERM ("dest_bit", [t]);
-
-  fun dest_binary (Const ("HOL.zero", Type ("nat", _))) = 0
-    | dest_binary (Const ("HOL.one", Type ("nat", _))) = 1
-    | dest_binary (Const ("Binary.bit", _) $ bs $ b) =
-        2 * dest_binary bs + IntInf.fromInt (dest_bit b)
-    | dest_binary t = raise TERM ("dest_binary", [t]);
+local
+  infix ==;
+  val op == = Logic.mk_equals;
+  fun plus m n = @{term "plus :: nat \<Rightarrow> nat \<Rightarrow> nat"} $ m $ n;
+  fun mult m n = @{term "times :: nat \<Rightarrow> nat \<Rightarrow> nat"} $ m $ n;
 
-  fun mk_bit 0 = @{term False}
-    | mk_bit 1 = @{term True}
-    | mk_bit _ = raise TERM ("mk_bit", []);
-
-  fun mk_binary 0 = @{term "0::nat"}
-    | mk_binary 1 = @{term "1::nat"}
-    | mk_binary n =
-        if n < 0 then raise TERM ("mk_binary", [])
-        else
-          let val (q, r) = IntInf.divMod (n, 2)
-          in @{term bit} $ mk_binary q $ mk_bit (IntInf.toInt r) end;
-*}
-
-ML {*
-local
   val binary_ss = HOL_basic_ss addsimps @{thms binary_simps};
   fun prove ctxt prop =
     Goal.prove ctxt [] [] prop (fn _ => ALLGOALS (full_simp_tac binary_ss));
 
-  infix ==;
-  val op == = Logic.mk_equals;
-
-  fun plus m n = @{term "plus :: nat \<Rightarrow> nat \<Rightarrow> nat"} $ m $ n;
-  fun mult m n = @{term "times :: nat \<Rightarrow> nat \<Rightarrow> nat"} $ m $ n;
-
-
-  exception FAIL;
-  fun the_arg t = (t, dest_binary t handle TERM _ => raise FAIL);
+  fun binary_proc proc ss ct =
+    (case Thm.term_of ct of
+      _ $ t $ u =>
+      (case try (pairself (`dest_binary)) (t, u) of
+        SOME args => proc (Simplifier.the_context ss) args
+      | NONE => NONE)
+    | _ => NONE);
+in
 
-  val read = Thm.cterm_of @{theory} o Sign.read_term @{theory};
-  fun mk_proc name pat proc = Simplifier.mk_simproc' name [read pat]
-    (fn ss => fn ct =>
-      (case Thm.term_of ct of
-        _ $ t $ u =>
-          (SOME (proc (Simplifier.the_context ss) (the_arg t) (the_arg u)) handle FAIL => NONE)
-      | _ => NONE));
+val less_eq_proc = binary_proc (fn ctxt => fn ((m, t), (n, u)) =>
+  let val k = n - m in
+    if k >= 0 then
+      SOME (@{thm binary_less_eq(1)} OF [prove ctxt (u == plus t (mk_binary k))])
+    else
+      SOME (@{thm binary_less_eq(2)} OF
+        [prove ctxt (t == plus (plus u (mk_binary (~ k - 1))) (mk_binary 1))])
+  end);
 
-
-  val less_eq_simproc = mk_proc "binary_nat_less_eq" "?m <= (?n::nat)"
-    (fn ctxt => fn (t, m) => fn (u, n) =>
-      let val k = n - m in
-        if k >= 0 then @{thm binary_less_eq(1)} OF [prove ctxt (u == plus t (mk_binary k))]
-        else @{thm binary_less_eq(2)} OF
-          [prove ctxt (t == plus (plus u (mk_binary (~ k - 1))) (mk_binary 1))]
-      end);
+val less_proc = binary_proc (fn ctxt => fn ((m, t), (n, u)) =>
+  let val k = m - n in
+    if k >= 0 then
+      SOME (@{thm binary_less(1)} OF [prove ctxt (t == plus u (mk_binary k))])
+    else
+      SOME (@{thm binary_less(2)} OF
+        [prove ctxt (u == plus (plus t (mk_binary (~ k - 1))) (mk_binary 1))])
+  end);
 
-  val less_simproc = mk_proc "binary_nat_less" "?m < (?n::nat)"
-    (fn ctxt => fn (t, m) => fn (u, n) =>
-      let val k = m - n in
-        if k >= 0 then @{thm binary_less(1)} OF [prove ctxt (t == plus u (mk_binary k))]
-        else @{thm binary_less(2)} OF
-          [prove ctxt (u == plus (plus t (mk_binary (~ k - 1))) (mk_binary 1))]
-      end);
+val diff_proc = binary_proc (fn ctxt => fn ((m, t), (n, u)) =>
+  let val k = m - n in
+    if k >= 0 then
+      SOME (@{thm binary_diff(1)} OF [prove ctxt (t == plus u (mk_binary k))])
+    else
+      SOME (@{thm binary_diff(2)} OF [prove ctxt (u == plus t (mk_binary (~ k)))])
+  end);
 
-  val diff_simproc = mk_proc "binary_nat_diff" "?m - (?n::nat)"
-    (fn ctxt => fn (t, m) => fn (u, n) =>
-      let val k = m - n in
-        if k >= 0 then @{thm binary_diff(1)} OF [prove ctxt (t == plus u (mk_binary k))]
-        else @{thm binary_diff(2)} OF [prove ctxt (u == plus t (mk_binary (~ k)))]
-      end);
+fun divmod_proc rule = binary_proc (fn ctxt => fn ((m, t), (n, u)) =>
+  if n = 0 then NONE
+  else
+    let val (k, l) = IntInf.divMod (m, n)
+    in SOME (rule OF [prove ctxt (t == plus (mult u (mk_binary k)) (mk_binary l))]) end);
+
+end;
+*}
 
-  fun divmod_proc rule ctxt (t, m) (u, n) =
-    if n = 0 then raise FAIL
-    else
-      let val (k, l) = IntInf.divMod (m, n)
-      in rule OF [prove ctxt (t == plus (mult u (mk_binary k)) (mk_binary l))] end;
+simproc_setup binary_nat_less_eq ("m <= (n::nat)") = {* K less_eq_proc *}
+simproc_setup binary_nat_less ("m < (n::nat)") = {* K less_proc *}
+simproc_setup binary_nat_diff ("m - (n::nat)") = {* K diff_proc *}
+simproc_setup binary_nat_div ("m div (n::nat)") = {* K (divmod_proc @{thm binary_divmod(1)}) *}
+simproc_setup binary_nat_mod ("m mod (n::nat)") = {* K (divmod_proc @{thm binary_divmod(2)}) *}
 
-  val div_simproc = mk_proc "binary_nat_div" "?m div (?n::nat)"
-    (divmod_proc @{thm binary_divmod(1)});
-  val mod_simproc = mk_proc "binary_nat_mod" "?m mod (?n::nat)"
-    (divmod_proc @{thm binary_divmod(2)});
-in
-  val binary_nat_simprocs =
-    [less_eq_simproc, less_simproc, diff_simproc, div_simproc, mod_simproc];
-end
-*}
+method_setup binary_simp = {*
+  Method.no_args (Method.SIMPLE_METHOD'
+    (full_simp_tac
+      (HOL_basic_ss
+        addsimps @{thms binary_simps}
+        addsimprocs
+         [@{simproc binary_nat_less_eq},
+          @{simproc binary_nat_less},
+          @{simproc binary_nat_diff},
+          @{simproc binary_nat_div},
+          @{simproc binary_nat_mod}])))
+*} "binary simplification"
 
 
 subsection {* Concrete syntax *}
@@ -198,12 +208,6 @@
 
 subsection {* Examples *}
 
-method_setup binary_simp = {*
-  Method.no_args (Method.SIMPLE_METHOD'
-    (full_simp_tac (HOL_basic_ss addsimps @{thms binary_simps} addsimprocs binary_nat_simprocs)))
-*} "binary simplification"
-
-
 lemma "$6 = 6"
   by (simp add: bit_simps)