# HG changeset patch # User wenzelm # Date 1401447282 -7200 # Node ID 4874411752fe41c9e6cebd109669146f4f6a3040 # Parent 2f620ef839ee9fced439160f0b88237c8a654f57# Parent a406e15c3cf7122ff93740ede493ff817d2ed71e merged diff -r a406e15c3cf7 -r 4874411752fe src/HOL/BNF_Examples/ListF.thy --- a/src/HOL/BNF_Examples/ListF.thy Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/BNF_Examples/ListF.thy Fri May 30 12:54:42 2014 +0200 @@ -14,6 +14,7 @@ datatype_new 'a listF (map: mapF rel: relF) = NilF (defaults tlF: NilF) | Conss (hdF: 'a) (tlF: "'a listF") + datatype_compat listF definition Singll ("[[_]]") where diff -r a406e15c3cf7 -r 4874411752fe src/HOL/Library/IArray.thy --- a/src/HOL/Library/IArray.thy Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/Library/IArray.thy Fri May 30 12:54:42 2014 +0200 @@ -86,10 +86,14 @@ constant IArray.tabulate \ (SML) "Vector.tabulate" primrec sub' :: "'a iarray \ integer \ 'a" where -"sub' (as, n) = IArray.list_of as ! nat_of_integer n" +[code del]: "sub' (as, n) = IArray.list_of as ! nat_of_integer n" hide_const (open) sub' lemma [code]: + "IArray.sub' (IArray as, n) = as ! nat_of_integer n" + by simp + +lemma [code]: "as !! n = IArray.sub' (as, integer_of_nat n)" by simp @@ -97,10 +101,14 @@ constant IArray.sub' \ (SML) "Vector.sub" definition length' :: "'a iarray \ integer" where -[simp]: "length' as = integer_of_nat (List.length (IArray.list_of as))" +[code del, simp]: "length' as = integer_of_nat (List.length (IArray.list_of as))" hide_const (open) length' lemma [code]: + "IArray.length' (IArray as) = integer_of_nat (List.length as)" + by simp + +lemma [code]: "IArray.length as = nat_of_integer (IArray.length' as)" by simp diff -r a406e15c3cf7 -r 4874411752fe src/HOL/List.thy --- a/src/HOL/List.thy Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/List.thy Fri May 30 12:54:42 2014 +0200 @@ -8,9 +8,10 @@ imports Presburger Code_Numeral Quotient Lifting_Set Lifting_Option Lifting_Product begin -datatype_new (set: 'a) list (map: map rel: list_all2) = +datatype_new (set: 'a) list (map: map rel: list_all2) = Nil (defaults tl: "[]") ("[]") | Cons (hd: 'a) (tl: "'a list") (infixr "#" 65) + datatype_compat list lemma [case_names Nil Cons, cases type: list]: diff -r a406e15c3cf7 -r 4874411752fe src/HOL/Option.thy --- a/src/HOL/Option.thy Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/Option.thy Fri May 30 12:54:42 2014 +0200 @@ -11,6 +11,7 @@ datatype_new 'a option = None | Some (the: 'a) + datatype_compat option lemma [case_names None Some, cases type: option]: diff -r a406e15c3cf7 -r 4874411752fe src/HOL/TPTP/MaSh_Export.thy --- a/src/HOL/TPTP/MaSh_Export.thy Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/TPTP/MaSh_Export.thy Fri May 30 12:54:42 2014 +0200 @@ -33,7 +33,6 @@ val prover = hd provers val range = (1, NONE) val step = 1 -val linearize = false val max_suggestions = 1024 val dir = "List" val prefix = "/tmp/" ^ dir ^ "/" @@ -46,10 +45,39 @@ () *} +ML {* Options.put_default @{system_option MaSh} "sml_nb" *} + +ML {* +if do_it then + generate_mash_suggestions @{context} params (range, step) thys max_suggestions + (prefix ^ "mash_sml_nb_suggestions") +else + () +*} + +ML {* Options.put_default @{system_option MaSh} "sml_knn" *} + ML {* if do_it then - generate_accessibility @{context} thys linearize - (prefix ^ "mash_accessibility") + generate_mash_suggestions @{context} params (range, step) thys max_suggestions + (prefix ^ "mash_sml_knn_suggestions") +else + () +*} + +ML {* Options.put_default @{system_option MaSh} "py" *} + +ML {* +if do_it then + generate_mash_suggestions @{context} params (range, step) thys max_suggestions + (prefix ^ "mash_py_suggestions") +else + () +*} + +ML {* +if do_it then + generate_accessibility @{context} thys (prefix ^ "mash_accessibility") else () *} @@ -63,24 +91,23 @@ ML {* if do_it then - generate_isar_dependencies @{context} range thys linearize - (prefix ^ "mash_dependencies") + generate_isar_dependencies @{context} range thys (prefix ^ "mash_dependencies") else () *} ML {* if do_it then - generate_isar_commands @{context} prover (range, step) thys linearize - max_suggestions (prefix ^ "mash_commands") + generate_isar_commands @{context} prover (range, step) thys max_suggestions + (prefix ^ "mash_commands") else () *} ML {* if do_it then - generate_mepo_suggestions @{context} params (range, step) thys linearize - max_suggestions (prefix ^ "mepo_suggestions") + generate_mepo_suggestions @{context} params (range, step) thys max_suggestions + (prefix ^ "mepo_suggestions") else () *} @@ -88,23 +115,22 @@ ML {* if do_it then generate_mesh_suggestions max_suggestions (prefix ^ "mash_suggestions") - (prefix ^ "mepo_suggestions") (prefix ^ "mesh_suggestions") + (prefix ^ "mepo_suggestions") (prefix ^ "mesh_suggestions") else () *} ML {* if do_it then - generate_prover_dependencies @{context} params range thys linearize - (prefix ^ "mash_prover_dependencies") + generate_prover_dependencies @{context} params range thys (prefix ^ "mash_prover_dependencies") else () *} ML {* if do_it then - generate_prover_commands @{context} params (range, step) thys linearize - max_suggestions (prefix ^ "mash_prover_commands") + generate_prover_commands @{context} params (range, step) thys max_suggestions + (prefix ^ "mash_prover_commands") else () *} @@ -112,7 +138,7 @@ ML {* if do_it then generate_mesh_suggestions max_suggestions (prefix ^ "mash_prover_suggestions") - (prefix ^ "mepo_suggestions") (prefix ^ "mesh_prover_suggestions") + (prefix ^ "mepo_suggestions") (prefix ^ "mesh_prover_suggestions") else () *} diff -r a406e15c3cf7 -r 4874411752fe src/HOL/TPTP/mash_eval.ML --- a/src/HOL/TPTP/mash_eval.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/TPTP/mash_eval.ML Fri May 30 12:54:42 2014 +0200 @@ -97,7 +97,7 @@ mesh_isar_line), mesh_prover_line)) = if in_range range j then let - val get_suggs = extract_suggestions ##> take max_suggs + val get_suggs = extract_suggestions ##> (take max_suggs #> map fst) val (name1, mepo_suggs) = get_suggs mepo_line val (name2, mash_isar_suggs) = get_suggs mash_isar_line val (name3, mash_prover_suggs) = get_suggs mash_prover_line diff -r a406e15c3cf7 -r 4874411752fe src/HOL/TPTP/mash_export.ML --- a/src/HOL/TPTP/mash_export.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/TPTP/mash_export.ML Fri May 30 12:54:42 2014 +0200 @@ -9,18 +9,20 @@ sig type params = Sledgehammer_Prover.params - val generate_accessibility : Proof.context -> theory list -> bool -> string -> unit + val generate_accessibility : Proof.context -> theory list -> string -> unit val generate_features : Proof.context -> theory list -> string -> unit - val generate_isar_dependencies : Proof.context -> int * int option -> theory list -> bool -> - string -> unit + val generate_isar_dependencies : Proof.context -> int * int option -> theory list -> string -> + unit val generate_prover_dependencies : Proof.context -> params -> int * int option -> theory list -> - bool -> string -> unit + string -> unit val generate_isar_commands : Proof.context -> string -> (int * int option) * int -> theory list -> - bool -> int -> string -> unit + int -> string -> unit val generate_prover_commands : Proof.context -> params -> (int * int option) * int -> - theory list -> bool -> int -> string -> unit + theory list -> int -> string -> unit val generate_mepo_suggestions : Proof.context -> params -> (int * int option) * int -> - theory list -> bool -> int -> string -> unit + theory list -> int -> string -> unit + val generate_mash_suggestions : Proof.context -> params -> (int * int option) * int -> + theory list -> int -> string -> unit val generate_mesh_suggestions : int -> string -> string -> string -> unit end; @@ -48,17 +50,14 @@ fun filter_accessible_from th = filter (fn (_, th') => thm_less (th', th)) -fun generate_accessibility ctxt thys linearize file_name = +fun generate_accessibility ctxt thys file_name = let - val path = file_name |> Path.explode - val _ = File.write path "" + val path = Path.explode file_name fun do_fact (parents, fact) prevs = - let - val parents = if linearize then prevs else parents - val s = encode_str fact ^ ": " ^ encode_strs parents ^ "\n" - val _ = File.append path s - in [fact] end + let val s = encode_str fact ^ ": " ^ encode_strs parents ^ "\n" in + File.append path s; [fact] + end val facts = all_facts ctxt @@ -66,7 +65,8 @@ |> attach_parents_to_facts [] |> map (apsnd (nickname_of_thm o snd)) in - fold do_fact facts []; () + File.write path ""; + ignore (fold do_fact facts []) end fun generate_features ctxt thys file_name = @@ -74,13 +74,16 @@ val path = file_name |> Path.explode val _ = File.write path "" val facts = all_facts ctxt |> filter_out (has_thys thys o snd) + fun do_fact ((_, stature), th) = let val name = nickname_of_thm th val feats = features_of ctxt (theory_of_thm th) 0 Symtab.empty stature [prop_of th] |> map fst val s = encode_str name ^ ": " ^ encode_strs (sort string_ord feats) ^ "\n" - in File.append path s end + in + File.append path s + end in List.app do_fact facts end @@ -111,10 +114,10 @@ | NONE => (omitted_marker, [])) end -fun generate_isar_or_prover_dependencies ctxt params_opt range thys linearize file_name = +fun generate_isar_or_prover_dependencies ctxt params_opt range thys file_name = let - val path = file_name |> Path.explode - val facts = all_facts ctxt |> filter_out (has_thys thys o snd) + val path = Path.explode file_name + val facts = filter_out (has_thys thys o snd) (all_facts ctxt) val name_tabs = build_name_tables nickname_of_thm facts fun do_fact (j, (_, th)) = @@ -122,11 +125,11 @@ let val name = nickname_of_thm th val _ = tracing ("Fact " ^ string_of_int j ^ ": " ^ name) - val access_facts = - if linearize then take (j - 1) facts else facts |> filter_accessible_from th - val (marker, deps) = - smart_dependencies_of ctxt params_opt access_facts name_tabs th NONE - in encode_str name ^ ": " ^ marker ^ " " ^ encode_strs deps ^ "\n" end + val access_facts = filter_accessible_from th facts + val (marker, deps) = smart_dependencies_of ctxt params_opt access_facts name_tabs th NONE + in + encode_str name ^ ": " ^ marker ^ " " ^ encode_strs deps ^ "\n" + end else "" @@ -147,8 +150,7 @@ null isar_deps orelse is_blacklisted_or_something ctxt ho_atp (Thm.get_name_hint th) -fun generate_isar_or_prover_commands ctxt prover params_opt (range, step) thys linearize max_suggs - file_name = +fun generate_isar_or_prover_commands ctxt prover params_opt (range, step) thys max_suggs file_name = let val ho_atp = is_ho_atp ctxt prover val path = file_name |> Path.explode @@ -166,17 +168,15 @@ val goal_feats = features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature [prop_of th] |> sort_wrt fst - val access_facts = - (if linearize then take (j - 1) new_facts else new_facts |> filter_accessible_from th) @ - old_facts + val access_facts = filter_accessible_from th new_facts @ old_facts val (marker, deps) = - smart_dependencies_of ctxt params_opt access_facts name_tabs th - (SOME isar_deps) - val parents = if linearize then prevs else parents + smart_dependencies_of ctxt params_opt access_facts name_tabs th (SOME isar_deps) + fun extra_features_of (((_, stature), th), weight) = [prop_of th] |> features_of ctxt (theory_of_thm th) (num_old_facts + j) const_tab stature |> map (apsnd (fn r => weight * extra_feature_factor * r)) + val query = if do_query then let @@ -189,9 +189,8 @@ |> map extra_features_of |> rpair goal_feats |-> fold (union (eq_fst (op =))) in - "? " ^ string_of_int max_suggs ^ " # " ^ - encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ - encode_features query_feats ^ "\n" + "? " ^ string_of_int max_suggs ^ " # " ^ encode_str name ^ ": " ^ + encode_strs parents ^ "; " ^ encode_features query_feats ^ "\n" end else "" @@ -201,9 +200,9 @@ in query ^ update end else "" + val new_facts = - new_facts |> attach_parents_to_facts old_facts - |> map (`(nickname_of_thm o snd o snd)) + new_facts |> attach_parents_to_facts old_facts |> map (`(nickname_of_thm o snd o snd)) val hd_prevs = map (nickname_of_thm o snd) (the_list (try List.last old_facts)) val prevss = hd_prevs :: map (single o fst) new_facts |> split_last |> fst val hd_const_tabs = fold (add_const_counts o prop_of o snd) old_facts Symtab.empty @@ -220,8 +219,8 @@ fun generate_prover_commands ctxt (params as {provers = prover :: _, ...}) = generate_isar_or_prover_commands ctxt prover (SOME params) -fun generate_mepo_suggestions ctxt (params as {provers = prover :: _, ...}) (range, step) thys - linearize max_suggs file_name = +fun generate_mepo_or_mash_suggestions mepo_or_mash_suggested_facts ctxt + (params as {provers = prover :: _, ...}) (range, step) thys max_suggs file_name = let val ho_atp = is_ho_atp ctxt prover val path = file_name |> Path.explode @@ -244,11 +243,12 @@ let val suggs = old_facts - |> not linearize ? filter_accessible_from th - |> Sledgehammer_Fact.drop_duplicate_facts - |> Sledgehammer_MePo.mepo_suggested_facts ctxt params max_suggs NONE hyp_ts concl_t + |> filter_accessible_from th + |> mepo_or_mash_suggested_facts ctxt params max_suggs hyp_ts concl_t |> map (nickname_of_thm o snd) - in encode_str name ^ ": " ^ encode_strs suggs ^ "\n" end + in + encode_str name ^ ": " ^ encode_strs suggs ^ "\n" + end end else "" @@ -260,6 +260,22 @@ File.write_list path lines end +val generate_mepo_suggestions = + generate_mepo_or_mash_suggestions + (fn ctxt => fn params => fn max_suggs => fn hyp_ts => fn concl_t => + Sledgehammer_Fact.drop_duplicate_facts + #> Sledgehammer_MePo.mepo_suggested_facts ctxt params max_suggs NONE hyp_ts concl_t) + +fun generate_mash_suggestions ctxt params = + (Sledgehammer_MaSh.mash_unlearn ctxt params; + generate_mepo_or_mash_suggestions + (fn ctxt => fn params as {provers = prover :: _, ...} => fn max_suggs => fn hyp_ts => + fn concl_t => + tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover true 2 false + Sledgehammer_Util.one_year) + #> Sledgehammer_MaSh.mash_suggested_facts ctxt params max_suggs hyp_ts concl_t + #> fst) ctxt params) + fun generate_mesh_suggestions max_suggs mash_file_name mepo_file_name mesh_file_name = let val mesh_path = Path.explode mesh_file_name @@ -269,10 +285,10 @@ let val (name, mash_suggs) = extract_suggestions mash_line - ||> weight_facts_steeply + ||> (map fst #> weight_facts_steeply) val (name', mepo_suggs) = extract_suggestions mepo_line - ||> weight_facts_steeply + ||> (map fst #> weight_facts_steeply) val _ = if name = name' then () else error "Input files out of sync." val mess = [(mepo_weight, (mepo_suggs, [])), @@ -284,10 +300,8 @@ val mash_lines = Path.explode mash_file_name |> File.read_lines val mepo_lines = Path.explode mepo_file_name |> File.read_lines in - if length mash_lines = length mepo_lines then - List.app do_fact (mash_lines ~~ mepo_lines) - else - warning "Skipped: MaSh file missing or out of sync with MePo file." + if length mash_lines = length mepo_lines then List.app do_fact (mash_lines ~~ mepo_lines) + else warning "Skipped: MaSh file missing or out of sync with MePo file." end end; diff -r a406e15c3cf7 -r 4874411752fe src/HOL/Tools/Sledgehammer/MaSh/src/server.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Fri May 30 12:54:42 2014 +0200 @@ -159,10 +159,10 @@ # Output predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]] - #predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]] - #predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] - #predictionsString = string.join(predictionsStringList,' ') - predictionsString = string.join(predictionNames,' ') + predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]] + predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] + predictionsString = string.join(predictionsStringList,' ') + #predictionsString = string.join(predictionNames,' ') outString = '%s: %s' % (name,predictionsString) self.request.sendall(outString) diff -r a406e15c3cf7 -r 4874411752fe src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri May 30 12:54:42 2014 +0200 @@ -32,9 +32,12 @@ val decode_str : string -> string val decode_strs : string -> string list val encode_features : (string * real) list -> string - val extract_suggestions : string -> string * string list + val extract_suggestions : string -> string * (string * real) list - datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB + datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py + + val is_mash_enabled : unit -> bool + val the_mash_engine : unit -> mash_engine structure MaSh_Py : sig @@ -45,12 +48,18 @@ val query : Proof.context -> bool -> int -> (string * string list * string list * string list) list * string list * string list * (string * real) list -> - string list + (string * real) list end structure MaSh_SML : sig - val query : Proof.context -> mash_engine -> string list -> int -> + val k_nearest_neighbors : int -> int -> (int -> int list) -> (int -> (int * real) list) -> + int -> (int * real) list -> (int * real) list + val naive_bayes : int -> int -> (int -> int list) -> (int -> int list) -> int -> int -> + (int * real) list -> (int * real) list + val naive_bayes_py : Proof.context -> bool -> int -> int -> (int -> int list) -> + (int -> int list) -> int -> int -> (int * real) list -> (int * real) list + val query : Proof.context -> bool -> mash_engine -> string list -> int -> (string * (string * real) list * string list) list * string list * (string * real) list -> string list end @@ -82,6 +91,8 @@ val mash_suggested_facts : Proof.context -> params -> int -> term list -> term -> raw_fact list -> fact list * fact list val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit + val mash_learn_facts : Proof.context -> params -> string -> bool -> int -> bool -> Time.time -> + raw_fact list -> string val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit val mash_can_suggest_facts : Proof.context -> bool -> bool @@ -135,7 +146,7 @@ () end -datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB +datatype mash_engine = MaSh_Py | MaSh_SML_kNN | MaSh_SML_NB | MaSh_SML_NB_Py fun mash_engine () = let val flag1 = Options.default_string @{system_option MaSh} in @@ -145,6 +156,7 @@ | "sml" => SOME MaSh_SML_NB | "sml_knn" => SOME MaSh_SML_kNN | "sml_nb" => SOME MaSh_SML_NB + | "sml_nb_py" => SOME MaSh_SML_NB_Py | _ => NONE) end @@ -157,7 +169,7 @@ val save_models_arg = "--saveModels" val shutdown_server_arg = "--shutdownServer" -fun wipe_out_file file = (try (File.rm o Path.explode) file; ()) +fun wipe_out_file file = ignore (try (File.rm o Path.explode) file) fun write_file banner (xs, f) path = (case banner of SOME s => File.write path s | NONE => (); @@ -258,8 +270,8 @@ (* The suggested weights do not make much sense. *) fun extract_suggestion sugg = (case space_explode "=" sugg of - [name, _ (* weight *)] => SOME (decode_str name) - | [name] => SOME (decode_str name) + [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0) + | [name] => SOME (decode_str name, 1.0) | _ => NONE) fun extract_suggestions line = @@ -445,14 +457,11 @@ ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end -val nb_tau = 0.02 (* FUDGE *) -val nb_pos_weight = 2.0 (* FUDGE *) -val nb_def_val = ~15.0 (* FUDGE *) val nb_def_prior_weight = 20 (* FUDGE *) (* TODO: Either use IDF or don't use it. See commented out code portions below. *) -fun naive_bayes_learn num_facts get_deps get_th_feats num_feats = +fun naive_bayes_learn num_facts get_deps get_feats num_feats = let val tfreq = Array.array (num_facts, 0) val sfreq = Array.array (num_facts, Inttab.empty) @@ -477,11 +486,17 @@ end fun for i = - if i = num_facts then () else (learn i (get_th_feats i) (get_deps i); for (i + 1)) + if i = num_facts then () else (learn i (get_feats i) (get_deps i); for (i + 1)) in for 0; (Array.vector tfreq, Array.vector sfreq (*, Array.vector dffreq *)) end +val nb_kuehlwein_style = false + +val nb_tau = if nb_kuehlwein_style then 0.05 else 0.02 (* FUDGE *) +val nb_pos_weight = if nb_kuehlwein_style then 20.0 else 2.0 (* FUDGE *) +val nb_def_val = ~15.0 (* FUDGE *) + fun naive_bayes_query _ (* num_facts *) num_visible_facts max_suggs feats (tfreq, sfreq (*, dffreq*)) = let (* @@ -503,8 +518,12 @@ val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i)) - fun fold_sfh (f, sf) sow = - sow + tfidf f * (Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)) + val fold_sfh = + if nb_kuehlwein_style then + (fn (f, sf) => fn sow => sow - tfidf f * (tfreq / Math.ln (Real.fromInt sf))) + else + (fn (f, sf) => fn sow => + sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)) val sum_of_weights = Inttab.fold fold_sfh sfh 0.0 in @@ -520,15 +539,34 @@ ret [] (Integer.max 0 (num_visible_facts - max_suggs)) end -fun naive_bayes num_facts num_visible_facts get_deps get_th_feats num_feats max_suggs feats = - naive_bayes_learn num_facts get_deps get_th_feats num_feats +fun naive_bayes num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats = + naive_bayes_learn num_facts get_deps get_feats num_feats |> naive_bayes_query num_facts num_visible_facts max_suggs feats +(* experimental *) +fun naive_bayes_py ctxt overlord num_facts num_visible_facts get_deps get_feats num_feats max_suggs + feats = + let + fun name_of_fact j = "f" ^ string_of_int j + fun fact_of_name s = the (Int.fromString (unprefix "f" s)) + fun name_of_feature j = "F" ^ string_of_int j + fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)] + + val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j), + map name_of_fact (get_deps j))) (0 upto num_facts - 1) + val parents' = parents_of num_visible_facts + val feats' = map (apfst name_of_feature) feats + in + MaSh_Py.unlearn ctxt overlord; + MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats') + |> map (apfst fact_of_name) + end + fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys) fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i))) -fun query ctxt engine visible_facts max_suggs (learns, hints, feats) = +fun query ctxt overlord engine visible_facts max_suggs (learns, hints, feats) = let val visible_fact_set = Symtab.make_set visible_facts @@ -586,8 +624,8 @@ val get_unweighted_feats = curry Vector.sub unweighted_feats_ary val feats' = map (apfst (the_default ~1 o Symtab.lookup feat_tab)) feats in - naive_bayes num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs - feats' + (if engine = MaSh_SML_NB then naive_bayes else naive_bayes_py ctxt overlord) + num_facts num_visible_facts get_deps get_unweighted_feats num_feats max_suggs feats' end) |> map (curry Vector.sub fact_vec o fst) end @@ -1285,6 +1323,7 @@ if engine = MaSh_Py then let val (parents, hints, feats) = query_args access_G in MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats) + |> map fst end else [])) @@ -1299,7 +1338,7 @@ val learns = Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G in - MaSh_SML.query ctxt engine visible_facts max_facts (learns, hints, feats) + MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, feats) end val unknown = filter_out (is_fact_in_graph access_G o snd) facts @@ -1574,7 +1613,7 @@ (* Generate more suggestions than requested, because some might be thrown out later for various reasons. *) -fun generous_max_suggestions max_facts = max_facts (*### 11 * max_facts div 10 + 20 *) +fun generous_max_suggestions max_facts = 11 * max_facts div 10 + 20 val mepo_weight = 0.5 val mash_weight = 0.5 diff -r a406e15c3cf7 -r 4874411752fe src/Pure/Isar/class.ML --- a/src/Pure/Isar/class.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/Pure/Isar/class.ML Fri May 30 12:54:42 2014 +0200 @@ -316,14 +316,9 @@ local -fun target_extension f class b_mx_rhs lthy = - let - val phi = morphism (Proof_Context.theory_of lthy) class; - in - lthy - |> Local_Theory.raw_theory (f class phi b_mx_rhs) - |> synchronize_class_syntax_target class - end; +fun target_extension f class b_mx_rhs = + Local_Theory.raw_theory (fn thy => f class (morphism thy class) b_mx_rhs thy) + #> synchronize_class_syntax_target class; fun class_const class prmode (b, rhs) = Generic_Target.locale_declaration class true (fn phi => @@ -344,28 +339,39 @@ end) #> Generic_Target.const (fn (this, other) => other <> 0 andalso this <> other) prmode ((b, NoSyn), rhs); -fun global_const (type_params, term_params) class phi ((b, mx), rhs) thy = +fun dangling_params_for lthy class (type_params, term_params) = let - val class_params = map fst (these_params thy [class]); - val additional_params = - subtract (fn (v, Free (w, _)) => v = w | _ => false) class_params term_params; - val context_params = map (Morphism.term phi) (type_params @ additional_params); + val class_param_names = + map fst (these_params (Proof_Context.theory_of lthy) [class]); + val dangling_term_params = + subtract (fn (v, Free (w, _)) => v = w | _ => false) class_param_names term_params; + in type_params @ dangling_term_params end; + +fun global_def (b, eq) thy = + thy + |> Thm.add_def_global false false (b, eq) + |>> (Thm.varifyT_global o snd) + |-> (fn def_thm => Global_Theory.store_thm (b, def_thm) + #> snd + #> pair def_thm); + +fun global_const dangling_params class phi ((b, mx), rhs) thy = + let + val dangling_params' = map (Morphism.term phi) dangling_params; val b' = Morphism.binding phi b; val b_def = Morphism.binding phi (Binding.suffix_name "_dict" b'); + val rhs' = Morphism.term phi rhs; val c' = Sign.full_name thy b'; - val rhs' = Morphism.term phi rhs; - val ty' = map Term.fastype_of context_params ---> Term.fastype_of rhs'; - val def_eq = Logic.mk_equals (list_comb (Const (c', ty'), context_params), rhs') + val ty' = map Term.fastype_of dangling_params' ---> Term.fastype_of rhs'; + val def_eq = Logic.mk_equals (list_comb (Const (c', ty'), dangling_params'), rhs') |> map_types Type.strip_sorts; in thy |> Sign.declare_const_global ((b', Type.strip_sorts ty'), mx) |> snd - |> Thm.add_def_global false false (b_def, def_eq) - |>> apsnd Thm.varifyT_global - |-> (fn (_, def_thm) => Global_Theory.store_thm (b_def, def_thm) - #> snd - #> null context_params ? register_operation class (c', (rhs', SOME (Thm.symmetric def_thm)))) + |> global_def (b_def, def_eq) + |-> (fn def_thm => + null dangling_params' ? register_operation class (c', (rhs', SOME (Thm.symmetric def_thm)))) |> Sign.add_const_constraint (c', SOME ty') end; @@ -373,8 +379,8 @@ let val unchecks = these_unchecks thy [class]; val b' = Morphism.binding phi b; + val rhs' = Pattern.rewrite_term thy unchecks [] rhs; val c' = Sign.full_name thy b'; - val rhs' = Pattern.rewrite_term thy unchecks [] rhs; val ty' = Term.fastype_of rhs'; in thy @@ -387,9 +393,14 @@ in -fun const class ((b, mx), lhs) params = - class_const class Syntax.mode_default (b, lhs) - #> target_extension (global_const params) class ((b, mx), lhs); +fun const class ((b, mx), lhs) params lthy = + let + val dangling_params = dangling_params_for lthy class params; + in + lthy + |> class_const class Syntax.mode_default (b, lhs) + |> target_extension (global_const dangling_params) class ((b, mx), lhs) + end; fun abbrev class prmode ((b, mx), lhs) rhs' = class_const class prmode (b, lhs) diff -r a406e15c3cf7 -r 4874411752fe src/Pure/Isar/generic_target.ML --- a/src/Pure/Isar/generic_target.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/Pure/Isar/generic_target.ML Fri May 30 12:54:42 2014 +0200 @@ -21,7 +21,7 @@ (Attrib.binding * (thm list * Args.src list) list) list -> local_theory -> local_theory val background_abbrev: binding * term -> term list -> local_theory -> (term * term) * local_theory - val abbrev: (string * bool -> binding * mixfix -> term * term -> + val abbrev: (string * bool -> binding * mixfix -> term -> term list -> local_theory -> local_theory) -> string * bool -> (binding * mixfix) * term -> local_theory -> (term * term) * local_theory val background_declaration: declaration -> local_theory -> local_theory @@ -41,7 +41,7 @@ (Attrib.binding * (thm list * Args.src list) list) list -> (Attrib.binding * (thm list * Args.src list) list) list -> local_theory -> local_theory - val theory_abbrev: Syntax.mode -> (binding * mixfix) -> term * term -> term list -> + val theory_abbrev: Syntax.mode -> (binding * mixfix) -> term -> term list -> local_theory -> local_theory val theory_declaration: declaration -> local_theory -> local_theory val theory_registration: string * morphism -> (morphism * bool) option -> morphism -> @@ -216,7 +216,7 @@ val mx' = check_mixfix lthy (b, extra_tfrees) mx; in lthy - |> target_abbrev prmode (b, mx') (global_rhs, rhs') params + |> target_abbrev prmode (b, mx') global_rhs params |> Proof_Context.add_abbrev Print_Mode.internal (b, rhs) |> snd |> Local_Defs.fixed_abbrev ((b, NoSyn), rhs) end; @@ -331,7 +331,7 @@ ctxt |> Attrib.local_notes kind (Element.transform_facts (Local_Theory.standard_morphism lthy ctxt) local_facts) |> snd)); -fun theory_abbrev prmode (b, mx) (global_rhs, _) params = +fun theory_abbrev prmode (b, mx) global_rhs params = Local_Theory.background_theory_result (Sign.add_abbrev (#1 prmode) (b, global_rhs) #-> (fn (lhs, _) => (* FIXME type_params!? *) diff -r a406e15c3cf7 -r 4874411752fe src/Pure/Isar/named_target.ML --- a/src/Pure/Isar/named_target.ML Fri May 30 11:02:02 2014 +0200 +++ b/src/Pure/Isar/named_target.ML Fri May 30 12:54:42 2014 +0200 @@ -75,11 +75,11 @@ (* abbrev *) -fun locale_abbrev locale prmode (b, mx) (global_rhs, _) params = +fun locale_abbrev locale prmode (b, mx) global_rhs params = Generic_Target.background_abbrev (b, global_rhs) params #-> (fn (lhs, _) => Generic_Target.locale_const locale prmode ((b, mx), lhs)); -fun class_abbrev class prmode (b, mx) (global_rhs, _) params = +fun class_abbrev class prmode (b, mx) global_rhs params = Generic_Target.background_abbrev (b, global_rhs) params #-> (fn (lhs, rhs) => Class.abbrev class prmode ((b, mx), lhs) rhs);