generate problems with type classes
authorblanchet
Mon, 09 Dec 2013 04:03:30 +0100
changeset 54695 a9efdf970720
parent 54694 af9cdb4989c7
child 54696 34496126a60c
generate problems with type classes
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Nitpick/nitpick_util.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
src/HOL/Tools/Sledgehammer/sledgehammer_util.ML
--- a/src/HOL/TPTP/mash_export.ML	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/TPTP/mash_export.ML	Mon Dec 09 04:03:30 2013 +0100
@@ -79,9 +79,9 @@
       let
         val name = nickname_of_thm th
         val feats =
-          features_of ctxt (theory_of_thm th) 0 Symtab.empty stature [prop_of th] |> map fst
+          features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th] |> map fst
         val s =
-          encode_str name ^ ": " ^ encode_strs (sort string_ord feats) ^ "\n"
+          encode_str name ^ ": " ^ encode_plain_features (sort_wrt hd feats) ^ "\n"
       in File.append path s end
   in List.app do_fact facts end
 
@@ -161,9 +161,6 @@
           val _ = tracing ("Fact " ^ string_of_int j ^ ": " ^ name)
           val isar_deps = isar_dependencies_of name_tabs th
           val do_query = not (is_bad_query ctxt ho_atp step j th isar_deps)
-          val goal_feats =
-            features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature [prop_of th]
-            |> sort_wrt fst
           val access_facts =
             (if linearize then take (j - 1) new_facts
              else new_facts |> filter_accessible_from th) @ old_facts
@@ -173,11 +170,15 @@
           val parents = if linearize then prevs else parents
           fun extra_features_of (((_, stature), th), weight) =
             [prop_of th]
-            |> features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature
+            |> features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature false
             |> map (apsnd (fn r => weight * extra_feature_factor * r))
           val query =
             if do_query then
               let
+                val goal_feats =
+                  features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature true
+                    [prop_of th]
+                  |> sort_wrt (hd o fst)
                 val query_feats =
                   new_facts
                   |> drop (j - num_extra_feature_facts)
@@ -193,9 +194,13 @@
               end
             else
               ""
+          val nongoal_feats =
+            features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature false
+              [prop_of th]
+            |> map fst |> sort_wrt hd
           val update =
             "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
-            encode_strs (map fst goal_feats) ^ "; " ^ marker ^ " " ^
+            encode_plain_features nongoal_feats ^ "; " ^ marker ^ " " ^
             encode_strs deps ^ "\n"
         in query ^ update end
       else
--- a/src/HOL/Tools/Nitpick/nitpick_util.ML	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_util.ML	Mon Dec 09 04:03:30 2013 +0100
@@ -173,10 +173,8 @@
                                "indices unordered or out of range")
   in aux 0 js xs end
 
-fun cartesian_product [] _ = []
-  | cartesian_product (x :: xs) yss =
-    map (cons x) yss @ cartesian_product xs yss
-fun n_fold_cartesian_product xss = fold_rev cartesian_product xss [[]]
+fun n_fold_cartesian_product xss = Sledgehammer_Util.n_fold_cartesian_product
+
 fun all_distinct_unordered_pairs_of [] = []
   | all_distinct_unordered_pairs_of (x :: xs) =
     map (pair x) xs @ all_distinct_unordered_pairs_of xs
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Mon Dec 09 04:03:30 2013 +0100
@@ -30,7 +30,8 @@
   val encode_strs : string list -> string
   val unencode_str : string -> string
   val unencode_strs : string -> string list
-  val encode_features : (string * real) list -> string
+  val encode_plain_features : string list list -> string
+  val encode_features : (string list * real) list -> string
   val extract_suggestions : string -> string * string list
 
   structure MaSh:
@@ -38,13 +39,13 @@
     val unlearn : Proof.context -> bool -> unit
     val learn :
       Proof.context -> bool -> bool
-      -> (string * string list * string list * string list) list -> unit
+      -> (string * string list * string list list * string list) list -> unit
     val relearn :
       Proof.context -> bool -> bool -> (string * string list) list -> unit
     val query :
       Proof.context -> bool -> int
-      -> (string * string list * string list * string list) list
-         * string list * string list * (string * real) list
+      -> (string * string list * string list list * string list) list
+         * string list * string list * (string list * real) list
       -> string list
   end
 
@@ -62,8 +63,8 @@
   val run_prover_for_mash :
     Proof.context -> params -> string -> string -> fact list -> thm -> prover_result
   val features_of :
-    Proof.context -> theory -> int -> int Symtab.table -> stature -> term list ->
-    (string * real) list
+    Proof.context -> theory -> int -> int Symtab.table -> stature -> bool -> term list ->
+    (string list * real) list
   val trim_dependencies : string list -> string list option
   val isar_dependencies_of :
     string Symtab.table * string Symtab.table -> thm -> string list
@@ -219,15 +220,17 @@
   else if r >= 1000000.0 then "1000000"
   else Markup.print_real r
 
-fun encode_feature (name, weight) =
-  encode_str name ^
-  (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
+val encode_plain_feature = space_implode "|" o map encode_str
 
+fun encode_feature (names, weight) =
+  encode_plain_feature names ^ (if Real.== (weight, 1.0) then "" else "=" ^ safe_str_of_real weight)
+
+val encode_plain_features = map encode_plain_feature #> space_implode " "
 val encode_features = map encode_feature #> space_implode " "
 
-fun str_of_learn (name, parents, feats, deps) =
+fun str_of_learn (name, parents, feats : string list list, deps) =
   "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^
-  encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
+  encode_plain_features feats ^ "; " ^ encode_strs deps ^ "\n"
 
 fun str_of_relearn (name, deps) =
   "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
@@ -274,7 +277,8 @@
   end
 
 fun learn _ _ _ [] = ()
-  | learn ctxt overlord save learns =
+  | learn ctxt overlord save (learns : (string * string list * string list list * string list) list) (*##*)
+   =
     let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
       (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
        run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
@@ -353,7 +357,7 @@
 
 local
 
-val version = "*** MaSh version 20130820 ***"
+val version = "*** MaSh version 20131206 ***"
 
 exception FILE_VERSION_TOO_NEW of unit
 
@@ -520,11 +524,12 @@
     end
 
 val default_weight = 1.0
-fun free_feature_of s = ("f" ^ s, 40.0 (* FUDGE *))
-fun thy_feature_of s = ("y" ^ s, 8.0 (* FUDGE *))
-fun type_feature_of s = ("t" ^ s, 4.0 (* FUDGE *))
-fun class_feature_of s = ("s" ^ s, 1.0 (* FUDGE *))
-val local_feature = ("local", 16.0 (* FUDGE *))
+fun free_feature_of s = (["f" ^ s], 40.0 (* FUDGE *))
+fun thy_feature_of s = (["y" ^ s], 8.0 (* FUDGE *))
+fun type_feature_of s = (["t" ^ s], 4.0 (* FUDGE *))
+fun var_feature_of s = ([s], 1.0 (* FUDGE *))
+fun class_feature_of s = (["s" ^ s], 1.0 (* FUDGE *))
+val local_feature = (["local"], 16.0 (* FUDGE *))
 
 fun crude_theory_ord p =
   if Theory.subthy p then
@@ -575,7 +580,7 @@
 val pat_var_prefix = "_"
 
 (* try "Long_Name.base_name" for shorter names *)
-fun massage_long_name s = s
+fun massage_long_name s = if s = hd HOLogic.typeS then "T" else s
 
 val crude_str_of_sort =
   space_implode ":" o map massage_long_name o subtract (op =) @{sort type}
@@ -591,9 +596,30 @@
 
 val max_pat_breadth = 10 (* FUDGE *)
 
-fun term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth ts =
+fun keep m xs =
+  let val n = length xs in
+    if n <= m then xs else take (m div 2) xs @ drop (n - (m + 1) div 2) xs
+  end
+
+fun sort_of_type alg T =
+  let
+    val graph = Sorts.classes_of alg
+    fun cls_of S [] = S
+      | cls_of S (cl :: cls) =
+        if Sorts.of_sort alg (T, [cl]) then
+          cls_of (insert (op =) cl S) cls
+        else
+          let val cls' = Sorts.minimize_sort alg (Sorts.super_classes alg cl) in
+            cls_of S (union (op =) cls' cls)
+          end
+  in
+    cls_of [] (Graph.maximals graph)
+  end
+
+fun term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts =
   let
     val thy = Proof_Context.theory_of ctxt
+    val alg = Sign.classes_of thy
 
     val fixes = map snd (Variable.dest_fixes ctxt)
     val classes = Sign.classes_of thy
@@ -612,8 +638,8 @@
       | pattify_type depth (Type (s, U :: Ts)) =
         let
           val T = Type (s, Ts)
-          val ps = take max_pat_breadth (pattify_type depth T)
-          val qs = take max_pat_breadth ("" :: pattify_type (depth - 1) U)
+          val ps = keep max_pat_breadth (pattify_type depth T)
+          val qs = keep max_pat_breadth ("" :: pattify_type (depth - 1) U)
         in map_product (fn p => fn "" => p | q => p ^ "(" ^ q ^ ")") ps qs end
       | pattify_type _ (TFree (_, S)) =
         maybe_singleton_str pat_tvar_prefix (crude_str_of_sort S)
@@ -638,29 +664,54 @@
          let val count = Symtab.lookup const_tab s |> the_default 1 in
            Real.fromInt num_facts / Real.fromInt count (* FUDGE *)
          end)
-    fun pattify_term _ 0 _ = []
-      | pattify_term _ _ (Const (s, _)) =
-        if is_widely_irrelevant_const s then [] else [(massage_long_name s, weight_of_const s)]
+    fun pattify_term _ 0 _ = ([] : (string list * real) list)
+      | pattify_term _ _ (Const (x as (s, _))) =
+        if is_widely_irrelevant_const s then
+          []
+        else
+          let
+            fun strs_of_sort S =
+              S |> (if in_goal then Sorts.complete_sort alg else single o hd)
+                |> map massage_long_name
+            fun strs_of_type_arg (T as Type (s, _)) =
+                massage_long_name s :: (if in_goal then strs_of_sort (sort_of_type alg T) else [])
+              | strs_of_type_arg (TFree (s, S)) = strs_of_sort S
+              | strs_of_type_arg (TVar (s, S)) = strs_of_sort S
+
+            val typargss =
+              these (try (Sign.const_typargs thy) x)
+              |> map strs_of_type_arg
+              |> n_fold_cartesian_product
+              |> keep max_pat_breadth
+            val s' = massage_long_name s
+            val w = weight_of_const s
+
+            fun str_of_type_args [] = ""
+              | str_of_type_args ss = "(" ^ space_implode "," ss ^ ")"
+          in
+            [(map (curry (op ^) s' o str_of_type_args) typargss, w)]
+          end
       | pattify_term _ _ (Free (s, T)) =
         maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
-        |> map (rpair 1.0)
+        |> map var_feature_of
         |> (if member (op =) fixes s then
-              cons (free_feature_of (massage_long_name
-                  (thy_name ^ Long_Name.separator ^ s)))
+              cons (free_feature_of (massage_long_name (thy_name ^ Long_Name.separator ^ s)))
             else
               I)
       | pattify_term _ _ (Var (_, T)) =
-        maybe_singleton_str pat_var_prefix (crude_str_of_typ T) |> map (rpair default_weight)
+        maybe_singleton_str pat_var_prefix (crude_str_of_typ T)
+        |> map var_feature_of
       | pattify_term Ts _ (Bound j) =
-        maybe_singleton_str pat_var_prefix (crude_str_of_typ (nth Ts j)) |> map (rpair default_weight)
+        maybe_singleton_str pat_var_prefix (crude_str_of_typ (nth Ts j))
+        |> map var_feature_of
       | pattify_term Ts depth (t $ u) =
         let
-          val ps = take max_pat_breadth (pattify_term Ts depth t)
-          val qs = take max_pat_breadth (("", default_weight) :: pattify_term Ts (depth - 1) u)
+          val ps = keep max_pat_breadth (pattify_term Ts depth t)
+          val qs = keep max_pat_breadth (([], default_weight) :: pattify_term Ts (depth - 1) u)
         in
-          map_product (fn ppw as (p, pw) =>
-              fn ("", _) => ppw
-               | (q, qw) => (p ^ "(" ^ q ^ ")", pw + qw)) ps qs
+          map_product (fn ppw as (p :: _, pw) =>
+              fn ([], _) => ppw
+               | (q :: _, qw) => ([p ^ "(" ^ q ^ ")"], pw + qw)) ps qs
         end
       | pattify_term _ _ _ = []
     fun add_term_pat Ts = union (eq_fst (op =)) oo pattify_term Ts
@@ -687,10 +738,10 @@
 val type_max_depth = 1
 
 (* TODO: Generate type classes for types? *)
-fun features_of ctxt thy num_facts const_tab (scope, _) ts =
+fun features_of ctxt thy num_facts const_tab (scope, _) in_goal ts =
   let val thy_name = Context.theory_name thy in
     thy_feature_of thy_name ::
-    term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth ts
+    term_features_of ctxt thy_name num_facts const_tab term_max_depth type_max_depth in_goal ts
     |> scope <> Global ? cons local_feature
   end
 
@@ -940,7 +991,7 @@
       thy_name = Context.theory_name (theory_of_thm th)
     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
       [prop_of th]
-      |> features_of ctxt (theory_of_thm th) num_facts const_tab stature
+      |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
       |> map (apsnd (fn r => weight * factor * r))
 
     val (access_G, suggs) =
@@ -951,7 +1002,7 @@
             let
               val parents = maximal_wrt_access_graph access_G facts
               val goal_feats =
-                features_of ctxt thy num_facts const_tab (Local, General) (concl_t :: hyp_ts)
+                features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
               val chained_feats =
                 chained
                 |> map (rpair 1.0)
@@ -1018,7 +1069,7 @@
     launch_thread (timeout |> the_default one_day) (fn () =>
         let
           val thy = Proof_Context.theory_of ctxt
-          val feats = features_of ctxt thy 0 Symtab.empty (Local, General) [t] |> map fst
+          val feats = features_of ctxt thy 0 Symtab.empty (Local, General) false [t] |> map fst
         in
           peek_state ctxt overlord (fn {access_G, ...} =>
               let
@@ -1117,7 +1168,8 @@
             let
               val name = nickname_of_thm th
               val feats =
-                features_of ctxt (theory_of_thm th) 0 Symtab.empty stature [prop_of th] |> map fst
+                features_of ctxt (theory_of_thm th) 0 Symtab.empty stature false [prop_of th]
+                |> map fst
               val deps = deps_of status th |> these
               val n = n |> not (null deps) ? Integer.add 1
               val learns = (name, parents, feats, deps) :: learns
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_util.ML	Mon Dec 09 04:03:30 2013 +0100
@@ -9,6 +9,7 @@
   val sledgehammerN : string
   val log2 : real -> real
   val app_hd : ('a -> 'a) -> 'a list -> 'a list
+  val n_fold_cartesian_product : 'a list list -> 'a list list
   val plural_s : int -> string
   val serial_commas : string -> string list -> string list
   val simplify_spaces : string -> string
@@ -48,6 +49,12 @@
 
 fun app_hd f (x :: xs) = f x :: xs
 
+fun cartesian_product [] _ = []
+  | cartesian_product (x :: xs) yss =
+    map (cons x) yss @ cartesian_product xs yss
+
+fun n_fold_cartesian_product xss = fold_rev cartesian_product xss [[]]
+
 fun plural_s n = if n = 1 then "" else "s"
 
 val serial_commas = Try.serial_commas