apply preprocessing simpset also to rhs of abstract code equations
authorhaftmann
Tue, 05 Jun 2012 10:12:54 +0200
changeset 48075 ec5e62b868eb
parent 48074 c6d514717d7b
child 48076 7a0b858fa63b
apply preprocessing simpset also to rhs of abstract code equations
src/HOL/Library/Target_Numeral.thy
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
--- a/src/HOL/Library/Target_Numeral.thy	Tue Jun 05 07:11:49 2012 +0200
+++ b/src/HOL/Library/Target_Numeral.thy	Tue Jun 05 10:12:54 2012 +0200
@@ -661,30 +661,30 @@
   by (simp add: Target_Numeral.int_eq_iff)
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (m + n) = Target_Numeral.of_nat m + Target_Numeral.of_nat n"
+  "Target_Numeral.of_nat (m + n) = of_nat m + of_nat n"
   by (simp add: Target_Numeral.int_eq_iff)
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (Code_Nat.dup n) = Target_Numeral.dup (Target_Numeral.of_nat n)"
+  "Target_Numeral.of_nat (Code_Nat.dup n) = Target_Numeral.dup (of_nat n)"
   by (simp add: Target_Numeral.int_eq_iff Code_Nat.dup_def)
 
 lemma [code, code del]:
   "Code_Nat.sub = Code_Nat.sub" ..
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (m - n) = max 0 (Target_Numeral.of_nat m - Target_Numeral.of_nat n)"
+  "Target_Numeral.of_nat (m - n) = max 0 (of_nat m - of_nat n)"
   by (simp add: Target_Numeral.int_eq_iff)
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (m * n) = Target_Numeral.of_nat m * Target_Numeral.of_nat n"
+  "Target_Numeral.of_nat (m * n) = of_nat m * of_nat n"
   by (simp add: Target_Numeral.int_eq_iff of_nat_mult)
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (m div n) = Target_Numeral.of_nat m div Target_Numeral.of_nat n"
+  "Target_Numeral.of_nat (m div n) = of_nat m div of_nat n"
   by (simp add: Target_Numeral.int_eq_iff zdiv_int)
 
 lemma [code abstract]:
-  "Target_Numeral.of_nat (m mod n) = Target_Numeral.of_nat m mod Target_Numeral.of_nat n"
+  "Target_Numeral.of_nat (m mod n) = of_nat m mod of_nat n"
   by (simp add: Target_Numeral.int_eq_iff zmod_int)
 
 lemma [code]:
@@ -704,7 +704,7 @@
   by (simp add: less_int_def)
 
 lemma num_of_nat_code [code]:
-  "num_of_nat = Target_Numeral.num_of_int \<circ> Target_Numeral.of_nat"
+  "num_of_nat = Target_Numeral.num_of_int \<circ> of_nat"
   by (simp add: fun_eq_iff num_of_int_def of_nat_def)
 
 lemma (in semiring_1) of_nat_code:
--- a/src/Pure/Isar/code.ML	Tue Jun 05 07:11:49 2012 +0200
+++ b/src/Pure/Isar/code.ML	Tue Jun 05 10:12:54 2012 +0200
@@ -65,7 +65,8 @@
   val get_type_of_constr_or_abstr: theory -> string -> (string * bool) option
   val is_constr: theory -> string -> bool
   val is_abstr: theory -> string -> bool
-  val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
+  val get_cert: theory -> { functrans: ((thm * bool) list -> (thm * bool) list option) list,
+    ss: simpset } -> string -> cert
   val get_case_scheme: theory -> string -> (int * (int * string option list)) option
   val get_case_cong: theory -> string -> thm option
   val undefineds: theory -> string list
@@ -854,23 +855,39 @@
   |> Option.map (snd o fst)
   |> the_default empty_fun_spec
 
-fun get_cert thy f c = case retrieve_raw thy c
- of Default (_, eqns_lazy) => Lazy.force eqns_lazy
-      |> (map o apfst) (Thm.transfer thy)
-      |> f
-      |> (map o apfst) (AxClass.unoverload thy)
-      |> cert_of_eqns thy c
-  | Eqns eqns => eqns
-      |> (map o apfst) (Thm.transfer thy)
-      |> f
-      |> (map o apfst) (AxClass.unoverload thy)
-      |> cert_of_eqns thy c
-  | Proj (_, tyco) =>
-      cert_of_proj thy c tyco
-  | Abstr (abs_thm, tyco) => abs_thm
-      |> Thm.transfer thy
-      |> AxClass.unoverload thy
-      |> cert_of_abs thy tyco c;
+fun eqn_conv conv ct =
+  let
+    fun lhs_conv ct = if can Thm.dest_comb ct
+      then Conv.combination_conv lhs_conv conv ct
+      else Conv.all_conv ct;
+  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
+
+fun rewrite_eqn thy conv ss =
+  let
+    val ctxt = Proof_Context.init_global thy;
+    val rewrite = Conv.fconv_rule (conv (Simplifier.rewrite ss));
+  in singleton (Variable.trade (K (map rewrite)) ctxt) end;
+
+fun cert_of_eqns_preprocess thy functrans ss c =
+  (map o apfst) (Thm.transfer thy)
+  #> perhaps (perhaps_loop (perhaps_apply functrans))
+  #> (map o apfst) (rewrite_eqn thy eqn_conv ss) 
+  #> (map o apfst) (AxClass.unoverload thy)
+  #> cert_of_eqns thy c;
+
+fun get_cert thy { functrans, ss } c =
+  case retrieve_raw thy c
+   of Default (_, eqns_lazy) => Lazy.force eqns_lazy
+        |> cert_of_eqns_preprocess thy functrans ss c
+    | Eqns eqns => eqns
+        |> cert_of_eqns_preprocess thy functrans ss c
+    | Proj (_, tyco) =>
+        cert_of_proj thy c tyco
+    | Abstr (abs_thm, tyco) => abs_thm
+        |> Thm.transfer thy
+        |> rewrite_eqn thy Conv.arg_conv ss
+        |> AxClass.unoverload thy
+        |> cert_of_abs thy tyco c;
 
 
 (* cases *)
--- a/src/Tools/Code/code_preproc.ML	Tue Jun 05 07:11:49 2012 +0200
+++ b/src/Tools/Code/code_preproc.ML	Tue Jun 05 10:12:54 2012 +0200
@@ -14,7 +14,6 @@
   val del_functrans: string -> theory -> theory
   val simple_functrans: (theory -> thm list -> thm list option)
     -> theory -> (thm * bool) list -> (thm * bool) list option
-  val preprocess_functrans: theory -> (thm * bool) list -> (thm * bool) list
   val print_codeproc: theory -> unit
 
   type code_algebra
@@ -124,15 +123,6 @@
 
 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
-fun eqn_conv conv ct =
-  let
-    fun lhs_conv ct = if can Thm.dest_comb ct
-      then Conv.combination_conv lhs_conv conv ct
-      else Conv.all_conv ct;
-  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
-
-val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
-
 fun term_of_conv thy conv =
   Thm.cterm_of thy
   #> conv
@@ -148,22 +138,6 @@
     val resubst = curry (Term.betapplys o swap) all_vars;
   in (resubst, term_of_conv thy conv (fold_rev lambda all_vars t)) end;
 
-
-fun preprocess_functrans thy = 
-  let
-    val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
-      o the_thmproc) thy;
-  in perhaps (perhaps_loop (perhaps_apply functrans)) end;
-
-fun preprocess thy =
-  let
-    val ctxt = Proof_Context.init_global thy;
-    val pre = (Simplifier.global_context thy o #pre o the_thmproc) thy;
-  in
-    preprocess_functrans thy
-    #> (map o apfst) (singleton (Variable.trade (K (map (rewrite_eqn pre))) ctxt)) 
-  end;
-
 fun preprocess_conv thy =
   let
     val pre = (Simplifier.global_context thy o #pre o the_thmproc) thy;
@@ -285,7 +259,10 @@
   case try (Graph.get_node eqngr) c
    of SOME (lhs, cert) => ((lhs, []), cert)
     | NONE => let
-        val cert = Code.get_cert thy (preprocess thy) c;
+        val functrans = (map (fn (_, (_, f)) => f thy)
+          o #functrans o the_thmproc) thy;
+        val pre = (Simplifier.global_context thy o #pre o the_thmproc) thy;
+        val cert = Code.get_cert thy { functrans = functrans, ss = pre } c;
         val (lhs, rhss) = Code.typargs_deps_of_cert thy cert;
       in ((lhs, rhss), cert) end;