src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57432 78d7fbe9b203
parent 57431 02c408aed5ee
child 57458 419180c354c0
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jun 29 18:28:27 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Sun Jun 29 18:28:27 2014 +0200
@@ -32,7 +32,6 @@
   val encode_strs : string list -> string
   val decode_str : string -> string
   val decode_strs : string -> string list
-  val encode_features : (string * real) list -> string
 
   datatype mash_engine =
     MaSh_kNN
@@ -62,8 +61,8 @@
   val extra_feature_factor : real
   val weight_facts_smoothly : 'a list -> ('a * real) list
   val weight_facts_steeply : 'a list -> ('a * real) list
-  val find_mash_suggestions : Proof.context -> int -> string list -> ('b * thm) list ->
-    ('b * thm) list -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
+  val find_mash_suggestions : Proof.context -> int -> string list -> ('a * thm) list ->
+    ('a * thm) list -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
   val mash_suggested_facts : Proof.context -> theory -> params -> int -> term list -> term ->
     raw_fact list -> fact list * fact list
 
@@ -72,8 +71,8 @@
   val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
     raw_fact list -> string
   val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
+  val mash_can_suggest_facts : Proof.context -> bool
 
-  val mash_can_suggest_facts : Proof.context -> bool
   val generous_max_suggestions : int -> int
   val mepo_weight : real
   val mash_weight : real
@@ -160,36 +159,6 @@
 val the_mash_engine = the_default MaSh_NB o mash_engine
 
 
-(*** Maintenance of the persistent, string-based state ***)
-
-fun meta_char c =
-  if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse
-     c = #"," then
-    String.str c
-  else
-    (* fixed width, in case more digits follow *)
-    "%" ^ stringN_of_int 3 (Char.ord c)
-
-fun unmeta_chars accum [] = String.implode (rev accum)
-  | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
-    (case Int.fromString (String.implode [d1, d2, d3]) of
-      SOME n => unmeta_chars (Char.chr n :: accum) cs
-    | NONE => "" (* error *))
-  | unmeta_chars _ (#"%" :: _) = "" (* error *)
-  | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
-
-val encode_str = String.translate meta_char
-val decode_str = String.explode #> unmeta_chars []
-
-val encode_strs = map encode_str #> space_implode " "
-val decode_strs = space_explode " " #> filter_out (curry (op =) "") #> map decode_str
-
-fun encode_feature (names, weight) =
-  encode_str names ^ (if Real.== (weight, 1.0) then "" else "=" ^ Real.toString weight)
-
-val encode_features = map encode_feature #> space_implode " "
-
-
 (*** Isabelle-agnostic machine learning ***)
 
 structure MaSh =
@@ -464,14 +433,14 @@
 val naive_bayes_ext = external_tool "predict/nbayes"
 
 fun query_external ctxt engine max_suggs learns goal_feats =
-  (trace_msg ctxt (fn () => "MaSh query external " ^ encode_features goal_feats);
+  (trace_msg ctxt (fn () => "MaSh query external " ^ commas (map fst goal_feats));
    (case engine of
      MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
    | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
 
 fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
     (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
-  (trace_msg ctxt (fn () => "MaSh query internal " ^ encode_features goal_feats ^ " from {" ^
+  (trace_msg ctxt (fn () => "MaSh query internal " ^ commas (map fst goal_feats) ^ " from {" ^
      elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
    (case engine of
      MaSh_kNN =>
@@ -490,7 +459,29 @@
 end;
 
 
-(*** Middle-level communication with MaSh ***)
+(*** Persistent, stringly-typed state ***)
+
+fun meta_char c =
+  if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse
+     c = #"," then
+    String.str c
+  else
+    (* fixed width, in case more digits follow *)
+    "%" ^ stringN_of_int 3 (Char.ord c)
+
+fun unmeta_chars accum [] = String.implode (rev accum)
+  | unmeta_chars accum (#"%" :: d1 :: d2 :: d3 :: cs) =
+    (case Int.fromString (String.implode [d1, d2, d3]) of
+      SOME n => unmeta_chars (Char.chr n :: accum) cs
+    | NONE => "" (* error *))
+  | unmeta_chars _ (#"%" :: _) = "" (* error *)
+  | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
+
+val encode_str = String.translate meta_char
+val decode_str = String.explode #> unmeta_chars []
+
+val encode_strs = map encode_str #> space_implode " "
+val decode_strs = space_explode " " #> filter_out (curry (op =) "") #> map decode_str
 
 datatype proof_kind = Isar_Proof | Automatic_Proof | Isar_Proof_wegen_Prover_Flop