src/HOL/Tools/Sledgehammer/sledgehammer_tptp_format.ML
changeset 37520 9fc2ae73c5ca
parent 37519 fd1a5ece77c0
child 37577 5379f41a1322
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_tptp_format.ML	Wed Jun 23 15:35:18 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_tptp_format.ML	Wed Jun 23 16:28:12 2010 +0200
@@ -29,6 +29,8 @@
 open Sledgehammer_FOL_Clause
 open Sledgehammer_HOL_Clause
 
+type const_info = {min_arity: int, max_arity: int, sub_level: bool}
+
 val clause_prefix = "cls_"
 val arity_clause_prefix = "clsarity_"
 
@@ -44,15 +46,19 @@
     in (s ^ paren_pack ss, pool) end
 
 (* True if the constant ever appears outside of the top-level position in
-   literals. If false, the constant always receives all of its arguments and is
-   used as a predicate. *)
-fun needs_hBOOL explicit_apply const_needs_hBOOL c =
-  explicit_apply orelse the_default false (Symtab.lookup const_needs_hBOOL c)
+   literals, or if it appears with different arities (e.g., because of different
+   type instantiations). If false, the constant always receives all of its
+   arguments and is used as a predicate. *)
+fun needs_hBOOL NONE _ = true
+  | needs_hBOOL (SOME the_const_tab) c =
+    case Symtab.lookup the_const_tab c of
+      SOME ({min_arity, max_arity, sub_level} : const_info) =>
+      sub_level orelse min_arity < max_arity
+    | NONE => false
 
-fun head_needs_hBOOL explicit_apply const_needs_hBOOL
-                     (CombConst ((c, _), _, _)) =
-    needs_hBOOL explicit_apply const_needs_hBOOL c
-  | head_needs_hBOOL _ _ _ = true
+fun head_needs_hBOOL const_tab (CombConst ((c, _), _, _)) =
+    needs_hBOOL const_tab c
+  | head_needs_hBOOL _ _ = true
 
 fun wrap_type full_types (s, ty) pool =
   if full_types then
@@ -62,11 +68,9 @@
   else
     (s, pool)
 
-fun wrap_type_if full_types explicit_apply const_needs_hBOOL (head, s, tp) =
-  if head_needs_hBOOL explicit_apply const_needs_hBOOL head then
-    wrap_type full_types (s, tp)
-  else
-    pair s
+fun wrap_type_if (full_types, const_tab) (head, s, tp) =
+  if head_needs_hBOOL const_tab head then wrap_type full_types (s, tp)
+  else pair s
 
 fun apply ss = "hAPP" ^ paren_pack ss;
 
@@ -75,15 +79,19 @@
 
 fun string_apply (v, args) = rev_apply (v, rev args)
 
-fun min_arity_of const_min_arity = the_default 0 o Symtab.lookup const_min_arity
+fun min_arity_of NONE _ = 0
+  | min_arity_of (SOME the_const_tab) c =
+    case Symtab.lookup the_const_tab c of
+      SOME ({min_arity, ...} : const_info) => min_arity
+    | NONE => 0
 
 (* Apply an operator to the argument strings, using either the "apply" operator
    or direct function application. *)
-fun string_of_application full_types const_min_arity
+fun string_of_application (full_types, const_tab)
                           (CombConst ((s, s'), _, tvars), args) pool =
     let
       val s = if s = "equal" then "c_fequal" else s
-      val nargs = min_arity_of const_min_arity s
+      val nargs = min_arity_of const_tab s
       val args1 = List.take (args, nargs)
         handle Subscript =>
                raise Fail (quote s ^ " has arity " ^ Int.toString nargs ^
@@ -93,31 +101,24 @@
                           else pool_map string_of_fol_type tvars pool
       val (s, pool) = nice_name (s, s') pool
     in (string_apply (s ^ paren_pack (args1 @ targs), args2), pool) end
-  | string_of_application _ _ (CombVar (name, _), args) pool =
+  | string_of_application _ (CombVar (name, _), args) pool =
     nice_name name pool |>> (fn s => string_apply (s, args))
 
-fun string_of_combterm (params as (full_types, explicit_apply, const_min_arity,
-                                   const_needs_hBOOL)) t pool =
+fun string_of_combterm params t pool =
   let
     val (head, args) = strip_combterm_comb t
     val (ss, pool) = pool_map (string_of_combterm params) args pool
-    val (s, pool) =
-      string_of_application full_types const_min_arity (head, ss) pool
-  in
-    wrap_type_if full_types explicit_apply const_needs_hBOOL
-                 (head, s, type_of_combterm t) pool
-  end
+    val (s, pool) = string_of_application params (head, ss) pool
+  in wrap_type_if params (head, s, type_of_combterm t) pool end
 
 (*Boolean-valued terms are here converted to literals.*)
 fun boolify params c =
   string_of_combterm params c #>> prefix "hBOOL" o paren_pack o single
 
-fun string_of_predicate (params as (_, explicit_apply, _, const_needs_hBOOL))
-                        t =
+fun string_of_predicate (params as (_, const_tab)) t =
   case #1 (strip_combterm_comb t) of
     CombConst ((s, _), _, _) =>
-    (if needs_hBOOL explicit_apply const_needs_hBOOL s then boolify
-     else string_of_combterm) params t
+    (if needs_hBOOL const_tab s then boolify else string_of_combterm) params t
   | _ => boolify params t
 
 fun tptp_of_equality params pos (t1, t2) =
@@ -188,36 +189,34 @@
   tptp_cnf (arity_clause_prefix ^ ascii_of axiom_name) "axiom"
            (tptp_raw_clause (map tptp_of_arity_literal (conclLit :: premLits)))
 
-(*Find the minimal arity of each function mentioned in the term. Also, note which uses
-  are not at top level, to see if hBOOL is needed.*)
-fun count_constants_term toplev t (const_min_arity, const_needs_hBOOL) =
+(* Find the minimal arity of each function mentioned in the term. Also, note
+   which uses are not at top level, to see if "hBOOL" is needed. *)
+fun count_constants_term top_level t the_const_tab =
   let
     val (head, args) = strip_combterm_comb t
     val n = length args
-    val (const_min_arity, const_needs_hBOOL) =
-      (const_min_arity, const_needs_hBOOL)
-      |> fold (count_constants_term false) args
+    val the_const_tab = the_const_tab |> fold (count_constants_term false) args
   in
     case head of
-      CombConst ((a, _), _, _) => (*predicate or function version of "equal"?*)
-        let val a = if a="equal" andalso not toplev then "c_fequal" else a
-        in
-          (const_min_arity |> Symtab.map_default (a, n) (Integer.min n),
-           const_needs_hBOOL |> not toplev ? Symtab.update (a, true))
-        end
-      | _ => (const_min_arity, const_needs_hBOOL)
+      CombConst ((a, _), _, ty) =>
+      (* Predicate or function version of "equal"? *)
+      let val a = if a = "equal" andalso not top_level then "c_fequal" else a in
+        the_const_tab
+        |> Symtab.map_default
+               (a, {min_arity = n, max_arity = n, sub_level = false})
+               (fn {min_arity, max_arity, sub_level} =>
+                   {min_arity = Int.min (n, min_arity),
+                    max_arity = Int.max (n, max_arity),
+                    sub_level = sub_level orelse not top_level})
+      end
+      | _ => the_const_tab
   end
-fun count_constants_lit (Literal (_, t)) = count_constants_term true t
+fun count_constants_literal (Literal (_, t)) = count_constants_term true t
 fun count_constants_clause (HOLClause {literals, ...}) =
-  fold count_constants_lit literals
-fun count_constants explicit_apply
-                    (conjectures, _, extra_clauses, helper_clauses, _, _) =
-  (Symtab.empty, Symtab.empty)
-  |> (if explicit_apply then
-        I
-      else
-        fold (fold count_constants_clause)
-             [conjectures, extra_clauses, helper_clauses])
+  fold count_constants_literal literals
+fun count_constants (conjectures, _, extra_clauses, helper_clauses, _, _) =
+  fold (fold count_constants_clause)
+       [conjectures, extra_clauses, helper_clauses]
 
 fun write_tptp_file readable_names full_types explicit_apply file clauses =
   let
@@ -228,10 +227,9 @@
     val pool = empty_name_pool readable_names
     val (conjectures, axclauses, _, helper_clauses,
       classrel_clauses, arity_clauses) = clauses
-    val (const_min_arity, const_needs_hBOOL) =
-      count_constants explicit_apply clauses
-    val params = (full_types, explicit_apply, const_min_arity,
-                  const_needs_hBOOL)
+    val const_tab = if explicit_apply then NONE
+                    else SOME (Symtab.empty |> count_constants clauses)
+    val params = (full_types, const_tab)
     val ((conjecture_clss, tfree_litss), pool) =
       pool_map (tptp_clause params) conjectures pool |>> ListPair.unzip
     val tfree_clss = map tptp_tfree_clause (fold (union (op =)) tfree_litss [])