src/HOL/Tools/Nitpick/nitpick_hol.ML
changeset 34121 5e831d805118
parent 33978 2380c1dac86e
child 34123 c4988215a691
--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Dec 14 12:14:12 2009 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Dec 14 12:30:26 2009 +0100
@@ -52,9 +52,9 @@
   val unbox_type : typ -> typ
   val string_for_type : Proof.context -> typ -> string
   val prefix_name : string -> string -> string
+  val shortest_name : string -> string
   val short_name : string -> string
-  val short_const_name : string -> string
-  val shorten_const_names_in_term : term -> term
+  val shorten_names_in_term : term -> term
   val type_match : theory -> typ * typ -> bool
   val const_match : theory -> styp * styp -> bool
   val term_match : theory -> term * term -> bool
@@ -197,12 +197,14 @@
 (* term * term -> term *)
 fun s_conj (t1, @{const True}) = t1
   | s_conj (@{const True}, t2) = t2
-  | s_conj (t1, t2) = if @{const False} mem [t1, t2] then @{const False}
-                      else HOLogic.mk_conj (t1, t2)
+  | s_conj (t1, t2) =
+    if t1 = @{const False} orelse t2 = @{const False} then @{const False}
+    else HOLogic.mk_conj (t1, t2)
 fun s_disj (t1, @{const False}) = t1
   | s_disj (@{const False}, t2) = t2
-  | s_disj (t1, t2) = if @{const True} mem [t1, t2] then @{const True}
-                      else HOLogic.mk_disj (t1, t2)
+  | s_disj (t1, t2) =
+    if t1 = @{const True} orelse t2 = @{const True} then @{const True}
+    else HOLogic.mk_disj (t1, t2)
 (* term -> term -> term *)
 fun mk_exists v t =
   HOLogic.exists_const (fastype_of v) $ lambda v (incr_boundvars 1 t)
@@ -213,7 +215,7 @@
   | strip_connective _ t = [t]
 (* term -> term list * term *)
 fun strip_any_connective (t as (t0 $ t1 $ t2)) =
-    if t0 mem [@{const "op &"}, @{const "op |"}] then
+    if t0 = @{const "op &"} orelse t0 = @{const "op |"} then
       (strip_connective t0 t, t0)
     else
       ([t], @{const Not})
@@ -347,19 +349,24 @@
 (* string -> string -> string *)
 val prefix_name = Long_Name.qualify o Long_Name.base_name
 (* string -> string *)
-fun short_name s = List.last (space_explode "." s) handle List.Empty => ""
+fun shortest_name s = List.last (space_explode "." s) handle List.Empty => ""
 (* string -> term -> term *)
 val prefix_abs_vars = Term.map_abs_vars o prefix_name
 (* term -> term *)
-val shorten_abs_vars = Term.map_abs_vars short_name
+val shorten_abs_vars = Term.map_abs_vars shortest_name
 (* string -> string *)
-fun short_const_name s =
+fun short_name s =
   case space_explode name_sep s of
     [_] => s |> String.isPrefix nitpick_prefix s ? unprefix nitpick_prefix
-  | ss => map short_name ss |> space_implode "_"
+  | ss => map shortest_name ss |> space_implode "_"
+(* typ -> typ *)
+fun shorten_names_in_type (Type (s, Ts)) =
+    Type (short_name s, map shorten_names_in_type Ts)
+  | shorten_names_in_type T = T
 (* term -> term *)
-val shorten_const_names_in_term =
-  map_aterms (fn Const (s, T) => Const (short_const_name s, T) | t => t)
+val shorten_names_in_term =
+  map_aterms (fn Const (s, T) => Const (short_name s, T) | t => t)
+  #> map_types shorten_names_in_type
 
 (* theory -> typ * typ -> bool *)
 fun type_match thy (T1, T2) =
@@ -371,7 +378,7 @@
 (* theory -> term * term -> bool *)
 fun term_match thy (Const x1, Const x2) = const_match thy (x1, x2)
   | term_match thy (Free (s1, T1), Free (s2, T2)) =
-    const_match thy ((short_name s1, T1), (short_name s2, T2))
+    const_match thy ((shortest_name s1, T1), (shortest_name s2, T2))
   | term_match thy (t1, t2) = t1 aconv t2
 
 (* typ -> bool *)
@@ -391,7 +398,7 @@
 fun is_gfp_iterator_type (Type (s, _)) = String.isPrefix gfp_iterator_prefix s
   | is_gfp_iterator_type _ = false
 val is_fp_iterator_type = is_lfp_iterator_type orf is_gfp_iterator_type
-val is_boolean_type = equal prop_T orf equal bool_T
+fun is_boolean_type T = (T = prop_T orelse T = bool_T)
 val is_integer_type =
   member (op =) [nat_T, int_T, @{typ bisim_iterator}] orf is_fp_iterator_type
 val is_record_type = not o null o Record.dest_recTs
@@ -458,6 +465,14 @@
         | NONE => case T of Type (s, Ts) => Type (s, map subst Ts) | _ => T
     in subst T end
 
+(* FIXME: Use antiquotation for "code_numeral" below or detect "rep_datatype",
+   e.g., by adding a field to "Datatype_Aux.info". *)
+(* string -> bool *)
+val is_basic_datatype =
+    member (op =) [@{type_name "*"}, @{type_name bool}, @{type_name unit},
+                   @{type_name nat}, @{type_name int},
+                   "Code_Numeral.code_numeral"]
+
 (* theory -> typ -> typ -> typ -> typ *)
 fun instantiate_type thy T1 T1' T2 =
   Same.commit (Envir.subst_type_same
@@ -486,8 +501,11 @@
     val constr_xs = map (apsnd (repair_constr_type thy co_T)) constr_xs
     val (co_s, co_Ts) = dest_Type co_T
     val _ =
-      if forall is_TFree co_Ts andalso not (has_duplicates (op =) co_Ts) then ()
-      else raise TYPE ("Nitpick_HOL.register_codatatype", [co_T], [])
+      if forall is_TFree co_Ts andalso not (has_duplicates (op =) co_Ts)
+         andalso co_s <> "fun" andalso not (is_basic_datatype co_s) then
+        ()
+      else
+        raise TYPE ("Nitpick_HOL.register_codatatype", [co_T], [])
     val codatatypes = AList.update (op =) (co_s, (case_name, constr_xs))
                                    codatatypes
   in Data.put {frac_types = frac_types, codatatypes = codatatypes} thy end
@@ -516,12 +534,6 @@
           Rep_inverse = SOME Rep_inverse}
   | NONE => NONE
 
-(* FIXME: use antiquotation for "code_numeral" below or detect "rep_datatype",
-   e.g., by adding a field to "Datatype_Aux.info". *)
-(* string -> bool *)
-fun is_basic_datatype s =
-    s mem [@{type_name "*"}, @{type_name bool}, @{type_name unit},
-           @{type_name nat}, @{type_name int}, "Code_Numeral.code_numeral"]
 (* theory -> string -> bool *)
 val is_typedef = is_some oo typedef_info
 val is_real_datatype = is_some oo Datatype.get_info
@@ -568,14 +580,15 @@
 val num_record_fields = Integer.add 1 o length o fst oo Record.get_extT_fields
 (* theory -> string -> typ -> int *)
 fun no_of_record_field thy s T1 =
-  find_index (equal s o fst) (Record.get_extT_fields thy T1 ||> single |> op @)
+  find_index (curry (op =) s o fst)
+             (Record.get_extT_fields thy T1 ||> single |> op @)
 (* theory -> styp -> bool *)
 fun is_record_get thy (s, Type ("fun", [T1, _])) =
-    exists (equal s o fst) (all_record_fields thy T1)
+    exists (curry (op =) s o fst) (all_record_fields thy T1)
   | is_record_get _ _ = false
 fun is_record_update thy (s, T) =
   String.isSuffix Record.updateN s andalso
-  exists (equal (unsuffix Record.updateN s) o fst)
+  exists (curry (op =) (unsuffix Record.updateN s) o fst)
          (all_record_fields thy (body_type T))
   handle TYPE _ => false
 fun is_abs_fun thy (s, Type ("fun", [_, Type (s', _)])) =
@@ -608,11 +621,11 @@
   end
   handle TYPE ("dest_Type", _, _) => false
 fun is_constr_like thy (s, T) =
-  s mem [@{const_name FunBox}, @{const_name PairBox}] orelse
+  s = @{const_name FunBox} orelse s = @{const_name PairBox} orelse
   let val (x as (s, T)) = (s, unbox_type T) in
     Refute.is_IDT_constructor thy x orelse is_record_constr x
     orelse (is_abs_fun thy x andalso is_pure_typedef thy (range_type T))
-    orelse s mem [@{const_name Zero_Rep}, @{const_name Suc_Rep}]
+    orelse s = @{const_name Zero_Rep} orelse s = @{const_name Suc_Rep}
     orelse x = (@{const_name zero_nat_inst.zero_nat}, nat_T)
     orelse is_coconstr thy x
   end
@@ -644,10 +657,11 @@
 fun is_boxing_worth_it (ext_ctxt : extended_context) boxy T =
   case T of
     Type ("fun", _) =>
-    boxy mem [InPair, InFunLHS] andalso not (is_boolean_type (body_type T))
+    (boxy = InPair orelse boxy = InFunLHS)
+    andalso not (is_boolean_type (body_type T))
   | Type ("*", Ts) =>
-    boxy mem [InPair, InFunRHS1, InFunRHS2]
-    orelse (boxy mem [InExpr, InFunLHS]
+    boxy = InPair orelse boxy = InFunRHS1 orelse boxy = InFunRHS2
+    orelse ((boxy = InExpr orelse boxy = InFunLHS)
             andalso exists (is_boxing_worth_it ext_ctxt InPair)
                            (map (box_type ext_ctxt InPair) Ts))
   | _ => false
@@ -660,7 +674,7 @@
 and box_type ext_ctxt boxy T =
   case T of
     Type (z as ("fun", [T1, T2])) =>
-    if not (boxy mem [InConstr, InSel])
+    if boxy <> InConstr andalso boxy <> InSel
        andalso should_box_type ext_ctxt boxy z then
       Type (@{type_name fun_box},
             [box_type ext_ctxt InFunLHS T1, box_type ext_ctxt InFunRHS1 T2])
@@ -672,8 +686,8 @@
       Type (@{type_name pair_box}, map (box_type ext_ctxt InSel) Ts)
     else
       Type ("*", map (box_type ext_ctxt
-                               (if boxy mem [InConstr, InSel] then boxy
-                                else InPair)) Ts)
+                          (if boxy = InConstr orelse boxy = InSel then boxy
+                           else InPair)) Ts)
   | _ => T
 
 (* styp -> styp *)
@@ -922,7 +936,7 @@
   let
     (* typ list -> typ -> int *)
     fun aux avoid T =
-      (if T mem avoid then
+      (if member (op =) avoid T then
          0
        else case T of
          Type ("fun", [T1, T2]) =>
@@ -957,7 +971,7 @@
                 |> map (Integer.prod o map (aux (T :: avoid)) o binder_types
                         o snd)
             in
-              if exists (equal 0) constr_cards then 0
+              if exists (curry (op =) 0) constr_cards then 0
               else Integer.sum constr_cards
             end)
        | _ => raise SAME ())
@@ -989,8 +1003,8 @@
 
 (* theory -> string -> bool *)
 fun is_funky_typedef_name thy s =
-  s mem [@{type_name unit}, @{type_name "*"}, @{type_name "+"},
-         @{type_name int}]
+  member (op =) [@{type_name unit}, @{type_name "*"}, @{type_name "+"},
+                 @{type_name int}] s
   orelse is_frac_type thy (Type (s, []))
 (* theory -> term -> bool *)
 fun is_funky_typedef thy (Type (s, _)) = is_funky_typedef_name thy s
@@ -1063,10 +1077,11 @@
     val (built_in_nondefs, user_nondefs) =
       List.partition (is_typedef_axiom thy false) user_nondefs
       |>> append built_in_nondefs
-    val defs = (thy |> PureThy.all_thms_of
-                    |> filter (equal Thm.definitionK o Thm.get_kind o snd)
-                    |> map (prop_of o snd) |> filter is_plain_definition) @
-               user_defs @ built_in_defs
+    val defs =
+      (thy |> PureThy.all_thms_of
+           |> filter (curry (op =) Thm.definitionK o Thm.get_kind o snd)
+           |> map (prop_of o snd) |> filter is_plain_definition) @
+      user_defs @ built_in_defs
   in (defs, built_in_nondefs, user_nondefs) end
 
 (* bool -> styp -> int option *)
@@ -1111,7 +1126,7 @@
   else
     these (Symtab.lookup table s)
     |> map_filter (try (Refute.specialize_type thy x))
-    |> filter (equal (Const x) o term_under_def)
+    |> filter (curry (op =) (Const x) o term_under_def)
 
 (* theory -> term -> term option *)
 fun normalized_rhs_of thy t =
@@ -1152,7 +1167,8 @@
     (* term -> bool *)
     fun is_good_arg (Bound _) = true
       | is_good_arg (Const (s, _)) =
-        s mem [@{const_name True}, @{const_name False}, @{const_name undefined}]
+        s = @{const_name True} orelse s = @{const_name False}
+        orelse s = @{const_name undefined}
       | is_good_arg _ = false
   in
     case t |> strip_abs_body |> strip_comb of
@@ -1598,9 +1614,9 @@
   let
     val prems = Logic.strip_imp_prems t |> map (ObjectLogic.atomize_term thy)
     val concl = Logic.strip_imp_concl t |> ObjectLogic.atomize_term thy
-    val (main, side) = List.partition (exists_Const (equal x)) prems
+    val (main, side) = List.partition (exists_Const (curry (op =) x)) prems
     (* term -> bool *)
-     val is_good_head = equal (Const x) o head_of
+     val is_good_head = curry (op =) (Const x) o head_of
   in
     if forall is_good_head main then (side, main, concl) else raise NO_TRIPLE ()
   end
@@ -1693,7 +1709,7 @@
         (x as (s, _)) =
   case triple_lookup (const_match thy) wfs x of
     SOME (SOME b) => b
-  | _ => s mem [@{const_name Nats}, @{const_name fold_graph'}]
+  | _ => s = @{const_name Nats} orelse s = @{const_name fold_graph'}
          orelse case AList.lookup (op =) (!wf_cache) x of
                   SOME (_, wf) => wf
                 | NONE =>
@@ -1730,7 +1746,7 @@
       | do_disjunct j t =
         case num_occs_of_bound_in_term j t of
           0 => true
-        | 1 => exists (equal (Bound j) o head_of) (conjuncts t)
+        | 1 => exists (curry (op =) (Bound j) o head_of) (conjuncts t)
         | _ => false
     (* term -> bool *)
     fun do_lfp_def (Const (@{const_name lfp}, _) $ t2) =
@@ -1774,7 +1790,7 @@
                   t
               end
           val (nonrecs, recs) =
-            List.partition (equal 0 o num_occs_of_bound_in_term j)
+            List.partition (curry (op =) 0 o num_occs_of_bound_in_term j)
                            (disjuncts body)
           val base_body = nonrecs |> List.foldl s_disj @{const False}
           val step_body = recs |> map (repair_rec j)
@@ -1923,7 +1939,7 @@
   | Type ("*", Ts) => fold (add_ground_types ext_ctxt) Ts accum
   | Type (@{type_name itself}, [T1]) => add_ground_types ext_ctxt T1 accum
   | Type (_, Ts) =>
-    if T mem @{typ prop} :: @{typ bool} :: @{typ unit} :: accum then
+    if member (op =) (@{typ prop} :: @{typ bool} :: @{typ unit} :: accum) T then
       accum
     else
       T :: accum
@@ -1962,7 +1978,7 @@
          andalso has_heavy_bounds_or_vars Ts level t_comb
          andalso not (loose_bvar (t_comb, level)) then
         let
-          val (j, seen) = case find_index (equal t_comb) seen of
+          val (j, seen) = case find_index (curry (op =) t_comb) seen of
                             ~1 => (0, t_comb :: seen)
                           | j => (j, seen)
         in (fresh_value_var Ts k (length seen) j t_comb, seen) end
@@ -2046,7 +2062,7 @@
          (Const (x as (s, T)), args) =>
          let val arg_Ts = binder_types T in
            if length arg_Ts = length args
-              andalso (is_constr thy x orelse s mem [@{const_name Pair}]
+              andalso (is_constr thy x orelse s = @{const_name Pair}
                        orelse x = dest_Const @{const Suc})
               andalso (not careful orelse not (is_Var t1)
                        orelse String.isPrefix val_var_prefix
@@ -2141,7 +2157,8 @@
     (* term list -> (indexname * typ) list -> indexname * typ -> term -> term
        -> term -> term *)
     and aux_eq prems zs z t' t1 t2 =
-      if not (z mem zs) andalso not (exists_subterm (equal (Var z)) t') then
+      if not (member (op =) zs z)
+         andalso not (exists_subterm (curry (op =) (Var z)) t') then
         aux prems zs (subst_free [(Var z, t')] t2)
       else
         aux (t1 :: prems) (Term.add_vars t1 zs) t2
@@ -2299,8 +2316,8 @@
          (t0 as Const (s0, _)) $ Abs (s1, T1, t1 as _ $ _) =>
          if s0 = quant_s andalso length Ts < quantifier_cluster_max_size then
            aux s0 (s1 :: ss) (T1 :: Ts) t1
-         else if quant_s = ""
-                 andalso s0 mem [@{const_name All}, @{const_name Ex}] then
+         else if quant_s = "" andalso (s0 = @{const_name All}
+                                       orelse s0 = @{const_name Ex}) then
            aux s0 [s1] [T1] t1
          else
            raise SAME ()
@@ -2330,7 +2347,8 @@
                      | cost boundss_cum_costs (j :: js) =
                        let
                          val (yeas, nays) =
-                           List.partition (fn (bounds, _) => j mem bounds)
+                           List.partition (fn (bounds, _) =>
+                                              member (op =) bounds j)
                                           boundss_cum_costs
                          val yeas_bounds = big_union fst yeas
                          val yeas_cost = Integer.sum (map snd yeas)
@@ -2339,7 +2357,7 @@
                    val js = all_permutations (index_seq 0 num_Ts)
                             |> map (`(cost (t_boundss ~~ t_costs)))
                             |> sort (int_ord o pairself fst) |> hd |> snd
-                   val back_js = map (fn j => find_index (equal j) js)
+                   val back_js = map (fn j => find_index (curry (op =) j) js)
                                      (index_seq 0 num_Ts)
                    val ts = map (renumber_bounds 0 num_Ts (nth back_js o flip))
                                 ts
@@ -2355,7 +2373,8 @@
                      | build ts_cum_bounds (j :: js) =
                        let
                          val (yeas, nays) =
-                           List.partition (fn (_, bounds) => j mem bounds)
+                           List.partition (fn (_, bounds) =>
+                                              member (op =) bounds j)
                                           ts_cum_bounds
                            ||> map (apfst (incr_boundvars ~1))
                        in
@@ -2548,7 +2567,7 @@
         if t = Const x then
           list_comb (Const x', extra_args @ filter_out_indices fixed_js args)
         else
-          let val j = find_index (equal t) fixed_params in
+          let val j = find_index (curry (op =) t) fixed_params in
             list_comb (if j >= 0 then nth fixed_args j else t, args)
           end
   in aux [] t end
@@ -2582,7 +2601,7 @@
                       else case term_under_def t of Const x => [x] | _ => []
       (* term list -> typ list -> term -> term *)
       fun aux args Ts (Const (x as (s, T))) =
-          ((if not (x mem blacklist) andalso not (null args)
+          ((if not (member (op =) blacklist x) andalso not (null args)
                andalso not (String.isPrefix special_prefix s)
                andalso is_equational_fun ext_ctxt x then
               let
@@ -2607,7 +2626,8 @@
                 (* int -> term *)
                 fun var_for_bound_no j =
                   Var ((bound_var_prefix ^
-                        nat_subscript (find_index (equal j) bound_js + 1), k),
+                        nat_subscript (find_index (curry (op =) j) bound_js
+                                       + 1), k),
                        nth Ts j)
                 val fixed_args_in_axiom =
                   map (curry subst_bounds
@@ -2739,7 +2759,8 @@
                                \coerce_term", [t']))
         | (Type (new_s, new_Ts as [new_T1, new_T2]),
            Type (old_s, old_Ts as [old_T1, old_T2])) =>
-          if old_s mem [@{type_name fun_box}, @{type_name pair_box}, "*"] then
+          if old_s = @{type_name fun_box} orelse old_s = @{type_name pair_box}
+             orelse old_s = "*" then
             case constr_expand ext_ctxt old_T t of
               Const (@{const_name FunBox}, _) $ t1 =>
               if new_s = "fun" then
@@ -2770,13 +2791,13 @@
            fold (add_boxed_types_for_var z) [(T1, t1), (T2, t2)]
          | _ => raise TYPE ("Nitpick_HOL.box_fun_and_pair_in_term.\
                             \add_boxed_types_for_var", [T'], []))
-      | _ => exists_subterm (equal (Var z)) t' ? insert (op =) T
+      | _ => exists_subterm (curry (op =) (Var z)) t' ? insert (op =) T
     (* typ list -> typ list -> term -> indexname * typ -> typ *)
     fun box_var_in_def new_Ts old_Ts t (z as (_, T)) =
       case t of
         @{const Trueprop} $ t1 => box_var_in_def new_Ts old_Ts t1 z
       | Const (s0, _) $ t1 $ _ =>
-        if s0 mem [@{const_name "=="}, @{const_name "op ="}] then
+        if s0 = @{const_name "=="} orelse s0 = @{const_name "op ="} then
           let
             val (t', args) = strip_comb t1
             val T' = fastype_of1 (new_Ts, do_term new_Ts old_Ts Neut t')
@@ -2855,7 +2876,8 @@
       | Const (s as @{const_name Eps}, T) => do_description_operator s T
       | Const (s as @{const_name Tha}, T) => do_description_operator s T
       | Const (x as (s, T)) =>
-        Const (s, if s mem [@{const_name converse}, @{const_name trancl}] then
+        Const (s, if s = @{const_name converse}
+                     orelse s = @{const_name trancl} then
                     box_relational_operator_type T
                   else if is_built_in_const fast_descrs x
                           orelse s = @{const_name Sigma} then
@@ -2954,7 +2976,7 @@
       |> map (fn ((x, js, ts), x') => (x, (js, ts, x')))
       |> AList.group (op =)
       |> filter_out (is_equational_fun_surely_complete ext_ctxt o fst)
-      |> map (fn (x, zs) => (x, zs |> (x mem xs) ? cons ([], [], x)))
+      |> map (fn (x, zs) => (x, zs |> member (op =) xs x ? cons ([], [], x)))
     (* special -> int *)
     fun generality (js, _, _) = ~(length js)
     (* special -> special -> bool *)
@@ -3022,7 +3044,7 @@
       case t of
         t1 $ t2 => accum |> fold (add_axioms_for_term depth) [t1, t2]
       | Const (x as (s, T)) =>
-        (if x mem xs orelse is_built_in_const fast_descrs x then
+        (if member (op =) xs x orelse is_built_in_const fast_descrs x then
            accum
          else
            let val accum as (xs, _) = (x :: xs, axs) in
@@ -3175,7 +3197,7 @@
     val T = unbox_type T
     val format = format |> filter (curry (op <) 0)
   in
-    if forall (equal 1) format then
+    if forall (curry (op =) 1) format then
       T
     else
       let
@@ -3226,7 +3248,8 @@
                                 SOME t => do_term t
                               | NONE =>
                                 Var (nth missing_vars
-                                         (find_index (equal j) missing_js)))
+                                         (find_index (curry (op =) j)
+                                                     missing_js)))
                           Ts (0 upto max_j)
            val t = do_const x' |> fst
            val format =
@@ -3300,7 +3323,7 @@
            (t, format_term_type thy def_table formats t)
          end)
       |>> map_types (typ_subst [(@{typ bisim_iterator}, nat_T)] o unbox_type)
-      |>> shorten_const_names_in_term |>> shorten_abs_vars
+      |>> shorten_names_in_term |>> shorten_abs_vars
   in do_const end
 
 (* styp -> string *)