src/HOL/Tools/Nitpick/nitpick.ML
changeset 35335 f715cfde056a
parent 35334 b83b9f2a4b92
child 35384 88dbcfe75c45
--- a/src/HOL/Tools/Nitpick/nitpick.ML	Tue Feb 23 16:53:13 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick.ML	Tue Feb 23 19:10:25 2010 +0100
@@ -58,8 +58,8 @@
   val register_codatatype : typ -> string -> styp list -> theory -> theory
   val unregister_codatatype : typ -> theory -> theory
   val pick_nits_in_term :
-    Proof.state -> params -> bool -> int -> int -> int -> term list -> term
-    -> string * Proof.state
+    Proof.state -> params -> bool -> int -> int -> int -> (term * term) list
+    -> term list -> term -> string * Proof.state
   val pick_nits_in_subgoal :
     Proof.state -> params -> bool -> int -> int -> string * Proof.state
 end;
@@ -187,10 +187,10 @@
 (* (unit -> string) -> Pretty.T *)
 fun plazy f = Pretty.blk (0, pstrs (f ()))
 
-(* Time.time -> Proof.state -> params -> bool -> int -> int -> int -> term
-   -> string * Proof.state *)
+(* Time.time -> Proof.state -> params -> bool -> int -> int -> int
+   -> (term * term) list -> term list -> term -> string * Proof.state *)
 fun pick_them_nits_in_term deadline state (params : params) auto i n step
-                           orig_assm_ts orig_t =
+                           subst orig_assm_ts orig_t =
   let
     val timer = Timer.startRealTimer ()
     val thy = Proof.theory_of state
@@ -237,6 +237,7 @@
       if passed_deadline deadline then raise TimeLimit.TimeOut
       else raise Interrupt
 
+    val orig_assm_ts = if assms orelse auto then orig_assm_ts else []
     val _ =
       if step = 0 then
         print_m (fn () => "Nitpicking formula...")
@@ -249,10 +250,7 @@
                    "goal")) [Logic.list_implies (orig_assm_ts, orig_t)]))
     val neg_t = if falsify then Logic.mk_implies (orig_t, @{prop False})
                 else orig_t
-    val assms_t = if assms orelse auto then
-                    Logic.mk_conjunction_list (neg_t :: orig_assm_ts)
-                  else
-                    neg_t
+    val assms_t = Logic.mk_conjunction_list (neg_t :: orig_assm_ts)
     val (assms_t, evals) =
       assms_t :: evals |> merge_type_vars ? merge_type_vars_in_terms
                        |> pairf hd tl
@@ -265,12 +263,12 @@
 *)
     val max_bisim_depth = fold Integer.max bisim_depths ~1
     val case_names = case_const_names thy stds
-    val (defs, built_in_nondefs, user_nondefs) = all_axioms_of thy
-    val def_table = const_def_table ctxt defs
+    val (defs, built_in_nondefs, user_nondefs) = all_axioms_of thy subst
+    val def_table = const_def_table ctxt subst defs
     val nondef_table = const_nondef_table (built_in_nondefs @ user_nondefs)
-    val simp_table = Unsynchronized.ref (const_simp_table ctxt)
-    val psimp_table = const_psimp_table ctxt
-    val intro_table = inductive_intro_table ctxt def_table
+    val simp_table = Unsynchronized.ref (const_simp_table ctxt subst)
+    val psimp_table = const_psimp_table ctxt subst
+    val intro_table = inductive_intro_table ctxt subst def_table
     val ground_thm_table = ground_theorem_table thy
     val ersatz_table = ersatz_table thy
     val (hol_ctxt as {wf_cache, ...}) =
@@ -941,10 +939,10 @@
            else
              error "Nitpick was interrupted."
 
-(* Proof.state -> params -> bool -> int -> int -> int -> term
-   -> string * Proof.state *)
+(* Proof.state -> params -> bool -> int -> int -> int -> (term * term) list
+   -> term list -> term -> string * Proof.state *)
 fun pick_nits_in_term state (params as {debug, timeout, expect, ...}) auto i n
-                      step orig_assm_ts orig_t =
+                      step subst orig_assm_ts orig_t =
   if getenv "KODKODI" = "" then
     (if auto then ()
      else warning (Pretty.string_of (plazy install_kodkodi_message));
@@ -954,13 +952,27 @@
       val deadline = Option.map (curry Time.+ (Time.now ())) timeout
       val outcome as (outcome_code, _) =
         time_limit (if debug then NONE else timeout)
-            (pick_them_nits_in_term deadline state params auto i n step
+            (pick_them_nits_in_term deadline state params auto i n step subst
                                     orig_assm_ts) orig_t
     in
       if expect = "" orelse outcome_code = expect then outcome
       else error ("Unexpected outcome: " ^ quote outcome_code ^ ".")
     end
 
+(* string list -> term -> bool *)
+fun is_fixed_equation fixes
+                      (Const (@{const_name "=="}, _) $ Free (s, _) $ Const _) =
+    member (op =) fixes s
+  | is_fixed_equation _ _ = false
+(* Proof.context -> term list * term -> (term * term) list * term list * term *)
+fun extract_fixed_frees ctxt (assms, t) =
+  let
+    val fixes = Variable.fixes_of ctxt |> map snd
+    val (subst, other_assms) =
+      List.partition (is_fixed_equation fixes) assms
+      |>> map Logic.dest_equals
+  in (subst, other_assms, subst_atomic subst t) end
+
 (* Proof.state -> params -> bool -> int -> int -> string * Proof.state *)
 fun pick_nits_in_subgoal state params auto i step =
   let
@@ -971,12 +983,11 @@
       0 => (priority "No subgoal!"; ("none", state))
     | n =>
       let
+        val (t, frees) = Logic.goal_params t i
+        val t = subst_bounds (frees, t)
         val assms = map term_of (Assumption.all_assms_of ctxt)
-        val (t, frees) = Logic.goal_params t i
-      in
-        pick_nits_in_term state params auto i n step assms
-                          (subst_bounds (frees, t))
-      end
+        val (subst, assms, t) = extract_fixed_frees ctxt (assms, t)
+      in pick_nits_in_term state params auto i n step subst assms t end
   end
 
 end;