improved preprocessing
authorhaftmann
Mon, 21 Aug 2006 11:02:43 +0200
changeset 20404 1a29e6c3ab04
parent 20403 14d5f6ed5602
child 20405 8276fd8d1919
improved preprocessing
src/Pure/Tools/codegen_theorems.ML
--- a/src/Pure/Tools/codegen_theorems.ML	Mon Aug 21 11:02:42 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Mon Aug 21 11:02:43 2006 +0200
@@ -21,6 +21,7 @@
   val notify_dirty: theory -> theory;
 
   val extr_typ: theory -> thm -> typ;
+  val rewrite_fun: thm list -> thm -> thm;
   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
   val preprocess: theory -> thm list -> thm list;
 
@@ -51,7 +52,7 @@
 (* diagnostics *)
 
 val debug = ref false;
-fun debug_msg f x = (if !debug then Output.debug (f x) else (); x);
+fun debug_msg f x = (if !debug then Output.tracing (f x) else (); x);
 
 
 (* auxiliary *)
@@ -184,6 +185,9 @@
 fun err_thm msg thm =
   error (msg ^ ": " ^ string_of_thm thm);
 
+val mk_rule =
+  #mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of;
+
 fun abs_norm thy thm =
   let
     fun expvars t =
@@ -255,6 +259,9 @@
           drop (eq::eqs) (filter_out (matches eq) eqs')
   in drop [] eqs end;
 
+fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
+  o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
+
 fun make_eq thy =
   let
     val ((_, atomize), _) = get_obj thy;
@@ -401,14 +408,11 @@
       (Pretty.writeln o Pretty.block o Pretty.fbreaks) ([
         Pretty.str "code generation theorems:",
         Pretty.str "function theorems:" ] @
-        (*Pretty.fbreaks ( *)
           map (fn (c, thms) =>
             (Pretty.block o Pretty.fbreaks) (
               (Pretty.str o CodegenConsts.string_of_const thy) c  :: map pretty_thm (rev thms)
             )
-          ) funs
-        (*) *) @ [
-        Pretty.fbrk,
+          ) funs @ [
         Pretty.block (
           Pretty.str "inlined theorems:"
           :: Pretty.fbrk
@@ -543,13 +547,19 @@
 fun extr_typ thy thm = case dest_fun thy thm
  of (_, (ty, _)) => ty;
 
-fun rewrite_rhs conv thm = (case (Drule.strip_comb o cprop_of) thm
- of (ct', [ct1, ct2]) => (case term_of ct'
-     of Const ("==", _) =>
-          Thm.equal_elim (combination (combination (reflexive ct') (reflexive ct1))
-            (conv ct2)) thm
-      | _ => raise ERROR "rewrite_rhs")
-  | _ => raise ERROR "rewrite_rhs");
+fun rewrite_fun rewrites thm =
+  let
+    val rewrite = Tactic.rewrite true rewrites;
+    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o cprop_of) thm;
+    val Const ("==", _) = term_of ct_eq;
+    val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
+    val rhs' = rewrite ct_rhs;
+    val args' = map rewrite ct_args;
+    val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
+      args' (Thm.reflexive ct_f));
+  in
+    Thm.transitive (Thm.transitive lhs' thm) rhs'
+  end handle Bind => raise ERROR "rewrite_fun"
 
 fun common_typ thy _ [] = []
   | common_typ thy _ [thm] = [thm]
@@ -585,13 +595,13 @@
       not (Sign.typ_instance thy (extr_typ thy thm1, extr_typ thy thm2));
     fun unvarify thms =
       #2 (#1 (Variable.import true thms (ProofContext.init thy)));
-    val unfold_thms = Tactic.rewrite true (map (make_eq thy) (the_unfolds thy));
+    val unfold_thms = map (make_eq thy) (the_unfolds thy);
   in
     thms
     |> map (make_eq thy)
     |> map (Thm.transfer thy)
     |> fold (fn f => f thy) (the_preprocs thy)
-    |> map (rewrite_rhs unfold_thms)
+    |> map (rewrite_fun unfold_thms)
     |> debug_msg (fn _ => "[cg_thm] sorting")
     |> debug_msg (commas o map string_of_thm)
     |> sort (make_ord cmp_thms)
@@ -601,6 +611,7 @@
     |> debug_msg (fn _ => "[cg_thm] abs_norm")
     |> debug_msg (commas o map string_of_thm)
     |> map (abs_norm thy)
+    |> drop_refl thy
     |> burrow_thms (
         debug_msg (fn _ => "[cg_thm] canonical tvars")
         #> debug_msg (string_of_thm)
@@ -684,10 +695,10 @@
     fun mk_eqs (vs, cos) =
       let val cos' = rev cos
       in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
-  in
-    map (fn t => Goal.prove_global thy [] []
-        (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos)
-  end;
+    val ts = (map (ObjectLogic.ensure_propT thy) o mk_eqs) vs_cos;
+    fun prove t = if !quick_and_dirty then SkipProof.make_thm thy (Logic.varify t)
+      else Goal.prove_global thy [] [] t (K tac);
+  in map prove ts end;
 
 fun get_datatypes thy dtco =
   let