src/HOL/Tools/Nitpick/nitpick_mono.ML
changeset 37476 0681e46b4022
parent 37267 5c47d633c84d
child 37678 0040bafffdef
--- a/src/HOL/Tools/Nitpick/nitpick_mono.ML	Mon Jun 21 09:38:20 2010 +0200
+++ b/src/HOL/Tools/Nitpick/nitpick_mono.ML	Mon Jun 21 11:15:21 2010 +0200
@@ -54,8 +54,9 @@
 exception MTYPE of string * mtyp list * typ list
 exception MTERM of string * mterm list
 
-fun print_g (_ : string) = ()
-(* val print_g = tracing *)
+val debug_mono = false
+
+fun print_g f = () |> debug_mono ? tracing o f
 
 val string_for_var = signed_string_of_int
 fun string_for_vars sep [] = "0\<^bsub>" ^ sep ^ "\<^esub>"
@@ -401,10 +402,10 @@
                  [M1, M2], [])
 
 fun add_mtype_comp cmp M1 M2 ((lits, comps, sexps) : constraint_set) =
-    (print_g ("*** Add " ^ string_for_mtype M1 ^ " " ^ string_for_comp_op cmp ^
-              " " ^ string_for_mtype M2);
+    (print_g (fn () => "*** Add " ^ string_for_mtype M1 ^ " " ^
+                       string_for_comp_op cmp ^ " " ^ string_for_mtype M2);
      case do_mtype_comp cmp [] M1 M2 (SOME (lits, comps)) of
-       NONE => (print_g "**** Unsolvable"; raise UNSOLVABLE ())
+       NONE => (print_g (K "**** Unsolvable"); raise UNSOLVABLE ())
      | SOME (lits, comps) => (lits, comps, sexps))
 
 val add_mtypes_equal = add_mtype_comp Eq
@@ -446,10 +447,11 @@
     raise MTYPE ("Nitpick_Mono.do_notin_mtype_fv", [M], [])
 
 fun add_notin_mtype_fv sn M ((lits, comps, sexps) : constraint_set) =
-    (print_g ("*** Add " ^ string_for_mtype M ^ " is " ^
-              (case sn of Minus => "concrete" | Plus => "complete") ^ ".");
+    (print_g (fn () => "*** Add " ^ string_for_mtype M ^ " is " ^
+                       (case sn of Minus => "concrete" | Plus => "complete") ^
+                       ".");
      case do_notin_mtype_fv sn [] M (SOME (lits, sexps)) of
-       NONE => (print_g "**** Unsolvable"; raise UNSOLVABLE ())
+       NONE => (print_g (K "**** Unsolvable"); raise UNSOLVABLE ())
      | SOME (lits, sexps) => (lits, comps, sexps))
 
 val add_mtype_is_concrete = add_notin_mtype_fv Minus
@@ -491,15 +493,16 @@
   subscript_string_for_vars " \<and> " xs ^ " " ^ string_for_sign_atom a2
 
 fun print_problem lits comps sexps =
-  print_g ("*** Problem:\n" ^ cat_lines (map string_for_literal lits @
-                                         map string_for_comp comps @
-                                         map string_for_sign_expr sexps))
+  print_g (fn () => "*** Problem:\n" ^
+                    cat_lines (map string_for_literal lits @
+                               map string_for_comp comps @
+                               map string_for_sign_expr sexps))
 
 fun print_solution lits =
   let val (pos, neg) = List.partition (curry (op =) Plus o snd) lits in
-    print_g ("*** Solution:\n" ^
-             "+: " ^ commas (map (string_for_var o fst) pos) ^ "\n" ^
-             "-: " ^ commas (map (string_for_var o fst) neg))
+    print_g (fn () => "*** Solution:\n" ^
+                      "+: " ^ commas (map (string_for_var o fst) pos) ^ "\n" ^
+                      "-: " ^ commas (map (string_for_var o fst) neg))
   end
 
 fun solve max_var (lits, comps, sexps) =
@@ -550,6 +553,12 @@
                                           def_table, ...},
                              alpha_T, max_fresh, ...}) =
   let
+    fun is_enough_eta_expanded t =
+      case strip_comb t of
+        (Const x, ts) =>
+        the_default 0 (arity_of_built_in_const thy stds fast_descrs x)
+        <= length ts
+      | _ => true
     val mtype_for = fresh_mtype_for_type mdata false
     fun plus_set_mtype_for_dom M =
       MFun (M, S (if exists_alpha_sub_mtype M then Plus else Minus), bool_M)
@@ -640,8 +649,10 @@
                   |>> mtype_of_mterm
                 end
               | @{const_name "op ="} => do_equals T accum
-              | @{const_name The} => (print_g "*** The"; raise UNSOLVABLE ())
-              | @{const_name Eps} => (print_g "*** Eps"; raise UNSOLVABLE ())
+              | @{const_name The} =>
+                (print_g (K "*** The"); raise UNSOLVABLE ())
+              | @{const_name Eps} =>
+                (print_g (K "*** Eps"); raise UNSOLVABLE ())
               | @{const_name If} =>
                 do_robust_set_operation (range_type T) accum
                 |>> curry3 MFun bool_M (S Minus)
@@ -650,19 +661,6 @@
               | @{const_name snd} => do_nth_pair_sel 1 T accum 
               | @{const_name Id} =>
                 (MFun (mtype_for (domain_type T), S Minus, bool_M), accum)
-              | @{const_name insert} =>
-                let
-                  val set_T = domain_type (range_type T)
-                  val M1 = mtype_for (domain_type set_T)
-                  val M1' = plus_set_mtype_for_dom M1
-                  val M2 = mtype_for set_T
-                  val M3 = mtype_for set_T
-                in
-                  (MFun (M1, S Minus, MFun (M2, S Minus, M3)),
-                   (gamma, cset |> add_mtype_is_concrete M1
-                                |> add_is_sub_mtype M1' M3
-                                |> add_is_sub_mtype M2 M3))
-                end
               | @{const_name converse} =>
                 let
                   val x = Unsynchronized.inc max_fresh
@@ -720,25 +718,9 @@
                     val a_set_M = mtype_for (domain_type T)
                     val a_M = dest_MFun a_set_M |> #1
                   in (MFun (a_set_M, S Minus, a_M), accum) end
-                else if s = @{const_name minus_class.minus} andalso
-                        is_set_type (domain_type T) then
-                  let
-                    val set_T = domain_type T
-                    val left_set_M = mtype_for set_T
-                    val right_set_M = mtype_for set_T
-                  in
-                    (MFun (left_set_M, S Minus,
-                           MFun (right_set_M, S Minus, left_set_M)),
-                     (gamma, cset |> add_mtype_is_concrete right_set_M
-                                  |> add_is_sub_mtype right_set_M left_set_M))
-                  end
                 else if s = @{const_name ord_class.less_eq} andalso
                         is_set_type (domain_type T) then
                   do_fragile_set_operation T accum
-                else if (s = @{const_name semilattice_inf_class.inf} orelse
-                         s = @{const_name semilattice_sup_class.sup}) andalso
-                        is_set_type (domain_type T) then
-                  do_robust_set_operation T accum
                 else if is_sel s then
                   (mtype_for_sel mdata x, accum)
                 else if is_constr ctxt stds x then
@@ -758,7 +740,7 @@
                 (M, ({bound_Ts = bound_Ts, bound_Ms = bound_Ms,
                       frees = (x, M) :: frees, consts = consts}, cset))
               end) |>> curry MRaw t
-         | Var _ => (print_g "*** Var"; raise UNSOLVABLE ())
+         | Var _ => (print_g (K "*** Var"); raise UNSOLVABLE ())
          | Bound j => (MRaw (t, nth bound_Ms j), accum)
          | Abs (s, T, t') =>
            (case fin_fun_body T (fastype_of1 (T :: bound_Ts, t')) t' of
@@ -771,10 +753,16 @@
             | NONE =>
               ((case t' of
                   t1' $ Bound 0 =>
-                  if not (loose_bvar1 (t1', 0)) then
+                  if not (loose_bvar1 (t1', 0)) andalso
+                     is_enough_eta_expanded t1' then
                     do_term (incr_boundvars ~1 t1') accum
                   else
                     raise SAME ()
+                | (t11 as Const (@{const_name "op ="}, _)) $ Bound 0 $ t13 =>
+                  if not (loose_bvar1 (t13, 0)) then
+                    do_term (incr_boundvars ~1 (t11 $ t13)) accum
+                  else
+                    raise SAME ()
                 | _ => raise SAME ())
                handle SAME () =>
                       let
@@ -803,8 +791,8 @@
                val M2 = mtype_of_mterm m2
              in (MApp (m1, m2), accum ||> add_is_sub_mtype M2 M11) end
            end)
-        |> tap (fn (m, _) => print_g ("  \<Gamma> \<turnstile> " ^
-                                      string_for_mterm ctxt m))
+        |> tap (fn (m, _) => print_g (fn () => "  \<Gamma> \<turnstile> " ^
+                                               string_for_mterm ctxt m))
   in do_term end
 
 fun force_minus_funs 0 _ = I
@@ -902,9 +890,9 @@
           | _ => do_term t accum
         end
         |> tap (fn (m, _) =>
-                   print_g ("\<Gamma> \<turnstile> " ^
-                            string_for_mterm ctxt m ^ " : o\<^sup>" ^
-                            string_for_sign sn))
+                   print_g (fn () => "\<Gamma> \<turnstile> " ^
+                                     string_for_mterm ctxt m ^ " : o\<^sup>" ^
+                                     string_for_sign sn))
   in do_formula end
 
 (* The harmless axiom optimization below is somewhat too aggressive in the face
@@ -987,9 +975,10 @@
   Syntax.string_of_term ctxt t ^ " : " ^ string_for_mtype (resolve_mtype lits M)
 
 fun print_mtype_context ctxt lits ({frees, consts, ...} : mtype_context) =
-  map (fn (x, M) => string_for_mtype_of_term ctxt lits (Free x) M) frees @
-  map (fn (x, M) => string_for_mtype_of_term ctxt lits (Const x) M) consts
-  |> cat_lines |> print_g
+  print_g (fn () =>
+      map (fn (x, M) => string_for_mtype_of_term ctxt lits (Free x) M) frees @
+      map (fn (x, M) => string_for_mtype_of_term ctxt lits (Const x) M) consts
+      |> cat_lines)
 
 fun amass f t (ms, accum) =
   let val (m, accum) = f t accum in (m :: ms, accum) end
@@ -997,9 +986,9 @@
 fun infer which no_harmless (hol_ctxt as {ctxt, ...}) binarize alpha_T
           (nondef_ts, def_ts) =
   let
-    val _ = print_g ("****** " ^ which ^ " analysis: " ^
-                     string_for_mtype MAlpha ^ " is " ^
-                     Syntax.string_of_typ ctxt alpha_T)
+    val _ = print_g (fn () => "****** " ^ which ^ " analysis: " ^
+                              string_for_mtype MAlpha ^ " is " ^
+                              Syntax.string_of_typ ctxt alpha_T)
     val mdata as {max_fresh, constr_mcache, ...} =
       initial_mdata hol_ctxt binarize no_harmless alpha_T
     val accum = (initial_gamma, ([], [], []))
@@ -1064,26 +1053,21 @@
             in
               case t of
                 Const (x as (s, _)) =>
-                if s = @{const_name insert} then
-                  case nth_range_type 2 T' of
-                    set_T' as Type (@{type_name fin_fun}, [elem_T', _]) =>
-                      Abs (Name.uu, elem_T', Abs (Name.uu, set_T',
-                          Const (@{const_name If},
-                                 bool_T --> set_T' --> set_T' --> set_T')
-                          $ (Const (@{const_name is_unknown},
-                                    elem_T' --> bool_T) $ Bound 1)
-                          $ (Const (@{const_name unknown}, set_T'))
-                          $ (coerce_term hol_ctxt new_Ts T' T (Const x)
-                             $ Bound 1 $ Bound 0)))
-                  | _ => Const (s, T')
-                else if s = @{const_name finite} then
+                if s = @{const_name finite} then
                   case domain_type T' of
                     set_T' as Type (@{type_name fin_fun}, _) =>
                     Abs (Name.uu, set_T', @{const True})
                   | _ => Const (s, T')
                 else if s = @{const_name "=="} orelse
                         s = @{const_name "op ="} then
-                  Const (s, T')
+                  let
+                    val T =
+                      case T' of
+                        Type (_, [T1, Type (_, [T2, T3])]) =>
+                        T1 --> T2 --> T3
+                      | _ => raise TYPE ("Nitpick_Mono.finitize_funs.\
+                                         \term_from_mterm", [T, T'], [])
+                  in coerce_term hol_ctxt new_Ts T' T (Const (s, T)) end
                 else if is_built_in_const thy stds fast_descrs x then
                   coerce_term hol_ctxt new_Ts T' T t
                 else if is_constr ctxt stds x then