src/HOL/SMT/Tools/smt_normalize.ML
changeset 36896 c030819254d3
parent 36893 48cf03469dc6
--- a/src/HOL/SMT/Tools/smt_normalize.ML	Wed May 12 23:53:59 2010 +0200
+++ b/src/HOL/SMT/Tools/smt_normalize.ML	Wed May 12 23:54:00 2010 +0200
@@ -16,10 +16,7 @@
 
 signature SMT_NORMALIZE =
 sig
-  val instantiate_free: cterm * cterm -> thm -> thm
-  val discharge_definition: cterm -> thm -> thm
-
-  val normalize: Proof.context -> thm list -> cterm list * thm list
+  val normalize: thm list -> Proof.context -> thm list * Proof.context
 end
 
 structure SMT_Normalize: SMT_NORMALIZE =
@@ -31,18 +28,6 @@
 fun if_conv c cv1 cv2 ct = (if c (Thm.term_of ct) then cv1 else cv2) ct
 fun if_true_conv c cv = if_conv c cv Conv.all_conv
 
-fun instantiate_free (cv, ct) =
-  (Term.exists_subterm (equal (Thm.term_of cv)) o Thm.prop_of) ??
-  (Thm.forall_elim ct o Thm.forall_intr cv)
-
-fun discharge_definition ct thm =
-  let val (cv, cu) = Thm.dest_equals ct
-  in
-    Thm.implies_intr ct thm
-    |> instantiate_free (cv, cu)
-    |> (fn thm => Thm.implies_elim thm (Thm.reflexive cu))
-  end
-
 
 
 (* simplification of trivial distincts (distinct should have at least
@@ -332,35 +317,34 @@
   fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
   fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
 
-  fun norm_meta_def cv thm = 
-    let val thm' = Thm.combination thm (Thm.reflexive cv)
-    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
-
   fun cert ctxt = Thm.cterm_of (ProofContext.theory_of ctxt)
 
-  val fresh_name = yield_singleton Name.variants
-
   fun used_vars cvs ct =
     let
       val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
-      val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
+      val add = (fn SOME ct => insert (op aconvc) ct | _ => I)
     in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
-  fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
-  fun add_def ct thm = Termtab.update (Thm.term_of ct, (serial (), thm))
 
-  fun replace ctxt cvs ct (cx as (nctxt, defs)) =
+  fun apply cv thm = 
+    let val thm' = Thm.combination thm (Thm.reflexive cv)
+    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
+  fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq)
+
+  fun replace_lambda cvs ct (cx as (ctxt, defs)) =
     let
       val cvs' = used_vars cvs ct
       val ct' = fold_rev Thm.cabs cvs' ct
     in
       (case Termtab.lookup defs (Thm.term_of ct') of
-        SOME (_, eq) => (make_def cvs' eq, cx)
+        SOME eq => (apply_def cvs' eq, cx)
       | NONE =>
           let
-            val {T, ...} = Thm.rep_cterm ct'
-            val (n, nctxt') = fresh_name "" nctxt
-            val eq = Thm.assume (mk_meta_eq (cert ctxt (Free (n, T))) ct')
-          in (make_def cvs' eq, (nctxt', add_def ct' eq defs)) end)
+            val {T, ...} = Thm.rep_cterm ct' and n = Name.uu
+            val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
+            val cu = mk_meta_eq (cert ctxt (Free (n', T))) ct'
+            val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt'
+            val defs' = Termtab.update (Thm.term_of ct', eq) defs
+          in (apply_def cvs' eq, (ctxt'', defs')) end)
     end
 
   fun none ct cx = (Thm.reflexive ct, cx)
@@ -368,28 +352,25 @@
     let val (cu1, cu2) = Thm.dest_comb ct
     in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
   fun in_arg f = in_comb none f
-  fun in_abs f cvs ct (nctxt, defs) =
-    let
-      val (n, nctxt') = fresh_name Name.uu nctxt
-      val (cv, cu) = Thm.dest_abs (SOME n) ct
-    in f (cv :: cvs) cu (nctxt', defs) |>> Thm.abstract_rule n cv end
-
-  fun replace_lambdas ctxt =
+  fun in_abs f cvs ct (ctxt, defs) =
     let
-      fun repl cvs ct =
-        (case Thm.term_of ct of
-          Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs repl cvs)
-        | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs repl cvs)
-        | Const _ $ Abs _ => in_arg (at_lambda cvs)
-        | Const (@{const_name Let}, _) $ _ $ Abs _ =>
-            in_comb (in_arg (repl cvs)) (in_abs repl cvs)
-        | Abs _ => at_lambda cvs
-        | _ $ _ => in_comb (repl cvs) (repl cvs)
-        | _ => none) ct
-      and at_lambda cvs ct =
-        in_abs repl cvs ct #-> (fn thm =>
-        replace ctxt cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
-    in repl [] end
+      val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
+      val (cv, cu) = Thm.dest_abs (SOME n) ct
+    in  (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end
+
+  fun traverse cvs ct =
+    (case Thm.term_of ct of
+      Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs)
+    | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs)
+    | Const (@{const_name Let}, _) $ _ $ Abs _ =>
+        in_comb (in_arg (traverse cvs)) (in_abs traverse cvs)
+    | Abs _ => at_lambda cvs
+    | _ $ _ => in_comb (traverse cvs) (traverse cvs)
+    | _ => none) ct
+
+  and at_lambda cvs ct =
+    in_abs traverse cvs ct #-> (fn thm =>
+    replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
 
   fun has_free_lambdas t =
     (case t of
@@ -400,26 +381,17 @@
     | Abs _ => true
     | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2
     | _ => false)
+
+  fun lift_lm f thm cx =
+    if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
+    else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm)
 in
-fun lift_lambdas ctxt thms =
+fun lift_lambdas thms ctxt =
   let
-    val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
-    fun rewrite f thm cx =
-      if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
-      else f (Thm.cprop_of thm) cx |>> (fn thm' => Thm.equal_elim thm' thm)
-
-    val rev_int_fst_ord = rev_order o int_ord o pairself fst
-    fun ordered_values tab =
-      Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
-      |> map snd
-
-    val (thms', (_, defs)) =
-      (declare_frees thms (Name.make_context []), Termtab.empty)
-      |> fold_map (rewrite (replace_lambdas ctxt)) thms
-    val eqs = ordered_values defs
-  in
-    (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
-  end
+    val cx = (ctxt, Termtab.empty)
+    val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx
+    val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs []
+  in (eqs @ thms', ctxt') end
 end
 
 
@@ -483,14 +455,16 @@
 
 (* combined normalization *)
 
-fun normalize ctxt thms =
+fun normalize thms ctxt =
   thms
   |> trivial_distinct ctxt
   |> rewrite_bool_cases ctxt
   |> normalize_numerals ctxt
   |> nat_as_int ctxt
   |> map (unfold_defs ctxt #> normalize_rule ctxt)
-  |> lift_lambdas ctxt
-  |> apsnd (explicit_application ctxt)
+  |> rpair ctxt
+  |-> SMT_Monomorph.monomorph
+  |-> lift_lambdas
+  |-> (fn thms' => `(fn ctxt' => explicit_application ctxt' thms'))
 
 end