src/HOL/Tools/SMT/smt_normalize.ML
changeset 40278 0fc78bb54f18
parent 40275 eed48b11abdb
child 40279 96365b4ae7b6
--- a/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:08 2010 +0200
+++ b/src/HOL/Tools/SMT/smt_normalize.ML	Fri Oct 29 18:17:09 2010 +0200
@@ -19,7 +19,8 @@
 sig
   type extra_norm = bool -> (int * thm) list -> Proof.context ->
     (int * thm) list * Proof.context
-  val normalize: extra_norm -> bool -> (int * thm) list -> Proof.context ->
+  val normalize: (Proof.context -> (thm -> string) -> thm -> unit) -> bool ->
+    extra_norm -> bool -> (int * thm) list -> Proof.context ->
     (int * thm) list * Proof.context
   val atomize_conv: Proof.context -> conv
   val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv
@@ -486,18 +487,28 @@
 
 fun with_context f irules ctxt = (f ctxt irules, ctxt)
 
-fun normalize extra_norm with_datatypes irules ctxt =
-  irules
-  |> trivial_distinct ctxt
-  |> rewrite_bool_cases ctxt
-  |> normalize_numerals ctxt
-  |> nat_as_int ctxt
-  |> rpair ctxt
-  |-> extra_norm with_datatypes
-  |-> with_context (fn cx => map (apsnd (normalize_rule cx)))
-  |-> SMT_Monomorph.monomorph
-  |-> lift_lambdas
-  |-> with_context explicit_application
-  |-> (if with_datatypes then datatype_selectors else pair)
+fun normalize trace keep_assms extra_norm with_datatypes irules ctxt =
+  let
+    fun norm f ctxt' (i, thm) =
+      if keep_assms then SOME (i, f ctxt' thm)
+      else
+        (case try (f ctxt') thm of
+          SOME thm' => SOME (i, thm')
+        | NONE => (trace ctxt' (prefix ("SMT warning: " ^
+            "dropping assumption: ") o Display.string_of_thm ctxt') thm; NONE))
+  in
+    irules
+    |> trivial_distinct ctxt
+    |> rewrite_bool_cases ctxt
+    |> normalize_numerals ctxt
+    |> nat_as_int ctxt
+    |> rpair ctxt
+    |-> extra_norm with_datatypes
+    |-> with_context (map_filter o norm normalize_rule)
+    |-> SMT_Monomorph.monomorph
+    |-> lift_lambdas
+    |-> with_context explicit_application
+    |-> (if with_datatypes then datatype_selectors else pair)
+  end
 
 end