src/HOL/Tools/Nitpick/nitpick_mono.ML
changeset 35386 45a4e19d3ebd
parent 35385 29f81babefd7
child 35665 ff2bf50505ab
--- a/src/HOL/Tools/Nitpick/nitpick_mono.ML	Thu Feb 25 16:33:39 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_mono.ML	Fri Feb 26 16:49:46 2010 +0100
@@ -10,7 +10,7 @@
   type hol_context = Nitpick_HOL.hol_context
 
   val formulas_monotonic :
-    hol_context -> bool -> typ -> term list * term list * term -> bool
+    hol_context -> bool -> typ -> term list * term list -> bool
 end;
 
 structure Nitpick_Mono : NITPICK_MONO =
@@ -34,7 +34,7 @@
   MRec of string * typ list
 
 datatype mterm =
-  MAtom of term * mtyp |
+  MRaw of term * mtyp |
   MAbs of string * typ * mtyp * sign_atom * mterm |
   MApp of mterm * mterm
 
@@ -76,7 +76,7 @@
 fun string_for_literal (x, sn) = string_for_var x ^ " = " ^ string_for_sign sn
 
 val bool_M = MType (@{type_name bool}, [])
-val irrelevant_M = MType (nitpick_prefix ^ "irrelevant", [])
+val dummy_M = MType (nitpick_prefix ^ "dummy", [])
 
 (* mtyp -> bool *)
 fun is_MRec (MRec _) = true
@@ -102,16 +102,19 @@
         val need_parens = (prec < outer_prec)
       in
         (if need_parens then "(" else "") ^
-        (case M of
-           MAlpha => "\<alpha>"
-         | MFun (M1, a, M2) =>
-           aux (prec + 1) M1 ^ " \<Rightarrow>\<^bsup>" ^
-           string_for_sign_atom a ^ "\<^esup> " ^ aux prec M2
-         | MPair (M1, M2) => aux (prec + 1) M1 ^ " \<times> " ^ aux prec M2
-         | MType (s, []) =>
-           if s = @{type_name prop} orelse s = @{type_name bool} then "o" else s
-         | MType (s, Ms) => "(" ^ commas (map (aux 0) Ms) ^ ") " ^ s
-         | MRec (s, _) => "[" ^ s ^ "]") ^
+        (if M = dummy_M then
+           "_"
+         else case M of
+             MAlpha => "\<alpha>"
+           | MFun (M1, a, M2) =>
+             aux (prec + 1) M1 ^ " \<Rightarrow>\<^bsup>" ^
+             string_for_sign_atom a ^ "\<^esup> " ^ aux prec M2
+           | MPair (M1, M2) => aux (prec + 1) M1 ^ " \<times> " ^ aux prec M2
+           | MType (s, []) =>
+             if s = @{type_name prop} orelse s = @{type_name bool} then "o"
+             else s
+           | MType (s, Ms) => "(" ^ commas (map (aux 0) Ms) ^ ") " ^ s
+           | MRec (s, _) => "[" ^ s ^ "]") ^
         (if need_parens then ")" else "")
       end
   in aux 0 end
@@ -122,7 +125,7 @@
   | flatten_mtype M = [M]
 
 (* mterm -> bool *)
-fun precedence_of_mterm (MAtom _) = no_prec
+fun precedence_of_mterm (MRaw _) = no_prec
   | precedence_of_mterm (MAbs _) = 1
   | precedence_of_mterm (MApp _) = 2
 
@@ -139,7 +142,7 @@
       in
         (if need_parens then "(" else "") ^
         (case m of
-           MAtom (t, M) => Syntax.string_of_term ctxt t ^ mtype_annotation M
+           MRaw (t, M) => Syntax.string_of_term ctxt t ^ mtype_annotation M
          | MAbs (s, _, M, a, m) =>
            "\<lambda>" ^ s ^ mtype_annotation M ^ ".\<^bsup>" ^
            string_for_sign_atom a ^ "\<^esup> " ^ aux prec m
@@ -149,7 +152,7 @@
   in aux 0 end
 
 (* mterm -> mtyp *)
-fun mtype_of_mterm (MAtom (_, M)) = M
+fun mtype_of_mterm (MRaw (_, M)) = M
   | mtype_of_mterm (MAbs (_, _, M, a, m)) = MFun (M, a, mtype_of_mterm m)
   | mtype_of_mterm (MApp (m1, _)) =
     case mtype_of_mterm m1 of
@@ -545,19 +548,28 @@
 fun solve _ UnsolvableCSet = (print_g "*** Problem: Unsolvable"; NONE)
   | solve max_var (CSet (lits, comps, sexps)) =
     let
+      (* (int -> bool option) -> literal list option *)
+      fun do_assigns assigns =
+        SOME (literals_from_assignments max_var assigns lits
+              |> tap print_solution)
       val _ = print_problem lits comps sexps
       val prop = PropLogic.all (map prop_for_literal lits @
                                 map prop_for_comp comps @
                                 map prop_for_sign_expr sexps)
-      (* use the first ML solver (to avoid startup overhead) *)
-      val solvers = !SatSolver.solvers
-                    |> filter (member (op =) ["dptsat", "dpll"] o fst)
+      val default_val = bool_from_sign Minus
     in
-      case snd (hd solvers) prop of
-        SatSolver.SATISFIABLE assigns =>
-        SOME (literals_from_assignments max_var assigns lits
-              |> tap print_solution)
-      | _ => NONE
+      if PropLogic.eval (K default_val) prop then
+        do_assigns (K (SOME default_val))
+      else
+        let
+          (* use the first ML solver (to avoid startup overhead) *)
+          val solvers = !SatSolver.solvers
+                        |> filter (member (op =) ["dptsat", "dpll"] o fst)
+        in
+          case snd (hd solvers) prop of
+            SatSolver.SATISFIABLE assigns => do_assigns assigns
+          | _ => NONE
+        end
     end
 
 type mtype_schema = mtyp * constraint_set
@@ -580,7 +592,7 @@
   handle List.Empty => initial_gamma
 
 (* mdata -> term -> accumulator -> mterm * accumulator *)
-fun consider_term (mdata as {hol_ctxt = {thy, ctxt, stds, fast_descrs,
+fun consider_term (mdata as {hol_ctxt as {thy, ctxt, stds, fast_descrs,
                                          def_table, ...},
                              alpha_T, max_fresh, ...}) =
   let
@@ -595,7 +607,7 @@
     fun do_all T (gamma, cset) =
       let
         val abs_M = mtype_for (domain_type (domain_type T))
-        val body_M = mtype_for (range_type T)
+        val body_M = mtype_for (body_type T)
       in
         (MFun (MFun (abs_M, S Minus, body_M), S Minus, body_M),
          (gamma, cset |> add_mtype_is_right_total abs_M))
@@ -641,9 +653,9 @@
         pair (MFun (M, S Minus, if n = 0 then a_M else b_M))
       | M => raise MTYPE ("Nitpick_Mono.consider_term.do_nth_pair_sel", [M])
     (* mtyp * accumulator *)
-    val mtype_unsolvable = (irrelevant_M, unsolvable_accum)
+    val mtype_unsolvable = (dummy_M, unsolvable_accum)
     (* term -> mterm * accumulator *)
-    fun mterm_unsolvable t = (MAtom (t, irrelevant_M), unsolvable_accum)
+    fun mterm_unsolvable t = (MRaw (t, dummy_M), unsolvable_accum)
     (* term -> string -> typ -> term -> term -> term -> accumulator
        -> mterm * accumulator *)
     fun do_bounded_quantifier t0 abs_s abs_T connective_t bound_t body_t accum =
@@ -657,10 +669,11 @@
         val bound_M = mtype_of_mterm bound_m
         val (M1, a, M2) = dest_MFun bound_M
       in
-        (MApp (MAtom (t0, MFun (bound_M, S Minus, bool_M)),
+        (MApp (MRaw (t0, MFun (bound_M, S Minus, bool_M)),
                MAbs (abs_s, abs_T, M1, a,
-                     MApp (MApp (MAtom (connective_t, irrelevant_M),
-                                 MApp (bound_m, MAtom (Bound 0, M1))),
+                     MApp (MApp (MRaw (connective_t,
+                                       mtype_for (fastype_of connective_t)),
+                                 MApp (bound_m, MRaw (Bound 0, M1))),
                            body_m))), accum)
       end
     (* term -> accumulator -> mterm * accumulator *)
@@ -678,10 +691,14 @@
               | @{const_name "=="} => do_equals T accum
               | @{const_name All} => do_all T accum
               | @{const_name Ex} =>
-                do_term (@{const Not}
-                         $ (HOLogic.eq_const (domain_type T)
-                            $ Abs (Name.uu, T, @{const False}))) accum
-                |>> mtype_of_mterm
+                let val set_T = domain_type T in
+                  do_term (Abs (Name.uu, set_T,
+                                @{const Not} $ (HOLogic.mk_eq
+                                    (Abs (Name.uu, domain_type set_T,
+                                          @{const False}),
+                                     Bound 0)))) accum
+                  |>> mtype_of_mterm
+                end
               | @{const_name "op ="} => do_equals T accum
               | @{const_name The} => (print_g "*** The"; mtype_unsolvable)
               | @{const_name Eps} => (print_g "*** Eps"; mtype_unsolvable)
@@ -719,9 +736,12 @@
               | @{const_name rtrancl} =>
                 (print_g "*** rtrancl"; mtype_unsolvable)
               | @{const_name finite} =>
-                let val M1 = mtype_for (domain_type (domain_type T)) in
-                  (MFun (pos_set_mtype_for_dom M1, S Minus, bool_M), accum)
-                end
+                if is_finite_type hol_ctxt T then
+                  let val M1 = mtype_for (domain_type (domain_type T)) in
+                    (MFun (pos_set_mtype_for_dom M1, S Minus, bool_M), accum)
+                  end
+                else
+                  (print_g "*** finite"; mtype_unsolvable)
               | @{const_name rel_comp} =>
                 let
                   val x = Unsynchronized.inc max_fresh
@@ -807,7 +827,7 @@
                   let val M = mtype_for T in
                     (M, ({bounds = bounds, frees = frees,
                           consts = (x, M) :: consts}, cset))
-                  end) |>> curry MAtom t
+                  end) |>> curry MRaw t
          | Free (x as (_, T)) =>
            (case AList.lookup (op =) frees x of
               SOME M => (M, accum)
@@ -815,12 +835,12 @@
               let val M = mtype_for T in
                 (M, ({bounds = bounds, frees = (x, M) :: frees,
                       consts = consts}, cset))
-              end) |>> curry MAtom t
+              end) |>> curry MRaw t
          | Var _ => (print_g "*** Var"; mterm_unsolvable t)
-         | Bound j => (MAtom (t, nth bounds j), accum)
+         | Bound j => (MRaw (t, nth bounds j), accum)
          | Abs (s, T, t' as @{const False}) =>
            let val (M1, a, M2) = mfun_for T bool_T in
-             (MAbs (s, T, M1, a, MAtom (t', M2)), accum)
+             (MAbs (s, T, M1, a, MRaw (t', M2)), accum)
            end
          | Abs (s, T, t') =>
            ((case t' of
@@ -850,88 +870,109 @@
            in
              case accum of
                (_, UnsolvableCSet) => mterm_unsolvable t
-             | _ => (MApp (m1, m2), accum)
+             | _ =>
+               let
+                 val M11 = mtype_of_mterm m1 |> dest_MFun |> #1
+                 val M2 = mtype_of_mterm m2
+               in (MApp (m1, m2), accum ||> add_is_sub_mtype M2 M11) end
            end)
         |> tap (fn (m, _) => print_g ("  \<Gamma> \<turnstile> " ^
                                       string_for_mterm ctxt m))
   in do_term end
 
-(* mdata -> sign -> term -> accumulator -> accumulator *)
+(* mdata -> styp -> term -> term -> mterm * accumulator *)
+fun consider_general_equals mdata (x as (_, T)) t1 t2 accum =
+  let
+    val (m1, accum) = consider_term mdata t1 accum
+    val (m2, accum) = consider_term mdata t2 accum
+    val M1 = mtype_of_mterm m1
+    val M2 = mtype_of_mterm m2
+    val body_M = fresh_mtype_for_type mdata (nth_range_type 2 T)
+  in
+    (MApp (MApp (MRaw (Const x,
+         MFun (M1, S Minus, MFun (M2, S Minus, body_M))), m1), m2),
+     accum ||> add_mtypes_equal M1 M2)
+  end
+
+(* mdata -> sign -> term -> accumulator -> mterm * accumulator *)
 fun consider_general_formula (mdata as {hol_ctxt = {ctxt, ...}, ...}) =
   let
     (* typ -> mtyp *)
     val mtype_for = fresh_mtype_for_type mdata
-    (* term -> accumulator -> mtyp * accumulator *)
-    val do_term = apfst mtype_of_mterm oo consider_term mdata
-    (* sign -> term -> accumulator -> accumulator *)
-    fun do_formula _ _ (_, UnsolvableCSet) = unsolvable_accum
-      | do_formula sn t (accum as (gamma, cset)) =
+    (* term -> accumulator -> mterm * accumulator *)
+    val do_term = consider_term mdata
+    (* sign -> term -> accumulator -> mterm * accumulator *)
+    fun do_formula _ t (_, UnsolvableCSet) =
+        (MRaw (t, dummy_M), unsolvable_accum)
+      | do_formula sn t accum =
         let
-          (* term -> accumulator -> accumulator *)
-          val do_co_formula = do_formula sn
-          val do_contra_formula = do_formula (negate sn)
-          (* string -> typ -> term -> accumulator *)
-          fun do_quantifier quant_s abs_T body_t =
+          (* styp -> string -> typ -> term -> mterm * accumulator *)
+          fun do_quantifier (quant_x as (quant_s, _)) abs_s abs_T body_t =
             let
               val abs_M = mtype_for abs_T
               val side_cond = ((sn = Minus) = (quant_s = @{const_name Ex}))
-              val cset = cset |> side_cond ? add_mtype_is_right_total abs_M
+              val (body_m, accum) =
+                accum ||> side_cond ? add_mtype_is_right_total abs_M
+                      |>> push_bound abs_M |> do_formula sn body_t
+              val body_M = mtype_of_mterm body_m
             in
-              (gamma |> push_bound abs_M, cset)
-              |> do_co_formula body_t |>> pop_bound
+              (MApp (MRaw (Const quant_x, MFun (abs_M, S Minus, body_M)),
+                     MAbs (abs_s, abs_T, abs_M, S Minus, body_m)),
+               accum |>> pop_bound)
             end
-          (* typ -> term -> accumulator *)
-          fun do_bounded_quantifier abs_T body_t =
-            accum |>> push_bound (mtype_for abs_T) |> do_co_formula body_t
-                  |>> pop_bound
-          (* term -> term -> accumulator *)
-          fun do_equals t1 t2 =
+          (* styp -> term -> term -> mterm * accumulator *)
+          fun do_equals x t1 t2 =
             case sn of
-              Plus => do_term t accum |> snd
-            | Minus => let
-                         val (M1, accum) = do_term t1 accum
-                         val (M2, accum) = do_term t2 accum
-                       in accum ||> add_mtypes_equal M1 M2 end
+              Plus => do_term t accum
+            | Minus => consider_general_equals mdata x t1 t2 accum
         in
           case t of
-            Const (s0 as @{const_name all}, _) $ Abs (_, T1, t1) =>
-            do_quantifier s0 T1 t1
-          | Const (@{const_name "=="}, _) $ t1 $ t2 => do_equals t1 t2
-          | @{const "==>"} $ t1 $ t2 =>
-            accum |> do_contra_formula t1 |> do_co_formula t2
-          | @{const Trueprop} $ t1 => do_co_formula t1 accum
-          | @{const Not} $ t1 => do_contra_formula t1 accum
-          | Const (@{const_name All}, _)
-            $ Abs (_, T1, t1 as @{const "op -->"} $ (_ $ Bound 0) $ _) =>
-            do_bounded_quantifier T1 t1
-          | Const (s0 as @{const_name All}, _) $ Abs (_, T1, t1) =>
-            do_quantifier s0 T1 t1
-          | Const (@{const_name Ex}, _)
-            $ Abs (_, T1, t1 as @{const "op &"} $ (_ $ Bound 0) $ _) =>
-            do_bounded_quantifier T1 t1
-          | Const (s0 as @{const_name Ex}, T0) $ (t1 as Abs (_, T1, t1')) =>
+            Const (x as (@{const_name all}, _)) $ Abs (s1, T1, t1) =>
+            do_quantifier x s1 T1 t1
+          | Const (x as (@{const_name "=="}, _)) $ t1 $ t2 => do_equals x t1 t2
+          | @{const Trueprop} $ t1 =>
+            let val (m1, accum) = do_formula sn t1 accum in
+              (MApp (MRaw (@{const Trueprop}, mtype_for (bool_T --> prop_T)),
+                     m1), accum)
+            end
+          | @{const Not} $ t1 =>
+            let val (m1, accum) = do_formula (negate sn) t1 accum in
+              (MApp (MRaw (@{const Not}, mtype_for (bool_T --> bool_T)), m1),
+               accum)
+            end
+          | Const (x as (@{const_name All}, _)) $ Abs (s1, T1, t1) =>
+            do_quantifier x s1 T1 t1
+          | Const (x0 as (s0 as @{const_name Ex}, T0))
+            $ (t1 as Abs (s1, T1, t1')) =>
             (case sn of
-               Plus => do_quantifier s0 T1 t1'
+               Plus => do_quantifier x0 s1 T1 t1'
              | Minus =>
+               (* ### do elsewhere *)
                do_term (@{const Not}
                         $ (HOLogic.eq_const (domain_type T0) $ t1
-                           $ Abs (Name.uu, T1, @{const False}))) accum |> snd)
-          | Const (@{const_name "op ="}, _) $ t1 $ t2 => do_equals t1 t2
-          | @{const "op &"} $ t1 $ t2 =>
-            accum |> do_co_formula t1 |> do_co_formula t2
-          | @{const "op |"} $ t1 $ t2 =>
-            accum |> do_co_formula t1 |> do_co_formula t2
-          | @{const "op -->"} $ t1 $ t2 =>
-            accum |> do_contra_formula t1 |> do_co_formula t2
-          | Const (@{const_name If}, _) $ t1 $ t2 $ t3 =>
-            accum |> do_term t1 |> snd |> fold do_co_formula [t2, t3]
-          | Const (@{const_name Let}, _) $ t1 $ t2 =>
-            do_co_formula (betapply (t2, t1)) accum
-          | _ => do_term t accum |> snd
+                           $ Abs (Name.uu, T1, @{const False}))) accum)
+          | Const (x as (@{const_name "op ="}, _)) $ t1 $ t2 =>
+            do_equals x t1 t2
+          | (t0 as Const (s0, _)) $ t1 $ t2 =>
+            if s0 = @{const_name "==>"} orelse s0 = @{const_name "op &"} orelse
+               s0 = @{const_name "op |"} orelse s0 = @{const_name "op -->"} then
+              let
+                val impl = (s0 = @{const_name "==>"} orelse
+                           s0 = @{const_name "op -->"})
+                val (m1, accum) = do_formula (sn |> impl ? negate) t1 accum
+                val (m2, accum) = do_formula sn t2 accum
+              in
+                (MApp (MApp (MRaw (t0, mtype_for (fastype_of t0)), m1), m2),
+                 accum)
+              end 
+            else
+              do_term t accum
+          | _ => do_term t accum
         end
-        |> tap (fn _ => print_g ("\<Gamma> \<turnstile> " ^
-                                 Syntax.string_of_term ctxt t ^
-                                 " : o\<^sup>" ^ string_for_sign sn))
+        |> tap (fn (m, _) =>
+                   print_g ("\<Gamma> \<turnstile> " ^
+                            string_for_mterm ctxt m ^ " : o\<^sup>" ^
+                            string_for_sign sn))
   in do_formula end
 
 (* The harmless axiom optimization below is somewhat too aggressive in the face
@@ -947,46 +988,69 @@
   |> (forall (member (op =) harmless_consts o original_name o fst)
       orf exists (member (op =) bounteous_consts o fst))
 
-(* mdata -> sign -> term -> accumulator -> accumulator *)
-fun consider_nondefinitional_axiom (mdata as {hol_ctxt, ...}) sn t =
-  not (is_harmless_axiom hol_ctxt t) ? consider_general_formula mdata sn t
+(* mdata -> term -> accumulator -> mterm * accumulator *)
+fun consider_nondefinitional_axiom (mdata as {hol_ctxt, ...}) t =
+  if is_harmless_axiom hol_ctxt t then pair (MRaw (t, dummy_M))
+  else consider_general_formula mdata Plus t
 
-(* mdata -> term -> accumulator -> accumulator *)
+(* mdata -> term -> accumulator -> mterm * accumulator *)
 fun consider_definitional_axiom (mdata as {hol_ctxt as {thy, ...}, ...}) t =
   if not (is_constr_pattern_formula thy t) then
-    consider_nondefinitional_axiom mdata Plus t
+    consider_nondefinitional_axiom mdata t
   else if is_harmless_axiom hol_ctxt t then
-    I
+    pair (MRaw (t, dummy_M))
   else
     let
-      (* term -> accumulator -> mtyp * accumulator *)
-      val do_term = apfst mtype_of_mterm oo consider_term mdata
-      (* typ -> term -> accumulator -> accumulator *)
-      fun do_all abs_T body_t accum =
-        let val abs_M = fresh_mtype_for_type mdata abs_T in
-          accum |>> push_bound abs_M |> do_formula body_t |>> pop_bound
+      (* typ -> mtyp *)
+      val mtype_for = fresh_mtype_for_type mdata
+      (* term -> accumulator -> mterm * accumulator *)
+      val do_term = consider_term mdata
+      (* term -> string -> typ -> term -> accumulator -> mterm * accumulator *)
+      fun do_all quant_t abs_s abs_T body_t accum =
+        let
+          val abs_M = mtype_for abs_T
+          val (body_m, accum) =
+            accum |>> push_bound abs_M |> do_formula body_t
+          val body_M = mtype_of_mterm body_m
+        in
+          (MApp (MRaw (quant_t, MFun (abs_M, S Minus, body_M)),
+                 MAbs (abs_s, abs_T, abs_M, S Minus, body_m)),
+           accum |>> pop_bound)
         end
-      (* term -> term -> accumulator -> accumulator *)
-      and do_implies t1 t2 = do_term t1 #> snd #> do_formula t2
-      and do_equals t1 t2 accum =
+      (* term -> term -> term -> accumulator -> mterm * accumulator *)
+      and do_conjunction t0 t1 t2 accum =
         let
-          val (M1, accum) = do_term t1 accum
-          val (M2, accum) = do_term t2 accum
-        in accum ||> add_mtypes_equal M1 M2 end
+          val (m1, accum) = do_formula t1 accum
+          val (m2, accum) = do_formula t2 accum
+        in
+          (MApp (MApp (MRaw (t0, mtype_for (fastype_of t0)), m1), m2), accum)
+        end
+      and do_implies t0 t1 t2 accum =
+        let
+          val (m1, accum) = do_term t1 accum
+          val (m2, accum) = do_formula t2 accum
+        in
+          (MApp (MApp (MRaw (t0, mtype_for (fastype_of t0)), m1), m2), accum)
+        end
       (* term -> accumulator -> accumulator *)
-      and do_formula _ (_, UnsolvableCSet) = unsolvable_accum
+      and do_formula t (_, UnsolvableCSet) =
+          (MRaw (t, dummy_M), unsolvable_accum)
         | do_formula t accum =
           case t of
-            Const (@{const_name all}, _) $ Abs (_, T1, t1) => do_all T1 t1 accum
+            (t0 as Const (@{const_name all}, _)) $ Abs (s1, T1, t1) =>
+            do_all t0 s1 T1 t1 accum
           | @{const Trueprop} $ t1 => do_formula t1 accum
-          | Const (@{const_name "=="}, _) $ t1 $ t2 => do_equals t1 t2 accum
-          | @{const "==>"} $ t1 $ t2 => do_implies t1 t2 accum
-          | @{const Pure.conjunction} $ t1 $ t2 =>
-            accum |> do_formula t1 |> do_formula t2
-          | Const (@{const_name All}, _) $ Abs (_, T1, t1) => do_all T1 t1 accum
-          | Const (@{const_name "op ="}, _) $ t1 $ t2 => do_equals t1 t2 accum
-          | @{const "op &"} $ t1 $ t2 => accum |> do_formula t1 |> do_formula t2
-          | @{const "op -->"} $ t1 $ t2 => do_implies t1 t2 accum
+          | Const (x as (@{const_name "=="}, _)) $ t1 $ t2 =>
+            consider_general_equals mdata x t1 t2 accum
+          | (t0 as @{const "==>"}) $ t1 $ t2 => do_implies t0 t1 t2 accum
+          | (t0 as @{const Pure.conjunction}) $ t1 $ t2 =>
+            do_conjunction t0 t1 t2 accum
+          | (t0 as Const (@{const_name All}, _)) $ Abs (s0, T1, t1) =>
+            do_all t0 s0 T1 t1 accum
+          | Const (x as (@{const_name "op ="}, _)) $ t1 $ t2 =>
+            consider_general_equals mdata x t1 t2 accum
+          | (t0 as @{const "op &"}) $ t1 $ t2 => do_conjunction t0 t1 t2 accum
+          | (t0 as @{const "op -->"}) $ t1 $ t2 => do_implies t0 t1 t2 accum
           | _ => raise TERM ("Nitpick_Mono.consider_definitional_axiom.\
                              \do_formula", [t])
     in do_formula t end
@@ -1002,20 +1066,27 @@
   map (fn (x, M) => string_for_mtype_of_term ctxt lits (Const x) M) consts
   |> cat_lines |> print_g
 
-(* hol_context -> bool -> typ -> term list * term list * term -> bool *)
+(* ('a -> 'b -> 'c * 'd) -> 'a -> 'c list * 'b -> 'c list * 'd *)
+fun gather f t (ms, accum) =
+  let val (m, accum) = f t accum in (m :: ms, accum) end
+
+(* hol_context -> bool -> typ -> term list * term list -> bool *)
 fun formulas_monotonic (hol_ctxt as {ctxt, ...}) binarize alpha_T
-                       (def_ts, nondef_ts, core_t) =
+                       (nondef_ts, def_ts) =
   let
     val _ = print_g ("****** Monotonicity analysis: " ^
                      string_for_mtype MAlpha ^ " is " ^
                      Syntax.string_of_typ ctxt alpha_T)
     val mdata as {max_fresh, constr_cache, ...} =
       initial_mdata hol_ctxt binarize alpha_T
-    val (gamma as {frees, consts, ...}, cset) =
-      (initial_gamma, slack)
-      |> fold (consider_definitional_axiom mdata) def_ts
-      |> fold (consider_nondefinitional_axiom mdata Plus) nondef_ts
-      |> consider_general_formula mdata Plus core_t
+
+    val accum = (initial_gamma, slack)
+    val (nondef_ms, accum) =
+      ([], accum) |> gather (consider_general_formula mdata Plus) (hd nondef_ts)
+                  |> fold (gather (consider_nondefinitional_axiom mdata))
+                          (tl nondef_ts)
+    val (def_ms, (gamma, cset)) =
+      ([], accum) |> fold (gather (consider_definitional_axiom mdata)) def_ts
   in
     case solve (!max_fresh) cset of
       SOME lits => (print_mtype_context ctxt lits gamma; true)