# HG changeset patch # User blanchet # Date 1404059307 -7200 # Node ID 02c408aed5eecff8c8ec28263c62b23da415f8f2 # Parent 020cea57eaa408e56acbd87bcc58d4f97db02e6a killed Python version of MaSh, now that the SML version works adequately diff -r 020cea57eaa4 -r 02c408aed5ee NEWS --- a/NEWS Sat Jun 28 22:13:23 2014 +0200 +++ b/NEWS Sun Jun 29 18:28:27 2014 +0200 @@ -433,9 +433,7 @@ and increase performance and reliability. - MaSh and MeSh are now used by default together with the traditional MePo (Meng-Paulson) relevance filter. To disable MaSh, set the "MaSh" - system option in Plugin Options / Isabelle / General to "none". Other - allowed values include "sml" (for the default SML engine) and "py" - (for the old Python engine). + system option in Plugin Options / Isabelle / General to "none". - New option: smt_proofs - Renamed options: diff -r 020cea57eaa4 -r 02c408aed5ee src/Doc/Sledgehammer/document/root.tex --- a/src/Doc/Sledgehammer/document/root.tex Sat Jun 28 22:13:23 2014 +0200 +++ b/src/Doc/Sledgehammer/document/root.tex Sun Jun 29 18:28:27 2014 +0200 @@ -1059,14 +1059,11 @@ The MaSh machine learner. Three learning engines are provided: \begin{enum} -\item[\labelitemi] \textbf{\textit{sml\_nb}} (also called \textbf{\textit{sml}} +\item[\labelitemi] \textbf{\textit{nb}} (also called \textbf{\textit{sml}} and \textbf{\textit{yes}}) is a Standard ML implementation of naive Bayes. -\item[\labelitemi] \textbf{\textit{sml\_knn}} is a Standard ML implementation of +\item[\labelitemi] \textbf{\textit{knn}} is a Standard ML implementation of $k$-nearest neighbors. - -\item[\labelitemi] \textbf{\textit{py}} is a Python implementation of naive Bayes. -The program is included with Isabelle as \texttt{mash.py}. \end{enum} In addition, the special value \textit{none} is used to disable machine learning by @@ -1077,10 +1074,7 @@ \texttt{\$ISABELLE\_HOME\_USER/etc/settings} file, or via the ``MaSh'' option under ``Plugins > Plugin Options > Isabelle > General'' in Isabelle/jEdit. Persistent data for both engines is stored in the directory -\texttt{\$ISABELLE\_HOME\_USER/mash}. When switching to the \textit{py} engine, -you will need to invoke the \textit{relearn\_isar} subcommand -(\S\ref{sledgehammer}) to synchronize the persistent databases on the -Python side. +\texttt{\$ISABELLE\_HOME\_USER/mash}. \item[\labelitemi] \textbf{\textit{mesh}:} The MeSh filter, which combines the rankings from MePo and MaSh. diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/TPTP/MaSh_Export.thy --- a/src/HOL/TPTP/MaSh_Export.thy Sat Jun 28 22:13:23 2014 +0200 +++ b/src/HOL/TPTP/MaSh_Export.thy Sun Jun 29 18:28:27 2014 +0200 @@ -46,22 +46,22 @@ () *} -ML {* Options.put_default @{system_option MaSh} "sml_nb" *} +ML {* Options.put_default @{system_option MaSh} "nb" *} ML {* if do_it then generate_mash_suggestions @{context} params (range, step) thys max_suggestions - (prefix ^ "mash_sml_nb_suggestions") + (prefix ^ "mash_nb_suggestions") else () *} -ML {* Options.put_default @{system_option MaSh} "sml_knn" *} +ML {* Options.put_default @{system_option MaSh} "knn" *} ML {* if do_it then generate_mash_suggestions @{context} params (range, step) thys max_suggestions - (prefix ^ "mash_sml_knn_suggestions") + (prefix ^ "mash_knn_suggestions") else () *} diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/TPTP/mash_eval.ML --- a/src/HOL/TPTP/mash_eval.ML Sat Jun 28 22:13:23 2014 +0200 +++ b/src/HOL/TPTP/mash_eval.ML Sun Jun 29 18:28:27 2014 +0200 @@ -29,6 +29,7 @@ open Sledgehammer_Prover open Sledgehammer_Prover_ATP open Sledgehammer_Commands +open MaSh_Export val prefix = Library.prefix @@ -38,9 +39,6 @@ val MeSh_ProverN = MeShN ^ "-Prover" val IsarN = "Isar" -fun in_range (from, to) j = - j >= from andalso (to = NONE orelse j <= the to) - fun evaluate_mash_suggestions ctxt params range methods prob_dir_name mepo_file_name mash_isar_file_name mash_prover_file_name mesh_isar_file_name mesh_prover_file_name report_file_name = diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/TPTP/mash_export.ML --- a/src/HOL/TPTP/mash_export.ML Sat Jun 28 22:13:23 2014 +0200 +++ b/src/HOL/TPTP/mash_export.ML Sun Jun 29 18:28:27 2014 +0200 @@ -9,6 +9,9 @@ sig type params = Sledgehammer_Prover.params + val in_range : int * int option -> int -> bool + val extract_suggestions : string -> string * (string * real) list + val generate_accessibility : Proof.context -> theory list -> string -> unit val generate_features : Proof.context -> theory list -> string -> unit val generate_isar_dependencies : Proof.context -> int * int option -> theory list -> string -> @@ -37,6 +40,18 @@ fun in_range (from, to) j = j >= from andalso (to = NONE orelse j <= the to) +(* The suggested weights do not make much sense. *) +fun extract_suggestion sugg = + (case space_explode "=" sugg of + [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0) + | [name] => SOME (decode_str name, 1.0) + | _ => NONE) + +fun extract_suggestions line = + (case space_explode ":" line of + [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs)) + | _ => ("", [])) + fun has_thm_thy th thy = Context.theory_name thy = Context.theory_name (theory_of_thm th) @@ -265,11 +280,11 @@ #> Sledgehammer_MePo.mepo_suggested_facts ctxt params max_suggs NONE hyp_ts concl_t) fun generate_mash_suggestions ctxt params = - (Sledgehammer_MaSh.mash_unlearn ctxt params; + (Sledgehammer_MaSh.mash_unlearn (); generate_mepo_or_mash_suggestions (fn ctxt => fn thy => fn params as {provers = prover :: _, ...} => fn max_suggs => fn hyp_ts => fn concl_t => - tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover true 2 false + tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover 2 false Sledgehammer_Util.one_year) #> Sledgehammer_MaSh.mash_suggested_facts ctxt thy params max_suggs hyp_ts concl_t #> fst) ctxt params) diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/etc/settings --- a/src/HOL/Tools/Sledgehammer/MaSh/etc/settings Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,8 +0,0 @@ -# -*- shell-script -*- :mode=shellscript: - -ISABELLE_SLEDGEHAMMER_MASH="$COMPONENT" - -# MASH=yes -if [ -z "$MASH_PORT" ]; then - MASH_PORT=9255 -fi diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/ExpandFeatures.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/ExpandFeatures.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,163 +0,0 @@ -''' -Created on Aug 21, 2013 - -@author: daniel -''' - -from math import log -#from gensim import corpora, models, similarities - -class ExpandFeatures(object): - - def __init__(self,dicts): - self.dicts = dicts - self.featureMap = {} - self.alpha = 0.1 - self.featureCounts = {} - self.counter = 0 - self.corpus = [] -# self.LSIModel = models.lsimodel.LsiModel(self.corpus,num_topics=500) - - def initialize(self,dicts): - self.dicts = dicts - IS = open(dicts.accFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - #print 'name',name - nameId = dicts.nameIdDict[name] - features = dicts.featureDict[nameId] - dependencies = dicts.dependenciesDict[nameId] - x = [self.dicts.idNameDict[d] for d in dependencies] - #print x - self.update(features, dependencies) - self.corpus.append([(x,1) for x in features.keys()]) - IS.close() - print 'x' - #self.LSIModel = models.lsimodel.LsiModel(self.corpus,num_topics=500) - print self.LSIModel - print 'y' - - def update(self,features,dependencies): - self.counter += 1 - self.corpus.append([(x,1) for x in features.keys()]) - self.LSIModel.add_documents([[(x,1) for x in features.keys()]]) - """ - for f in features.iterkeys(): - try: - self.featureCounts[f] += 1 - except: - self.featureCounts[f] = 1 - if self.featureCounts[f] > 100: - continue - try: - self.featureMap[f] = self.featureMap[f].intersection(features.keys()) - except: - self.featureMap[f] = set(features.keys()) - #print 'fOld',len(fMap),self.featureCounts[f],len(dependencies) - - for d in dependencies[1:]: - #print 'dep',self.dicts.idNameDict[d] - dFeatures = self.dicts.featureDict[d] - for df in dFeatures.iterkeys(): - if self.featureCounts.has_key(df): - if self.featureCounts[df] > 20: - continue - else: - print df - try: - fMap[df] += self.alpha * (1.0 - fMap[df]) - except: - fMap[df] = self.alpha - """ - #print 'fNew',len(fMap) - - def expand(self,features): - #print self.corpus[:50] - #print corpus - #tfidfmodel = models.TfidfModel(self.corpus, normalize=True) - #print features.keys() - #tfidfcorpus = [tfidfmodel[x] for x in self.corpus] - #newFeatures = LSI[[(x,1) for x in features.keys()]] - #newFeatures = self.LSIModel[[(x,1) for x in features.keys()]] - #print features - #print newFeatures - #print newFeatures - - """ - newFeatures = dict(features) - for f in features.keys(): - try: - fC = self.featureCounts[f] - except: - fC = 0.5 - newFeatures[f] = log(float(8+self.counter) / fC) - #nrOfFeatures = float(len(features)) - addedCount = 0 - alpha = 0.2 - #""" - - """ - consideredFeatures = [] - while len(newFeatures) < 30: - #alpha = alpha * 0.5 - minF = None - minFrequence = 1000000 - for f in newFeatures.iterkeys(): - if f in consideredFeatures: - continue - try: - if self.featureCounts[f] < minFrequence: - minF = f - except: - pass - if minF == None: - break - # Expand minimal feature - consideredFeatures.append(minF) - for expF in self.featureMap[minF]: - if not newFeatures.has_key(expF): - fC = self.featureCounts[minF] - newFeatures[expF] = alpha*log(float(8+self.counter) / fC) - #print features, newFeatures - #""" - """ - for f in features.iterkeys(): - try: - self.featureCounts[f] += 1 - except: - self.featureCounts[f] = 0 - if self.featureCounts[f] > 10: - continue - addedCount += 1 - try: - fmap = self.featureMap[f] - except: - self.featureMap[f] = {} - fmap = {} - for nf,nv in fmap.iteritems(): - try: - newFeatures[nf] += nv - except: - newFeatures[nf] = nv - if addedCount > 0: - for f,w in newFeatures.iteritems(): - newFeatures[f] = float(w)/addedCount - #""" - """ - deleteF = [] - for f,w in newFeatures.iteritems(): - if w < 0.1: - deleteF.append(f) - for f in deleteF: - del newFeatures[f] - """ - #print 'fold',len(features) - #print 'fnew',len(newFeatures) - #return dict(newFeatures) - return features - -if __name__ == "__main__": - pass - - \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/KNN.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/KNN.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,99 +0,0 @@ -''' -Created on Aug 21, 2013 - -@author: daniel -''' - -from cPickle import dump,load -from numpy import array -from math import sqrt,log - -def cosine(f1,f2): - f1Norm = 0.0 - for f in f1.keys(): - f1Norm += f1[f] * f1[f] - #assert f1Norm = sum(map(lambda x,y: x*y,f1.itervalues(),f1.itervalues())) - f1Norm = sqrt(f1Norm) - - f2Norm = 0.0 - for f in f2.keys(): - f2Norm += f2[f] * f2[f] - f2Norm = sqrt(f2Norm) - - dotProduct = 0.0 - featureIntersection = set(f1.keys()) & set(f2.keys()) - for f in featureIntersection: - dotProduct += f1[f] * f2[f] - cosine = dotProduct / (f1Norm * f2Norm) - return 1.0 - cosine - -def euclidean(f1,f2): - diffSum = 0.0 - featureUnion = set(f1.keys()) | set(f2.keys()) - for f in featureUnion: - try: - f1Val = f1[f] - except: - f1Val = 0.0 - try: - f2Val = f2[f] - except: - f2Val = 0.0 - diff = f1Val - f2Val - diffSum += diff * diff - #if f in f1.keys(): - # diffSum += log(2+self.pointCount/self.featureCounts[f]) * diff * diff - #else: - # diffSum += diff * diff - #print diffSum,f1,f2 - return diffSum - -class KNN(object): - ''' - A basic KNN ranker. - ''' - - def __init__(self,dicts,metric=cosine): - ''' - Constructor - ''' - self.points = dicts.featureDict - self.metric = metric - - def initializeModel(self,_trainData,_dicts): - """ - Build basic model from training data. - """ - pass - - def update(self,dataPoint,features,dependencies): - assert self.points[dataPoint] == features - - def overwrite(self,problemId,newDependencies,dicts): - # Taken care of by dicts - pass - - def delete(self,dataPoint,features,dependencies): - # Taken care of by dicts - pass - - def predict(self,features,accessibles,dicts): - predictions = map(lambda x: self.metric(features,self.points[x]),accessibles) - predictions = array(predictions) - perm = predictions.argsort() - return array(accessibles)[perm],predictions[perm] - - def save(self,fileName): - OStream = open(fileName, 'wb') - dump((self.points,self.metric),OStream) - OStream.close() - - def load(self,fileName): - OStream = open(fileName, 'rb') - self.points,self.metric = load(OStream) - OStream.close() - -if __name__ == '__main__': - pass - - \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/KNNs.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/KNNs.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,105 +0,0 @@ -''' -Created on Aug 21, 2013 - -@author: daniel -''' - -from math import log -from KNN import KNN,cosine -from numpy import array - -class KNNAdaptPointFeatures(KNN): - - def __init__(self,dicts,metric=cosine,alpha = 0.05): - self.points = dicts.featureDict - self.metric = self.euclidean - self.alpha = alpha - self.count = 0 - self.featureCount = {} - - def initializeModel(self,trainData,dicts): - """ - Build basic model from training data. - """ - IS = open(dicts.accFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - nameId = dicts.nameIdDict[name] - features = dicts.featureDict[nameId] - dependencies = dicts.dependenciesDict[nameId] - self.update(nameId, features, dependencies) - IS.close() - - def update(self,dataPoint,features,dependencies): - self.count += 1 - for f in features.iterkeys(): - try: - self.featureCount[f] += 1 - except: - self.featureCount[f] = 1 - for d in dependencies: - dFeatures = self.points[d] - featureUnion = set(dFeatures.keys()) | set(features.keys()) - for f in featureUnion: - try: - pVal = features[f] - except: - pVal = 0.0 - try: - dVal = dFeatures[f] - except: - dVal = 0.0 - newDVal = dVal + self.alpha * (pVal - dVal) - dFeatures[f] = newDVal - - def euclidean(self,f1,f2): - diffSum = 0.0 - f1Set = set(f1.keys()) - featureUnion = f1Set | set(f2.keys()) - for f in featureUnion: - if not self.featureCount.has_key(f): - continue - if self.featureCount[f] == 1: - continue - try: - f1Val = f1[f] - except: - f1Val = 0.0 - try: - f2Val = f2[f] - except: - f2Val = 0.0 - diff = f1Val - f2Val - diffSum += diff * diff - if f in f1Set: - diffSum += log(2+self.count/self.featureCount[f]) * diff * diff - else: - diffSum += diff * diff - #print diffSum,f1,f2 - return diffSum - -class KNNUrban(KNN): - def __init__(self,dicts,metric=cosine,nrOfNeighbours = 40): - self.points = dicts.featureDict - self.metric = metric - self.nrOfNeighbours = nrOfNeighbours # Ignored at the moment - - def predict(self,features,accessibles,dicts): - predictions = map(lambda x: self.metric(features,self.points[x]),accessibles) - pDict = dict(zip(accessibles,predictions)) - for a,p in zip(accessibles,predictions): - aDeps = dicts.dependenciesDict[a] - for d in aDeps: - pDict[d] -= p - predictions = [] - names = [] - for n,p in pDict.items(): - predictions.append(p) - names.append(n) - predictions = array(predictions) - perm = predictions.argsort() - return array(names)[perm],predictions[perm] - - - \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/argparse.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/argparse.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,2357 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/argparse.py -# -# Argument parser. See copyright notice below. - -# -*- coding: utf-8 -*- - -# Copyright (C) 2006-2009 Steven J. Bethard . -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy -# of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Command-line parsing library - -This module is an optparse-inspired command-line parsing library that: - - - handles both optional and positional arguments - - produces highly informative usage messages - - supports parsers that dispatch to sub-parsers - -The following is a simple usage example that sums integers from the -command-line and writes the result to a file:: - - parser = argparse.ArgumentParser( - description='sum the integers at the command line') - parser.add_argument( - 'integers', metavar='int', nargs='+', type=int, - help='an integer to be summed') - parser.add_argument( - '--log', default=sys.stdout, type=argparse.FileType('w'), - help='the file where the sum should be written') - args = parser.parse_args() - args.log.write('%s' % sum(args.integers)) - args.log.close() - -The module contains the following public classes: - - - ArgumentParser -- The main entry point for command-line parsing. As the - example above shows, the add_argument() method is used to populate - the parser with actions for optional and positional arguments. Then - the parse_args() method is invoked to convert the args at the - command-line into an object with attributes. - - - ArgumentError -- The exception raised by ArgumentParser objects when - there are errors with the parser's actions. Errors raised while - parsing the command-line are caught by ArgumentParser and emitted - as command-line messages. - - - FileType -- A factory for defining types of files to be created. As the - example above shows, instances of FileType are typically passed as - the type= argument of add_argument() calls. - - - Action -- The base class for parser actions. Typically actions are - selected by passing strings like 'store_true' or 'append_const' to - the action= argument of add_argument(). However, for greater - customization of ArgumentParser actions, subclasses of Action may - be defined and passed as the action= argument. - - - HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter, - ArgumentDefaultsHelpFormatter -- Formatter classes which - may be passed as the formatter_class= argument to the - ArgumentParser constructor. HelpFormatter is the default, - RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser - not to change the formatting for help text, and - ArgumentDefaultsHelpFormatter adds information about argument defaults - to the help. - -All other classes in this module are considered implementation details. -(Also note that HelpFormatter and RawDescriptionHelpFormatter are only -considered public as object names -- the API of the formatter objects is -still considered an implementation detail.) -""" - -__version__ = '1.1' -__all__ = [ - 'ArgumentParser', - 'ArgumentError', - 'Namespace', - 'Action', - 'FileType', - 'HelpFormatter', - 'RawDescriptionHelpFormatter', - 'RawTextHelpFormatter', - 'ArgumentDefaultsHelpFormatter', -] - - -import copy as _copy -import os as _os -import re as _re -import sys as _sys -import textwrap as _textwrap - -from gettext import gettext as _ - -try: - _set = set -except NameError: - from sets import Set as _set - -try: - _basestring = basestring -except NameError: - _basestring = str - -try: - _sorted = sorted -except NameError: - - def _sorted(iterable, reverse=False): - result = list(iterable) - result.sort() - if reverse: - result.reverse() - return result - - -def _callable(obj): - return hasattr(obj, '__call__') or hasattr(obj, '__bases__') - -# silence Python 2.6 buggy warnings about Exception.message -if _sys.version_info[:2] == (2, 6): - import warnings - warnings.filterwarnings( - action='ignore', - message='BaseException.message has been deprecated as of Python 2.6', - category=DeprecationWarning, - module='argparse') - - -SUPPRESS = '==SUPPRESS==' - -OPTIONAL = '?' -ZERO_OR_MORE = '*' -ONE_OR_MORE = '+' -PARSER = 'A...' -REMAINDER = '...' - -# ============================= -# Utility functions and classes -# ============================= - -class _AttributeHolder(object): - """Abstract base class that provides __repr__. - - The __repr__ method returns a string in the format:: - ClassName(attr=name, attr=name, ...) - The attributes are determined either by a class-level attribute, - '_kwarg_names', or by inspecting the instance __dict__. - """ - - def __repr__(self): - type_name = type(self).__name__ - arg_strings = [] - for arg in self._get_args(): - arg_strings.append(repr(arg)) - for name, value in self._get_kwargs(): - arg_strings.append('%s=%r' % (name, value)) - return '%s(%s)' % (type_name, ', '.join(arg_strings)) - - def _get_kwargs(self): - return _sorted(self.__dict__.items()) - - def _get_args(self): - return [] - - -def _ensure_value(namespace, name, value): - if getattr(namespace, name, None) is None: - setattr(namespace, name, value) - return getattr(namespace, name) - - -# =============== -# Formatting Help -# =============== - -class HelpFormatter(object): - """Formatter for generating usage messages and argument help strings. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def __init__(self, - prog, - indent_increment=2, - max_help_position=24, - width=None): - - # default setting for width - if width is None: - try: - width = int(_os.environ['COLUMNS']) - except (KeyError, ValueError): - width = 80 - width -= 2 - - self._prog = prog - self._indent_increment = indent_increment - self._max_help_position = max_help_position - self._width = width - - self._current_indent = 0 - self._level = 0 - self._action_max_length = 0 - - self._root_section = self._Section(self, None) - self._current_section = self._root_section - - self._whitespace_matcher = _re.compile(r'\s+') - self._long_break_matcher = _re.compile(r'\n\n\n+') - - # =============================== - # Section and indentation methods - # =============================== - def _indent(self): - self._current_indent += self._indent_increment - self._level += 1 - - def _dedent(self): - self._current_indent -= self._indent_increment - assert self._current_indent >= 0, 'Indent decreased below 0.' - self._level -= 1 - - class _Section(object): - - def __init__(self, formatter, parent, heading=None): - self.formatter = formatter - self.parent = parent - self.heading = heading - self.items = [] - - def format_help(self): - # format the indented section - if self.parent is not None: - self.formatter._indent() - join = self.formatter._join_parts - for func, args in self.items: - func(*args) - item_help = join([func(*args) for func, args in self.items]) - if self.parent is not None: - self.formatter._dedent() - - # return nothing if the section was empty - if not item_help: - return '' - - # add the heading if the section was non-empty - if self.heading is not SUPPRESS and self.heading is not None: - current_indent = self.formatter._current_indent - heading = '%*s%s:\n' % (current_indent, '', self.heading) - else: - heading = '' - - # join the section-initial newline, the heading and the help - return join(['\n', heading, item_help, '\n']) - - def _add_item(self, func, args): - self._current_section.items.append((func, args)) - - # ======================== - # Message building methods - # ======================== - def start_section(self, heading): - self._indent() - section = self._Section(self, self._current_section, heading) - self._add_item(section.format_help, []) - self._current_section = section - - def end_section(self): - self._current_section = self._current_section.parent - self._dedent() - - def add_text(self, text): - if text is not SUPPRESS and text is not None: - self._add_item(self._format_text, [text]) - - def add_usage(self, usage, actions, groups, prefix=None): - if usage is not SUPPRESS: - args = usage, actions, groups, prefix - self._add_item(self._format_usage, args) - - def add_argument(self, action): - if action.help is not SUPPRESS: - - # find all invocations - get_invocation = self._format_action_invocation - invocations = [get_invocation(action)] - for subaction in self._iter_indented_subactions(action): - invocations.append(get_invocation(subaction)) - - # update the maximum item length - invocation_length = max([len(s) for s in invocations]) - action_length = invocation_length + self._current_indent - self._action_max_length = max(self._action_max_length, - action_length) - - # add the item to the list - self._add_item(self._format_action, [action]) - - def add_arguments(self, actions): - for action in actions: - self.add_argument(action) - - # ======================= - # Help-formatting methods - # ======================= - def format_help(self): - help = self._root_section.format_help() - if help: - help = self._long_break_matcher.sub('\n\n', help) - help = help.strip('\n') + '\n' - return help - - def _join_parts(self, part_strings): - return ''.join([part - for part in part_strings - if part and part is not SUPPRESS]) - - def _format_usage(self, usage, actions, groups, prefix): - if prefix is None: - prefix = _('usage: ') - - # if usage is specified, use that - if usage is not None: - usage = usage % dict(prog=self._prog) - - # if no optionals or positionals are available, usage is just prog - elif usage is None and not actions: - usage = '%(prog)s' % dict(prog=self._prog) - - # if optionals and positionals are available, calculate usage - elif usage is None: - prog = '%(prog)s' % dict(prog=self._prog) - - # split optionals from positionals - optionals = [] - positionals = [] - for action in actions: - if action.option_strings: - optionals.append(action) - else: - positionals.append(action) - - # build full usage string - format = self._format_actions_usage - action_usage = format(optionals + positionals, groups) - usage = ' '.join([s for s in [prog, action_usage] if s]) - - # wrap the usage parts if it's too long - text_width = self._width - self._current_indent - if len(prefix) + len(usage) > text_width: - - # break usage into wrappable parts - part_regexp = r'\(.*?\)+|\[.*?\]+|\S+' - opt_usage = format(optionals, groups) - pos_usage = format(positionals, groups) - opt_parts = _re.findall(part_regexp, opt_usage) - pos_parts = _re.findall(part_regexp, pos_usage) - assert ' '.join(opt_parts) == opt_usage - assert ' '.join(pos_parts) == pos_usage - - # helper for wrapping lines - def get_lines(parts, indent, prefix=None): - lines = [] - line = [] - if prefix is not None: - line_len = len(prefix) - 1 - else: - line_len = len(indent) - 1 - for part in parts: - if line_len + 1 + len(part) > text_width: - lines.append(indent + ' '.join(line)) - line = [] - line_len = len(indent) - 1 - line.append(part) - line_len += len(part) + 1 - if line: - lines.append(indent + ' '.join(line)) - if prefix is not None: - lines[0] = lines[0][len(indent):] - return lines - - # if prog is short, follow it with optionals or positionals - if len(prefix) + len(prog) <= 0.75 * text_width: - indent = ' ' * (len(prefix) + len(prog) + 1) - if opt_parts: - lines = get_lines([prog] + opt_parts, indent, prefix) - lines.extend(get_lines(pos_parts, indent)) - elif pos_parts: - lines = get_lines([prog] + pos_parts, indent, prefix) - else: - lines = [prog] - - # if prog is long, put it on its own line - else: - indent = ' ' * len(prefix) - parts = opt_parts + pos_parts - lines = get_lines(parts, indent) - if len(lines) > 1: - lines = [] - lines.extend(get_lines(opt_parts, indent)) - lines.extend(get_lines(pos_parts, indent)) - lines = [prog] + lines - - # join lines into usage - usage = '\n'.join(lines) - - # prefix with 'usage:' - return '%s%s\n\n' % (prefix, usage) - - def _format_actions_usage(self, actions, groups): - # find group indices and identify actions in groups - group_actions = _set() - inserts = {} - for group in groups: - try: - start = actions.index(group._group_actions[0]) - except ValueError: - continue - else: - end = start + len(group._group_actions) - if actions[start:end] == group._group_actions: - for action in group._group_actions: - group_actions.add(action) - if not group.required: - inserts[start] = '[' - inserts[end] = ']' - else: - inserts[start] = '(' - inserts[end] = ')' - for i in range(start + 1, end): - inserts[i] = '|' - - # collect all actions format strings - parts = [] - for i, action in enumerate(actions): - - # suppressed arguments are marked with None - # remove | separators for suppressed arguments - if action.help is SUPPRESS: - parts.append(None) - if inserts.get(i) == '|': - inserts.pop(i) - elif inserts.get(i + 1) == '|': - inserts.pop(i + 1) - - # produce all arg strings - elif not action.option_strings: - part = self._format_args(action, action.dest) - - # if it's in a group, strip the outer [] - if action in group_actions: - if part[0] == '[' and part[-1] == ']': - part = part[1:-1] - - # add the action string to the list - parts.append(part) - - # produce the first way to invoke the option in brackets - else: - option_string = action.option_strings[0] - - # if the Optional doesn't take a value, format is: - # -s or --long - if action.nargs == 0: - part = '%s' % option_string - - # if the Optional takes a value, format is: - # -s ARGS or --long ARGS - else: - default = action.dest.upper() - args_string = self._format_args(action, default) - part = '%s %s' % (option_string, args_string) - - # make it look optional if it's not required or in a group - if not action.required and action not in group_actions: - part = '[%s]' % part - - # add the action string to the list - parts.append(part) - - # insert things at the necessary indices - for i in _sorted(inserts, reverse=True): - parts[i:i] = [inserts[i]] - - # join all the action items with spaces - text = ' '.join([item for item in parts if item is not None]) - - # clean up separators for mutually exclusive groups - open = r'[\[(]' - close = r'[\])]' - text = _re.sub(r'(%s) ' % open, r'\1', text) - text = _re.sub(r' (%s)' % close, r'\1', text) - text = _re.sub(r'%s *%s' % (open, close), r'', text) - text = _re.sub(r'\(([^|]*)\)', r'\1', text) - text = text.strip() - - # return the text - return text - - def _format_text(self, text): - if '%(prog)' in text: - text = text % dict(prog=self._prog) - text_width = self._width - self._current_indent - indent = ' ' * self._current_indent - return self._fill_text(text, text_width, indent) + '\n\n' - - def _format_action(self, action): - # determine the required width and the entry label - help_position = min(self._action_max_length + 2, - self._max_help_position) - help_width = self._width - help_position - action_width = help_position - self._current_indent - 2 - action_header = self._format_action_invocation(action) - - # ho nelp; start on same line and add a final newline - if not action.help: - tup = self._current_indent, '', action_header - action_header = '%*s%s\n' % tup - - # short action name; start on the same line and pad two spaces - elif len(action_header) <= action_width: - tup = self._current_indent, '', action_width, action_header - action_header = '%*s%-*s ' % tup - indent_first = 0 - - # long action name; start on the next line - else: - tup = self._current_indent, '', action_header - action_header = '%*s%s\n' % tup - indent_first = help_position - - # collect the pieces of the action help - parts = [action_header] - - # if there was help for the action, add lines of help text - if action.help: - help_text = self._expand_help(action) - help_lines = self._split_lines(help_text, help_width) - parts.append('%*s%s\n' % (indent_first, '', help_lines[0])) - for line in help_lines[1:]: - parts.append('%*s%s\n' % (help_position, '', line)) - - # or add a newline if the description doesn't end with one - elif not action_header.endswith('\n'): - parts.append('\n') - - # if there are any sub-actions, add their help as well - for subaction in self._iter_indented_subactions(action): - parts.append(self._format_action(subaction)) - - # return a single string - return self._join_parts(parts) - - def _format_action_invocation(self, action): - if not action.option_strings: - metavar, = self._metavar_formatter(action, action.dest)(1) - return metavar - - else: - parts = [] - - # if the Optional doesn't take a value, format is: - # -s, --long - if action.nargs == 0: - parts.extend(action.option_strings) - - # if the Optional takes a value, format is: - # -s ARGS, --long ARGS - else: - default = action.dest.upper() - args_string = self._format_args(action, default) - for option_string in action.option_strings: - parts.append('%s %s' % (option_string, args_string)) - - return ', '.join(parts) - - def _metavar_formatter(self, action, default_metavar): - if action.metavar is not None: - result = action.metavar - elif action.choices is not None: - choice_strs = [str(choice) for choice in action.choices] - result = '{%s}' % ','.join(choice_strs) - else: - result = default_metavar - - def format(tuple_size): - if isinstance(result, tuple): - return result - else: - return (result, ) * tuple_size - return format - - def _format_args(self, action, default_metavar): - get_metavar = self._metavar_formatter(action, default_metavar) - if action.nargs is None: - result = '%s' % get_metavar(1) - elif action.nargs == OPTIONAL: - result = '[%s]' % get_metavar(1) - elif action.nargs == ZERO_OR_MORE: - result = '[%s [%s ...]]' % get_metavar(2) - elif action.nargs == ONE_OR_MORE: - result = '%s [%s ...]' % get_metavar(2) - elif action.nargs == REMAINDER: - result = '...' - elif action.nargs == PARSER: - result = '%s ...' % get_metavar(1) - else: - formats = ['%s' for _ in range(action.nargs)] - result = ' '.join(formats) % get_metavar(action.nargs) - return result - - def _expand_help(self, action): - params = dict(vars(action), prog=self._prog) - for name in list(params): - if params[name] is SUPPRESS: - del params[name] - for name in list(params): - if hasattr(params[name], '__name__'): - params[name] = params[name].__name__ - if params.get('choices') is not None: - choices_str = ', '.join([str(c) for c in params['choices']]) - params['choices'] = choices_str - return self._get_help_string(action) % params - - def _iter_indented_subactions(self, action): - try: - get_subactions = action._get_subactions - except AttributeError: - pass - else: - self._indent() - for subaction in get_subactions(): - yield subaction - self._dedent() - - def _split_lines(self, text, width): - text = self._whitespace_matcher.sub(' ', text).strip() - return _textwrap.wrap(text, width) - - def _fill_text(self, text, width, indent): - text = self._whitespace_matcher.sub(' ', text).strip() - return _textwrap.fill(text, width, initial_indent=indent, - subsequent_indent=indent) - - def _get_help_string(self, action): - return action.help - - -class RawDescriptionHelpFormatter(HelpFormatter): - """Help message formatter which retains any formatting in descriptions. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _fill_text(self, text, width, indent): - return ''.join([indent + line for line in text.splitlines(True)]) - - -class RawTextHelpFormatter(RawDescriptionHelpFormatter): - """Help message formatter which retains formatting of all help text. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _split_lines(self, text, width): - return text.splitlines() - - -class ArgumentDefaultsHelpFormatter(HelpFormatter): - """Help message formatter which adds default values to argument help. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _get_help_string(self, action): - help = action.help - if '%(default)' not in action.help: - if action.default is not SUPPRESS: - defaulting_nargs = [OPTIONAL, ZERO_OR_MORE] - if action.option_strings or action.nargs in defaulting_nargs: - help += ' (default: %(default)s)' - return help - - -# ===================== -# Options and Arguments -# ===================== - -def _get_action_name(argument): - if argument is None: - return None - elif argument.option_strings: - return '/'.join(argument.option_strings) - elif argument.metavar not in (None, SUPPRESS): - return argument.metavar - elif argument.dest not in (None, SUPPRESS): - return argument.dest - else: - return None - - -class ArgumentError(Exception): - """An error from creating or using an argument (optional or positional). - - The string value of this exception is the message, augmented with - information about the argument that caused it. - """ - - def __init__(self, argument, message): - self.argument_name = _get_action_name(argument) - self.message = message - - def __str__(self): - if self.argument_name is None: - format = '%(message)s' - else: - format = 'argument %(argument_name)s: %(message)s' - return format % dict(message=self.message, - argument_name=self.argument_name) - - -class ArgumentTypeError(Exception): - """An error from trying to convert a command line string to a type.""" - pass - - -# ============== -# Action classes -# ============== - -class Action(_AttributeHolder): - """Information about how to convert command line strings to Python objects. - - Action objects are used by an ArgumentParser to represent the information - needed to parse a single argument from one or more strings from the - command line. The keyword arguments to the Action constructor are also - all attributes of Action instances. - - Keyword Arguments: - - - option_strings -- A list of command-line option strings which - should be associated with this action. - - - dest -- The name of the attribute to hold the created object(s) - - - nargs -- The number of command-line arguments that should be - consumed. By default, one argument will be consumed and a single - value will be produced. Other values include: - - N (an integer) consumes N arguments (and produces a list) - - '?' consumes zero or one arguments - - '*' consumes zero or more arguments (and produces a list) - - '+' consumes one or more arguments (and produces a list) - Note that the difference between the default and nargs=1 is that - with the default, a single value will be produced, while with - nargs=1, a list containing a single value will be produced. - - - const -- The value to be produced if the option is specified and the - option uses an action that takes no values. - - - default -- The value to be produced if the option is not specified. - - - type -- The type which the command-line arguments should be converted - to, should be one of 'string', 'int', 'float', 'complex' or a - callable object that accepts a single string argument. If None, - 'string' is assumed. - - - choices -- A container of values that should be allowed. If not None, - after a command-line argument has been converted to the appropriate - type, an exception will be raised if it is not a member of this - collection. - - - required -- True if the action must always be specified at the - command line. This is only meaningful for optional command-line - arguments. - - - help -- The help string describing the argument. - - - metavar -- The name to be used for the option's argument with the - help string. If None, the 'dest' value will be used as the name. - """ - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - self.option_strings = option_strings - self.dest = dest - self.nargs = nargs - self.const = const - self.default = default - self.type = type - self.choices = choices - self.required = required - self.help = help - self.metavar = metavar - - def _get_kwargs(self): - names = [ - 'option_strings', - 'dest', - 'nargs', - 'const', - 'default', - 'type', - 'choices', - 'help', - 'metavar', - ] - return [(name, getattr(self, name)) for name in names] - - def __call__(self, parser, namespace, values, option_string=None): - raise NotImplementedError(_('.__call__() not defined')) - - -class _StoreAction(Action): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - if nargs == 0: - raise ValueError('nargs for store actions must be > 0; if you ' - 'have nothing to store, actions such as store ' - 'true or store const may be more appropriate') - if const is not None and nargs != OPTIONAL: - raise ValueError('nargs must be %r to supply const' % OPTIONAL) - super(_StoreAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=nargs, - const=const, - default=default, - type=type, - choices=choices, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, values) - - -class _StoreConstAction(Action): - - def __init__(self, - option_strings, - dest, - const, - default=None, - required=False, - help=None, - metavar=None): - super(_StoreConstAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=const, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, self.const) - - -class _StoreTrueAction(_StoreConstAction): - - def __init__(self, - option_strings, - dest, - default=False, - required=False, - help=None): - super(_StoreTrueAction, self).__init__( - option_strings=option_strings, - dest=dest, - const=True, - default=default, - required=required, - help=help) - - -class _StoreFalseAction(_StoreConstAction): - - def __init__(self, - option_strings, - dest, - default=True, - required=False, - help=None): - super(_StoreFalseAction, self).__init__( - option_strings=option_strings, - dest=dest, - const=False, - default=default, - required=required, - help=help) - - -class _AppendAction(Action): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - if nargs == 0: - raise ValueError('nargs for append actions must be > 0; if arg ' - 'strings are not supplying the value to append, ' - 'the append const action may be more appropriate') - if const is not None and nargs != OPTIONAL: - raise ValueError('nargs must be %r to supply const' % OPTIONAL) - super(_AppendAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=nargs, - const=const, - default=default, - type=type, - choices=choices, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - items = _copy.copy(_ensure_value(namespace, self.dest, [])) - items.append(values) - setattr(namespace, self.dest, items) - - -class _AppendConstAction(Action): - - def __init__(self, - option_strings, - dest, - const, - default=None, - required=False, - help=None, - metavar=None): - super(_AppendConstAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=const, - default=default, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - items = _copy.copy(_ensure_value(namespace, self.dest, [])) - items.append(self.const) - setattr(namespace, self.dest, items) - - -class _CountAction(Action): - - def __init__(self, - option_strings, - dest, - default=None, - required=False, - help=None): - super(_CountAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - new_count = _ensure_value(namespace, self.dest, 0) + 1 - setattr(namespace, self.dest, new_count) - - -class _HelpAction(Action): - - def __init__(self, - option_strings, - dest=SUPPRESS, - default=SUPPRESS, - help=None): - super(_HelpAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - parser.print_help() - parser.exit() - - -class _VersionAction(Action): - - def __init__(self, - option_strings, - version=None, - dest=SUPPRESS, - default=SUPPRESS, - help=None): - super(_VersionAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help) - self.version = version - - def __call__(self, parser, namespace, values, option_string=None): - version = self.version - if version is None: - version = parser.version - formatter = parser._get_formatter() - formatter.add_text(version) - parser.exit(message=formatter.format_help()) - - -class _SubParsersAction(Action): - - class _ChoicesPseudoAction(Action): - - def __init__(self, name, help): - sup = super(_SubParsersAction._ChoicesPseudoAction, self) - sup.__init__(option_strings=[], dest=name, help=help) - - def __init__(self, - option_strings, - prog, - parser_class, - dest=SUPPRESS, - help=None, - metavar=None): - - self._prog_prefix = prog - self._parser_class = parser_class - self._name_parser_map = {} - self._choices_actions = [] - - super(_SubParsersAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=PARSER, - choices=self._name_parser_map, - help=help, - metavar=metavar) - - def add_parser(self, name, **kwargs): - # set prog from the existing prefix - if kwargs.get('prog') is None: - kwargs['prog'] = '%s %s' % (self._prog_prefix, name) - - # create a pseudo-action to hold the choice help - if 'help' in kwargs: - help = kwargs.pop('help') - choice_action = self._ChoicesPseudoAction(name, help) - self._choices_actions.append(choice_action) - - # create the parser and add it to the map - parser = self._parser_class(**kwargs) - self._name_parser_map[name] = parser - return parser - - def _get_subactions(self): - return self._choices_actions - - def __call__(self, parser, namespace, values, option_string=None): - parser_name = values[0] - arg_strings = values[1:] - - # set the parser name if requested - if self.dest is not SUPPRESS: - setattr(namespace, self.dest, parser_name) - - # select the parser - try: - parser = self._name_parser_map[parser_name] - except KeyError: - tup = parser_name, ', '.join(self._name_parser_map) - msg = _('unknown parser %r (choices: %s)' % tup) - raise ArgumentError(self, msg) - - # parse all the remaining options into the namespace - parser.parse_args(arg_strings, namespace) - - -# ============== -# Type classes -# ============== - -class FileType(object): - """Factory for creating file object types - - Instances of FileType are typically passed as type= arguments to the - ArgumentParser add_argument() method. - - Keyword Arguments: - - mode -- A string indicating how the file is to be opened. Accepts the - same values as the builtin open() function. - - bufsize -- The file's desired buffer size. Accepts the same values as - the builtin open() function. - """ - - def __init__(self, mode='r', bufsize=None): - self._mode = mode - self._bufsize = bufsize - - def __call__(self, string): - # the special argument "-" means sys.std{in,out} - if string == '-': - if 'r' in self._mode: - return _sys.stdin - elif 'w' in self._mode: - return _sys.stdout - else: - msg = _('argument "-" with mode %r' % self._mode) - raise ValueError(msg) - - # all other arguments are used as file names - if self._bufsize: - return open(string, self._mode, self._bufsize) - else: - return open(string, self._mode) - - def __repr__(self): - args = [self._mode, self._bufsize] - args_str = ', '.join([repr(arg) for arg in args if arg is not None]) - return '%s(%s)' % (type(self).__name__, args_str) - -# =========================== -# Optional and Positional Parsing -# =========================== - -class Namespace(_AttributeHolder): - """Simple object for storing attributes. - - Implements equality by attribute names and values, and provides a simple - string representation. - """ - - def __init__(self, **kwargs): - for name in kwargs: - setattr(self, name, kwargs[name]) - - def __eq__(self, other): - return vars(self) == vars(other) - - def __ne__(self, other): - return not (self == other) - - def __contains__(self, key): - return key in self.__dict__ - - -class _ActionsContainer(object): - - def __init__(self, - description, - prefix_chars, - argument_default, - conflict_handler): - super(_ActionsContainer, self).__init__() - - self.description = description - self.argument_default = argument_default - self.prefix_chars = prefix_chars - self.conflict_handler = conflict_handler - - # set up registries - self._registries = {} - - # register actions - self.register('action', None, _StoreAction) - self.register('action', 'store', _StoreAction) - self.register('action', 'store_const', _StoreConstAction) - self.register('action', 'store_true', _StoreTrueAction) - self.register('action', 'store_false', _StoreFalseAction) - self.register('action', 'append', _AppendAction) - self.register('action', 'append_const', _AppendConstAction) - self.register('action', 'count', _CountAction) - self.register('action', 'help', _HelpAction) - self.register('action', 'version', _VersionAction) - self.register('action', 'parsers', _SubParsersAction) - - # raise an exception if the conflict handler is invalid - self._get_handler() - - # action storage - self._actions = [] - self._option_string_actions = {} - - # groups - self._action_groups = [] - self._mutually_exclusive_groups = [] - - # defaults storage - self._defaults = {} - - # determines whether an "option" looks like a negative number - self._negative_number_matcher = _re.compile(r'^-\d+$|^-\d*\.\d+$') - - # whether or not there are any optionals that look like negative - # numbers -- uses a list so it can be shared and edited - self._has_negative_number_optionals = [] - - # ==================== - # Registration methods - # ==================== - def register(self, registry_name, value, object): - registry = self._registries.setdefault(registry_name, {}) - registry[value] = object - - def _registry_get(self, registry_name, value, default=None): - return self._registries[registry_name].get(value, default) - - # ================================== - # Namespace default accessor methods - # ================================== - def set_defaults(self, **kwargs): - self._defaults.update(kwargs) - - # if these defaults match any existing arguments, replace - # the previous default on the object with the new one - for action in self._actions: - if action.dest in kwargs: - action.default = kwargs[action.dest] - - def get_default(self, dest): - for action in self._actions: - if action.dest == dest and action.default is not None: - return action.default - return self._defaults.get(dest, None) - - - # ======================= - # Adding argument actions - # ======================= - def add_argument(self, *args, **kwargs): - """ - add_argument(dest, ..., name=value, ...) - add_argument(option_string, option_string, ..., name=value, ...) - """ - - # if no positional args are supplied or only one is supplied and - # it doesn't look like an option string, parse a positional - # argument - chars = self.prefix_chars - if not args or len(args) == 1 and args[0][0] not in chars: - if args and 'dest' in kwargs: - raise ValueError('dest supplied twice for positional argument') - kwargs = self._get_positional_kwargs(*args, **kwargs) - - # otherwise, we're adding an optional argument - else: - kwargs = self._get_optional_kwargs(*args, **kwargs) - - # if no default was supplied, use the parser-level default - if 'default' not in kwargs: - dest = kwargs['dest'] - if dest in self._defaults: - kwargs['default'] = self._defaults[dest] - elif self.argument_default is not None: - kwargs['default'] = self.argument_default - - # create the action object, and add it to the parser - action_class = self._pop_action_class(kwargs) - if not _callable(action_class): - raise ValueError('unknown action "%s"' % action_class) - action = action_class(**kwargs) - - # raise an error if the action type is not callable - type_func = self._registry_get('type', action.type, action.type) - if not _callable(type_func): - raise ValueError('%r is not callable' % type_func) - - return self._add_action(action) - - def add_argument_group(self, *args, **kwargs): - group = _ArgumentGroup(self, *args, **kwargs) - self._action_groups.append(group) - return group - - def add_mutually_exclusive_group(self, **kwargs): - group = _MutuallyExclusiveGroup(self, **kwargs) - self._mutually_exclusive_groups.append(group) - return group - - def _add_action(self, action): - # resolve any conflicts - self._check_conflict(action) - - # add to actions list - self._actions.append(action) - action.container = self - - # index the action by any option strings it has - for option_string in action.option_strings: - self._option_string_actions[option_string] = action - - # set the flag if any option strings look like negative numbers - for option_string in action.option_strings: - if self._negative_number_matcher.match(option_string): - if not self._has_negative_number_optionals: - self._has_negative_number_optionals.append(True) - - # return the created action - return action - - def _remove_action(self, action): - self._actions.remove(action) - - def _add_container_actions(self, container): - # collect groups by titles - title_group_map = {} - for group in self._action_groups: - if group.title in title_group_map: - msg = _('cannot merge actions - two groups are named %r') - raise ValueError(msg % (group.title)) - title_group_map[group.title] = group - - # map each action to its group - group_map = {} - for group in container._action_groups: - - # if a group with the title exists, use that, otherwise - # create a new group matching the container's group - if group.title not in title_group_map: - title_group_map[group.title] = self.add_argument_group( - title=group.title, - description=group.description, - conflict_handler=group.conflict_handler) - - # map the actions to their new group - for action in group._group_actions: - group_map[action] = title_group_map[group.title] - - # add container's mutually exclusive groups - # NOTE: if add_mutually_exclusive_group ever gains title= and - # description= then this code will need to be expanded as above - for group in container._mutually_exclusive_groups: - mutex_group = self.add_mutually_exclusive_group( - required=group.required) - - # map the actions to their new mutex group - for action in group._group_actions: - group_map[action] = mutex_group - - # add all actions to this container or their group - for action in container._actions: - group_map.get(action, self)._add_action(action) - - def _get_positional_kwargs(self, dest, **kwargs): - # make sure required is not specified - if 'required' in kwargs: - msg = _("'required' is an invalid argument for positionals") - raise TypeError(msg) - - # mark positional arguments as required if at least one is - # always required - if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]: - kwargs['required'] = True - if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs: - kwargs['required'] = True - - # return the keyword arguments with no option strings - return dict(kwargs, dest=dest, option_strings=[]) - - def _get_optional_kwargs(self, *args, **kwargs): - # determine short and long option strings - option_strings = [] - long_option_strings = [] - for option_string in args: - # error on strings that don't start with an appropriate prefix - if not option_string[0] in self.prefix_chars: - msg = _('invalid option string %r: ' - 'must start with a character %r') - tup = option_string, self.prefix_chars - raise ValueError(msg % tup) - - # strings starting with two prefix characters are long options - option_strings.append(option_string) - if option_string[0] in self.prefix_chars: - if len(option_string) > 1: - if option_string[1] in self.prefix_chars: - long_option_strings.append(option_string) - - # infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x' - dest = kwargs.pop('dest', None) - if dest is None: - if long_option_strings: - dest_option_string = long_option_strings[0] - else: - dest_option_string = option_strings[0] - dest = dest_option_string.lstrip(self.prefix_chars) - if not dest: - msg = _('dest= is required for options like %r') - raise ValueError(msg % option_string) - dest = dest.replace('-', '_') - - # return the updated keyword arguments - return dict(kwargs, dest=dest, option_strings=option_strings) - - def _pop_action_class(self, kwargs, default=None): - action = kwargs.pop('action', default) - return self._registry_get('action', action, action) - - def _get_handler(self): - # determine function from conflict handler string - handler_func_name = '_handle_conflict_%s' % self.conflict_handler - try: - return getattr(self, handler_func_name) - except AttributeError: - msg = _('invalid conflict_resolution value: %r') - raise ValueError(msg % self.conflict_handler) - - def _check_conflict(self, action): - - # find all options that conflict with this option - confl_optionals = [] - for option_string in action.option_strings: - if option_string in self._option_string_actions: - confl_optional = self._option_string_actions[option_string] - confl_optionals.append((option_string, confl_optional)) - - # resolve any conflicts - if confl_optionals: - conflict_handler = self._get_handler() - conflict_handler(action, confl_optionals) - - def _handle_conflict_error(self, action, conflicting_actions): - message = _('conflicting option string(s): %s') - conflict_string = ', '.join([option_string - for option_string, action - in conflicting_actions]) - raise ArgumentError(action, message % conflict_string) - - def _handle_conflict_resolve(self, action, conflicting_actions): - - # remove all conflicting options - for option_string, action in conflicting_actions: - - # remove the conflicting option - action.option_strings.remove(option_string) - self._option_string_actions.pop(option_string, None) - - # if the option now has no option string, remove it from the - # container holding it - if not action.option_strings: - action.container._remove_action(action) - - -class _ArgumentGroup(_ActionsContainer): - - def __init__(self, container, title=None, description=None, **kwargs): - # add any missing keyword arguments by checking the container - update = kwargs.setdefault - update('conflict_handler', container.conflict_handler) - update('prefix_chars', container.prefix_chars) - update('argument_default', container.argument_default) - super_init = super(_ArgumentGroup, self).__init__ - super_init(description=description, **kwargs) - - # group attributes - self.title = title - self._group_actions = [] - - # share most attributes with the container - self._registries = container._registries - self._actions = container._actions - self._option_string_actions = container._option_string_actions - self._defaults = container._defaults - self._has_negative_number_optionals = \ - container._has_negative_number_optionals - - def _add_action(self, action): - action = super(_ArgumentGroup, self)._add_action(action) - self._group_actions.append(action) - return action - - def _remove_action(self, action): - super(_ArgumentGroup, self)._remove_action(action) - self._group_actions.remove(action) - - -class _MutuallyExclusiveGroup(_ArgumentGroup): - - def __init__(self, container, required=False): - super(_MutuallyExclusiveGroup, self).__init__(container) - self.required = required - self._container = container - - def _add_action(self, action): - if action.required: - msg = _('mutually exclusive arguments must be optional') - raise ValueError(msg) - action = self._container._add_action(action) - self._group_actions.append(action) - return action - - def _remove_action(self, action): - self._container._remove_action(action) - self._group_actions.remove(action) - - -class ArgumentParser(_AttributeHolder, _ActionsContainer): - """Object for parsing command line strings into Python objects. - - Keyword Arguments: - - prog -- The name of the program (default: sys.argv[0]) - - usage -- A usage message (default: auto-generated from arguments) - - description -- A description of what the program does - - epilog -- Text following the argument descriptions - - parents -- Parsers whose arguments should be copied into this one - - formatter_class -- HelpFormatter class for printing help messages - - prefix_chars -- Characters that prefix optional arguments - - fromfile_prefix_chars -- Characters that prefix files containing - additional arguments - - argument_default -- The default value for all arguments - - conflict_handler -- String indicating how to handle conflicts - - add_help -- Add a -h/-help option - """ - - def __init__(self, - prog=None, - usage=None, - description=None, - epilog=None, - version=None, - parents=[], - formatter_class=HelpFormatter, - prefix_chars='-', - fromfile_prefix_chars=None, - argument_default=None, - conflict_handler='error', - add_help=True): - - if version is not None: - import warnings - warnings.warn( - """The "version" argument to ArgumentParser is deprecated. """ - """Please use """ - """"add_argument(..., action='version', version="N", ...)" """ - """instead""", DeprecationWarning) - - superinit = super(ArgumentParser, self).__init__ - superinit(description=description, - prefix_chars=prefix_chars, - argument_default=argument_default, - conflict_handler=conflict_handler) - - # default setting for prog - if prog is None: - prog = _os.path.basename(_sys.argv[0]) - - self.prog = prog - self.usage = usage - self.epilog = epilog - self.version = version - self.formatter_class = formatter_class - self.fromfile_prefix_chars = fromfile_prefix_chars - self.add_help = add_help - - add_group = self.add_argument_group - self._positionals = add_group(_('positional arguments')) - self._optionals = add_group(_('optional arguments')) - self._subparsers = None - - # register types - def identity(string): - return string - self.register('type', None, identity) - - # add help and version arguments if necessary - # (using explicit default to override global argument_default) - if self.add_help: - self.add_argument( - '-h', '--help', action='help', default=SUPPRESS, - help=_('show this help message and exit')) - if self.version: - self.add_argument( - '-v', '--version', action='version', default=SUPPRESS, - version=self.version, - help=_("show program's version number and exit")) - - # add parent arguments and defaults - for parent in parents: - self._add_container_actions(parent) - try: - defaults = parent._defaults - except AttributeError: - pass - else: - self._defaults.update(defaults) - - # ======================= - # Pretty __repr__ methods - # ======================= - def _get_kwargs(self): - names = [ - 'prog', - 'usage', - 'description', - 'version', - 'formatter_class', - 'conflict_handler', - 'add_help', - ] - return [(name, getattr(self, name)) for name in names] - - # ================================== - # Optional/Positional adding methods - # ================================== - def add_subparsers(self, **kwargs): - if self._subparsers is not None: - self.error(_('cannot have multiple subparser arguments')) - - # add the parser class to the arguments if it's not present - kwargs.setdefault('parser_class', type(self)) - - if 'title' in kwargs or 'description' in kwargs: - title = _(kwargs.pop('title', 'subcommands')) - description = _(kwargs.pop('description', None)) - self._subparsers = self.add_argument_group(title, description) - else: - self._subparsers = self._positionals - - # prog defaults to the usage message of this parser, skipping - # optional arguments and with no "usage:" prefix - if kwargs.get('prog') is None: - formatter = self._get_formatter() - positionals = self._get_positional_actions() - groups = self._mutually_exclusive_groups - formatter.add_usage(self.usage, positionals, groups, '') - kwargs['prog'] = formatter.format_help().strip() - - # create the parsers action and add it to the positionals list - parsers_class = self._pop_action_class(kwargs, 'parsers') - action = parsers_class(option_strings=[], **kwargs) - self._subparsers._add_action(action) - - # return the created parsers action - return action - - def _add_action(self, action): - if action.option_strings: - self._optionals._add_action(action) - else: - self._positionals._add_action(action) - return action - - def _get_optional_actions(self): - return [action - for action in self._actions - if action.option_strings] - - def _get_positional_actions(self): - return [action - for action in self._actions - if not action.option_strings] - - # ===================================== - # Command line argument parsing methods - # ===================================== - def parse_args(self, args=None, namespace=None): - args, argv = self.parse_known_args(args, namespace) - if argv: - msg = _('unrecognized arguments: %s') - self.error(msg % ' '.join(argv)) - return args - - def parse_known_args(self, args=None, namespace=None): - # args default to the system args - if args is None: - args = _sys.argv[1:] - - # default Namespace built from parser defaults - if namespace is None: - namespace = Namespace() - - # add any action defaults that aren't present - for action in self._actions: - if action.dest is not SUPPRESS: - if not hasattr(namespace, action.dest): - if action.default is not SUPPRESS: - default = action.default - if isinstance(action.default, _basestring): - default = self._get_value(action, default) - setattr(namespace, action.dest, default) - - # add any parser defaults that aren't present - for dest in self._defaults: - if not hasattr(namespace, dest): - setattr(namespace, dest, self._defaults[dest]) - - # parse the arguments and exit if there are any errors - try: - return self._parse_known_args(args, namespace) - except ArgumentError: - err = _sys.exc_info()[1] - self.error(str(err)) - - def _parse_known_args(self, arg_strings, namespace): - # replace arg strings that are file references - if self.fromfile_prefix_chars is not None: - arg_strings = self._read_args_from_files(arg_strings) - - # map all mutually exclusive arguments to the other arguments - # they can't occur with - action_conflicts = {} - for mutex_group in self._mutually_exclusive_groups: - group_actions = mutex_group._group_actions - for i, mutex_action in enumerate(mutex_group._group_actions): - conflicts = action_conflicts.setdefault(mutex_action, []) - conflicts.extend(group_actions[:i]) - conflicts.extend(group_actions[i + 1:]) - - # find all option indices, and determine the arg_string_pattern - # which has an 'O' if there is an option at an index, - # an 'A' if there is an argument, or a '-' if there is a '--' - option_string_indices = {} - arg_string_pattern_parts = [] - arg_strings_iter = iter(arg_strings) - for i, arg_string in enumerate(arg_strings_iter): - - # all args after -- are non-options - if arg_string == '--': - arg_string_pattern_parts.append('-') - for arg_string in arg_strings_iter: - arg_string_pattern_parts.append('A') - - # otherwise, add the arg to the arg strings - # and note the index if it was an option - else: - option_tuple = self._parse_optional(arg_string) - if option_tuple is None: - pattern = 'A' - else: - option_string_indices[i] = option_tuple - pattern = 'O' - arg_string_pattern_parts.append(pattern) - - # join the pieces together to form the pattern - arg_strings_pattern = ''.join(arg_string_pattern_parts) - - # converts arg strings to the appropriate and then takes the action - seen_actions = _set() - seen_non_default_actions = _set() - - def take_action(action, argument_strings, option_string=None): - seen_actions.add(action) - argument_values = self._get_values(action, argument_strings) - - # error if this argument is not allowed with other previously - # seen arguments, assuming that actions that use the default - # value don't really count as "present" - if argument_values is not action.default: - seen_non_default_actions.add(action) - for conflict_action in action_conflicts.get(action, []): - if conflict_action in seen_non_default_actions: - msg = _('not allowed with argument %s') - action_name = _get_action_name(conflict_action) - raise ArgumentError(action, msg % action_name) - - # take the action if we didn't receive a SUPPRESS value - # (e.g. from a default) - if argument_values is not SUPPRESS: - action(self, namespace, argument_values, option_string) - - # function to convert arg_strings into an optional action - def consume_optional(start_index): - - # get the optional identified at this index - option_tuple = option_string_indices[start_index] - action, option_string, explicit_arg = option_tuple - - # identify additional optionals in the same arg string - # (e.g. -xyz is the same as -x -y -z if no args are required) - match_argument = self._match_argument - action_tuples = [] - while True: - - # if we found no optional action, skip it - if action is None: - extras.append(arg_strings[start_index]) - return start_index + 1 - - # if there is an explicit argument, try to match the - # optional's string arguments to only this - if explicit_arg is not None: - arg_count = match_argument(action, 'A') - - # if the action is a single-dash option and takes no - # arguments, try to parse more single-dash options out - # of the tail of the option string - chars = self.prefix_chars - if arg_count == 0 and option_string[1] not in chars: - action_tuples.append((action, [], option_string)) - for char in self.prefix_chars: - option_string = char + explicit_arg[0] - explicit_arg = explicit_arg[1:] or None - optionals_map = self._option_string_actions - if option_string in optionals_map: - action = optionals_map[option_string] - break - else: - msg = _('ignored explicit argument %r') - raise ArgumentError(action, msg % explicit_arg) - - # if the action expect exactly one argument, we've - # successfully matched the option; exit the loop - elif arg_count == 1: - stop = start_index + 1 - args = [explicit_arg] - action_tuples.append((action, args, option_string)) - break - - # error if a double-dash option did not use the - # explicit argument - else: - msg = _('ignored explicit argument %r') - raise ArgumentError(action, msg % explicit_arg) - - # if there is no explicit argument, try to match the - # optional's string arguments with the following strings - # if successful, exit the loop - else: - start = start_index + 1 - selected_patterns = arg_strings_pattern[start:] - arg_count = match_argument(action, selected_patterns) - stop = start + arg_count - args = arg_strings[start:stop] - action_tuples.append((action, args, option_string)) - break - - # add the Optional to the list and return the index at which - # the Optional's string args stopped - assert action_tuples - for action, args, option_string in action_tuples: - take_action(action, args, option_string) - return stop - - # the list of Positionals left to be parsed; this is modified - # by consume_positionals() - positionals = self._get_positional_actions() - - # function to convert arg_strings into positional actions - def consume_positionals(start_index): - # match as many Positionals as possible - match_partial = self._match_arguments_partial - selected_pattern = arg_strings_pattern[start_index:] - arg_counts = match_partial(positionals, selected_pattern) - - # slice off the appropriate arg strings for each Positional - # and add the Positional and its args to the list - for action, arg_count in zip(positionals, arg_counts): - args = arg_strings[start_index: start_index + arg_count] - start_index += arg_count - take_action(action, args) - - # slice off the Positionals that we just parsed and return the - # index at which the Positionals' string args stopped - positionals[:] = positionals[len(arg_counts):] - return start_index - - # consume Positionals and Optionals alternately, until we have - # passed the last option string - extras = [] - start_index = 0 - if option_string_indices: - max_option_string_index = max(option_string_indices) - else: - max_option_string_index = -1 - while start_index <= max_option_string_index: - - # consume any Positionals preceding the next option - next_option_string_index = min([ - index - for index in option_string_indices - if index >= start_index]) - if start_index != next_option_string_index: - positionals_end_index = consume_positionals(start_index) - - # only try to parse the next optional if we didn't consume - # the option string during the positionals parsing - if positionals_end_index > start_index: - start_index = positionals_end_index - continue - else: - start_index = positionals_end_index - - # if we consumed all the positionals we could and we're not - # at the index of an option string, there were extra arguments - if start_index not in option_string_indices: - strings = arg_strings[start_index:next_option_string_index] - extras.extend(strings) - start_index = next_option_string_index - - # consume the next optional and any arguments for it - start_index = consume_optional(start_index) - - # consume any positionals following the last Optional - stop_index = consume_positionals(start_index) - - # if we didn't consume all the argument strings, there were extras - extras.extend(arg_strings[stop_index:]) - - # if we didn't use all the Positional objects, there were too few - # arg strings supplied. - if positionals: - self.error(_('too few arguments')) - - # make sure all required actions were present - for action in self._actions: - if action.required: - if action not in seen_actions: - name = _get_action_name(action) - self.error(_('argument %s is required') % name) - - # make sure all required groups had one option present - for group in self._mutually_exclusive_groups: - if group.required: - for action in group._group_actions: - if action in seen_non_default_actions: - break - - # if no actions were used, report the error - else: - names = [_get_action_name(action) - for action in group._group_actions - if action.help is not SUPPRESS] - msg = _('one of the arguments %s is required') - self.error(msg % ' '.join(names)) - - # return the updated namespace and the extra arguments - return namespace, extras - - def _read_args_from_files(self, arg_strings): - # expand arguments referencing files - new_arg_strings = [] - for arg_string in arg_strings: - - # for regular arguments, just add them back into the list - if arg_string[0] not in self.fromfile_prefix_chars: - new_arg_strings.append(arg_string) - - # replace arguments referencing files with the file content - else: - try: - args_file = open(arg_string[1:]) - try: - arg_strings = [] - for arg_line in args_file.read().splitlines(): - for arg in self.convert_arg_line_to_args(arg_line): - arg_strings.append(arg) - arg_strings = self._read_args_from_files(arg_strings) - new_arg_strings.extend(arg_strings) - finally: - args_file.close() - except IOError: - err = _sys.exc_info()[1] - self.error(str(err)) - - # return the modified argument list - return new_arg_strings - - def convert_arg_line_to_args(self, arg_line): - return [arg_line] - - def _match_argument(self, action, arg_strings_pattern): - # match the pattern for this action to the arg strings - nargs_pattern = self._get_nargs_pattern(action) - match = _re.match(nargs_pattern, arg_strings_pattern) - - # raise an exception if we weren't able to find a match - if match is None: - nargs_errors = { - None: _('expected one argument'), - OPTIONAL: _('expected at most one argument'), - ONE_OR_MORE: _('expected at least one argument'), - } - default = _('expected %s argument(s)') % action.nargs - msg = nargs_errors.get(action.nargs, default) - raise ArgumentError(action, msg) - - # return the number of arguments matched - return len(match.group(1)) - - def _match_arguments_partial(self, actions, arg_strings_pattern): - # progressively shorten the actions list by slicing off the - # final actions until we find a match - result = [] - for i in range(len(actions), 0, -1): - actions_slice = actions[:i] - pattern = ''.join([self._get_nargs_pattern(action) - for action in actions_slice]) - match = _re.match(pattern, arg_strings_pattern) - if match is not None: - result.extend([len(string) for string in match.groups()]) - break - - # return the list of arg string counts - return result - - def _parse_optional(self, arg_string): - # if it's an empty string, it was meant to be a positional - if not arg_string: - return None - - # if it doesn't start with a prefix, it was meant to be positional - if not arg_string[0] in self.prefix_chars: - return None - - # if the option string is present in the parser, return the action - if arg_string in self._option_string_actions: - action = self._option_string_actions[arg_string] - return action, arg_string, None - - # if it's just a single character, it was meant to be positional - if len(arg_string) == 1: - return None - - # if the option string before the "=" is present, return the action - if '=' in arg_string: - option_string, explicit_arg = arg_string.split('=', 1) - if option_string in self._option_string_actions: - action = self._option_string_actions[option_string] - return action, option_string, explicit_arg - - # search through all possible prefixes of the option string - # and all actions in the parser for possible interpretations - option_tuples = self._get_option_tuples(arg_string) - - # if multiple actions match, the option string was ambiguous - if len(option_tuples) > 1: - options = ', '.join([option_string - for action, option_string, explicit_arg in option_tuples]) - tup = arg_string, options - self.error(_('ambiguous option: %s could match %s') % tup) - - # if exactly one action matched, this segmentation is good, - # so return the parsed action - elif len(option_tuples) == 1: - option_tuple, = option_tuples - return option_tuple - - # if it was not found as an option, but it looks like a negative - # number, it was meant to be positional - # unless there are negative-number-like options - if self._negative_number_matcher.match(arg_string): - if not self._has_negative_number_optionals: - return None - - # if it contains a space, it was meant to be a positional - if ' ' in arg_string: - return None - - # it was meant to be an optional but there is no such option - # in this parser (though it might be a valid option in a subparser) - return None, arg_string, None - - def _get_option_tuples(self, option_string): - result = [] - - # option strings starting with two prefix characters are only - # split at the '=' - chars = self.prefix_chars - if option_string[0] in chars and option_string[1] in chars: - if '=' in option_string: - option_prefix, explicit_arg = option_string.split('=', 1) - else: - option_prefix = option_string - explicit_arg = None - for option_string in self._option_string_actions: - if option_string.startswith(option_prefix): - action = self._option_string_actions[option_string] - tup = action, option_string, explicit_arg - result.append(tup) - - # single character options can be concatenated with their arguments - # but multiple character options always have to have their argument - # separate - elif option_string[0] in chars and option_string[1] not in chars: - option_prefix = option_string - explicit_arg = None - short_option_prefix = option_string[:2] - short_explicit_arg = option_string[2:] - - for option_string in self._option_string_actions: - if option_string == short_option_prefix: - action = self._option_string_actions[option_string] - tup = action, option_string, short_explicit_arg - result.append(tup) - elif option_string.startswith(option_prefix): - action = self._option_string_actions[option_string] - tup = action, option_string, explicit_arg - result.append(tup) - - # shouldn't ever get here - else: - self.error(_('unexpected option string: %s') % option_string) - - # return the collected option tuples - return result - - def _get_nargs_pattern(self, action): - # in all examples below, we have to allow for '--' args - # which are represented as '-' in the pattern - nargs = action.nargs - - # the default (None) is assumed to be a single argument - if nargs is None: - nargs_pattern = '(-*A-*)' - - # allow zero or one arguments - elif nargs == OPTIONAL: - nargs_pattern = '(-*A?-*)' - - # allow zero or more arguments - elif nargs == ZERO_OR_MORE: - nargs_pattern = '(-*[A-]*)' - - # allow one or more arguments - elif nargs == ONE_OR_MORE: - nargs_pattern = '(-*A[A-]*)' - - # allow any number of options or arguments - elif nargs == REMAINDER: - nargs_pattern = '([-AO]*)' - - # allow one argument followed by any number of options or arguments - elif nargs == PARSER: - nargs_pattern = '(-*A[-AO]*)' - - # all others should be integers - else: - nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs) - - # if this is an optional action, -- is not allowed - if action.option_strings: - nargs_pattern = nargs_pattern.replace('-*', '') - nargs_pattern = nargs_pattern.replace('-', '') - - # return the pattern - return nargs_pattern - - # ======================== - # Value conversion methods - # ======================== - def _get_values(self, action, arg_strings): - # for everything but PARSER args, strip out '--' - if action.nargs not in [PARSER, REMAINDER]: - arg_strings = [s for s in arg_strings if s != '--'] - - # optional argument produces a default when not present - if not arg_strings and action.nargs == OPTIONAL: - if action.option_strings: - value = action.const - else: - value = action.default - if isinstance(value, _basestring): - value = self._get_value(action, value) - self._check_value(action, value) - - # when nargs='*' on a positional, if there were no command-line - # args, use the default if it is anything other than None - elif (not arg_strings and action.nargs == ZERO_OR_MORE and - not action.option_strings): - if action.default is not None: - value = action.default - else: - value = arg_strings - self._check_value(action, value) - - # single argument or optional argument produces a single value - elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: - arg_string, = arg_strings - value = self._get_value(action, arg_string) - self._check_value(action, value) - - # REMAINDER arguments convert all values, checking none - elif action.nargs == REMAINDER: - value = [self._get_value(action, v) for v in arg_strings] - - # PARSER arguments convert all values, but check only the first - elif action.nargs == PARSER: - value = [self._get_value(action, v) for v in arg_strings] - self._check_value(action, value[0]) - - # all other types of nargs produce a list - else: - value = [self._get_value(action, v) for v in arg_strings] - for v in value: - self._check_value(action, v) - - # return the converted value - return value - - def _get_value(self, action, arg_string): - type_func = self._registry_get('type', action.type, action.type) - if not _callable(type_func): - msg = _('%r is not callable') - raise ArgumentError(action, msg % type_func) - - # convert the value to the appropriate type - try: - result = type_func(arg_string) - - # ArgumentTypeErrors indicate errors - except ArgumentTypeError: - name = getattr(action.type, '__name__', repr(action.type)) - msg = str(_sys.exc_info()[1]) - raise ArgumentError(action, msg) - - # TypeErrors or ValueErrors also indicate errors - except (TypeError, ValueError): - name = getattr(action.type, '__name__', repr(action.type)) - msg = _('invalid %s value: %r') - raise ArgumentError(action, msg % (name, arg_string)) - - # return the converted value - return result - - def _check_value(self, action, value): - # converted value must be one of the choices (if specified) - if action.choices is not None and value not in action.choices: - tup = value, ', '.join(map(repr, action.choices)) - msg = _('invalid choice: %r (choose from %s)') % tup - raise ArgumentError(action, msg) - - # ======================= - # Help-formatting methods - # ======================= - def format_usage(self): - formatter = self._get_formatter() - formatter.add_usage(self.usage, self._actions, - self._mutually_exclusive_groups) - return formatter.format_help() - - def format_help(self): - formatter = self._get_formatter() - - # usage - formatter.add_usage(self.usage, self._actions, - self._mutually_exclusive_groups) - - # description - formatter.add_text(self.description) - - # positionals, optionals and user-defined groups - for action_group in self._action_groups: - formatter.start_section(action_group.title) - formatter.add_text(action_group.description) - formatter.add_arguments(action_group._group_actions) - formatter.end_section() - - # epilog - formatter.add_text(self.epilog) - - # determine help from format above - return formatter.format_help() - - def format_version(self): - import warnings - warnings.warn( - 'The format_version method is deprecated -- the "version" ' - 'argument to ArgumentParser is no longer supported.', - DeprecationWarning) - formatter = self._get_formatter() - formatter.add_text(self.version) - return formatter.format_help() - - def _get_formatter(self): - return self.formatter_class(prog=self.prog) - - # ===================== - # Help-printing methods - # ===================== - def print_usage(self, file=None): - if file is None: - file = _sys.stdout - self._print_message(self.format_usage(), file) - - def print_help(self, file=None): - if file is None: - file = _sys.stdout - self._print_message(self.format_help(), file) - - def print_version(self, file=None): - import warnings - warnings.warn( - 'The print_version method is deprecated -- the "version" ' - 'argument to ArgumentParser is no longer supported.', - DeprecationWarning) - self._print_message(self.format_version(), file) - - def _print_message(self, message, file=None): - if message: - if file is None: - file = _sys.stderr - file.write(message) - - # =============== - # Exiting methods - # =============== - def exit(self, status=0, message=None): - if message: - self._print_message(message, _sys.stderr) - _sys.exit(status) - - def error(self, message): - """error(message: string) - - Prints a usage message incorporating the message to stderr and - exits. - - If you override this in a subclass, it should not return -- it - should either exit or raise an exception. - """ - self.print_usage(_sys.stderr) - self.exit(2, _('%s: error: %s\n') % (self.prog, message)) diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/compareStats.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/compareStats.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,65 +0,0 @@ -#!/usr/bin/env python -# Title: HOL/Tools/Sledgehammer/MaSh/src/compareStats.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# Tool that compares MaSh statistics and displays a comparison. - -''' -Created on Jul 13, 2012 - -@author: Daniel Kuehlwein -''' - -import sys -from argparse import ArgumentParser,RawDescriptionHelpFormatter -from matplotlib.pyplot import plot,figure,show,legend,xlabel,ylabel,axis,hist -from stats import Statistics - -parser = ArgumentParser(description='Compare Statistics. \n\n\ -Loads different statistics and displays a comparison. Requires the matplotlib module.\n\n\ --------- Example Usage ---------------\n\ -./compareStats.py --statFiles ../tmp/natISANB.stats ../tmp/natATPNB.stats -b 30\n\n\ -Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter) -parser.add_argument('--statFiles', default=None, nargs='+', - help='The names of the saved statistic files.') -parser.add_argument('-b','--bins',default=50,help="Number of bins for the AUC histogram. Default=50.",type=int) - -def main(argv = sys.argv[1:]): - args = parser.parse_args(argv) - if args.statFiles == None: - print 'Filenames missing.' - sys.exit(-1) - - aucData = [] - aucLabels = [] - for statFile in args.statFiles: - s = Statistics() - s.load(statFile) - avgRecall = [float(x)/s.problems for x in s.recallData] - figure('Recall') - plot(range(s.cutOff),avgRecall,label=statFile) - legend(loc='lower right') - ylabel('Average Recall') - xlabel('Highest ranked premises') - axis([0,s.cutOff,0.0,1.0]) - figure('100%Recall') - plot(range(s.cutOff),s.recall100Data,label=statFile) - legend(loc='lower right') - ylabel('100%Recall') - xlabel('Highest ranked premises') - axis([0,s.cutOff,0,s.problems]) - aucData.append(s.aucData) - aucLabels.append(statFile) - figure('AUC Histogram') - hist(aucData,bins=args.bins,label=aucLabels,histtype='bar') - legend(loc='upper left') - ylabel('Problems') - xlabel('AUC') - - show() - -if __name__ == '__main__': - #args = ['--statFiles','../tmp/natISANB.stats','../tmp/natATPNB.stats','-b','30'] - #sys.exit(main(args)) - sys.exit(main()) diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,261 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012-2013 -# -# Persistent dictionaries: accessibility, dependencies, and features. - -import sys -from os.path import join -from Queue import Queue -from readData import create_accessible_dict,create_dependencies_dict -from cPickle import load,dump -from exceptions import LookupError - -class Dictionaries(object): - ''' - This class contains all info about name-> id mapping, etc. - ''' - def __init__(self): - ''' - Constructor - ''' - self.nameIdDict = {} - self.idNameDict = {} - self.featureIdDict={} - self.maxNameId = 1 - self.maxFeatureId = 0 - self.featureDict = {} - self.dependenciesDict = {} - self.accessibleDict = {} - self.expandedAccessibles = {} - self.accFile = '' - self.changed = True - # Unnamed facts - self.nameIdDict[''] = 0 - self.idNameDict[0] = 'Unnamed Fact' - - """ - Init functions. nameIdDict, idNameDict, featureIdDict, articleDict get filled! - """ - def init_featureDict(self,featureFile): - self.create_feature_dict(featureFile) - #self.featureDict,self.maxNameId,self.maxFeatureId,self.featureCountDict,self.triggerFeaturesDict,self.featureTriggeredFormulasDict =\ - # create_feature_dict(self.nameIdDict,self.idNameDict,self.maxNameId,self.featureIdDict,self.maxFeatureId,self.featureCountDict,\ - # self.triggerFeaturesDict,self.featureTriggeredFormulasDict,sineFeatures,featureFile) - def init_dependenciesDict(self,depFile): - self.dependenciesDict = create_dependencies_dict(self.nameIdDict,depFile) - def init_accessibleDict(self,accFile): - self.accessibleDict,self.maxNameId = create_accessible_dict(self.nameIdDict,self.idNameDict,self.maxNameId,accFile) - - def init_all(self,args): - self.featureFileName = 'mash_features' - self.accFileName = 'mash_accessibility' - featureFile = join(args.inputDir,self.featureFileName) - depFile = join(args.inputDir,args.depFile) - self.accFile = join(args.inputDir,self.accFileName) - self.init_featureDict(featureFile) - self.init_accessibleDict(self.accFile) - self.init_dependenciesDict(depFile) - self.expandedAccessibles = {} - self.changed = True - - def create_feature_dict(self,inputFile): - self.featureDict = {} - IS = open(inputFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - # Name Id - if self.nameIdDict.has_key(name): - raise LookupError('%s appears twice in the feature file. Aborting.'% name) - sys.exit(-1) - else: - self.nameIdDict[name] = self.maxNameId - self.idNameDict[self.maxNameId] = name - nameId = self.maxNameId - self.maxNameId += 1 - features = self.get_features(line) - # Store results - self.featureDict[nameId] = features - IS.close() - return - - def get_name_id(self,name): - """ - Return the Id for a name. - If it doesn't exist yet, a new entry is created. - """ - if self.nameIdDict.has_key(name): - nameId = self.nameIdDict[name] - else: - self.nameIdDict[name] = self.maxNameId - self.idNameDict[self.maxNameId] = name - nameId = self.maxNameId - self.maxNameId += 1 - self.changed = True - return nameId - - def add_feature(self,featureName): - fMul = featureName.split('|') - fIds = [] - for f in fMul: - if not self.featureIdDict.has_key(f): - self.featureIdDict[f] = self.maxFeatureId - self.maxFeatureId += 1 - self.changed = True - fId = self.featureIdDict[f] - fIds.append(fId) - return fIds - - def get_features(self,line): - featureNames = [f.strip() for f in line[1].split()] - features = {} - for fn in featureNames: - tmp = fn.split('=') - weight = 1.0 - if len(tmp) == 2: - fn = tmp[0] - weight = float(tmp[1]) - fIds = self.add_feature(tmp[0]) - features[fIds[0]] = (weight,fIds[1:]) - #features[fId] = 1.0 ### - return features - - def expand_accessibles(self,acc): - accessibles = set(acc) - unexpandedQueue = Queue() - for a in acc: - if self.expandedAccessibles.has_key(a): - accessibles = accessibles.union(self.expandedAccessibles[a]) - else: - unexpandedQueue.put(a) - while not unexpandedQueue.empty(): - nextUnExp = unexpandedQueue.get() - nextUnExpAcc = self.accessibleDict[nextUnExp] - for a in nextUnExpAcc: - if not a in accessibles: - accessibles = accessibles.union([a]) - if self.expandedAccessibles.has_key(a): - accessibles = accessibles.union(self.expandedAccessibles[a]) - else: - unexpandedQueue.put(a) - return list(accessibles) - - def parse_unExpAcc(self,line): - try: - unExpAcc = [self.nameIdDict[a.strip()] for a in line.split()] - except: - raise LookupError('Cannot find the accessibles:%s. Accessibles need to be introduced before referring to them.' % line) - return unExpAcc - - def parse_fact(self,line): - """ - Parses a single line, extracting accessibles, features, and dependencies. - """ - assert line.startswith('! ') - line = line[2:] - - # line = name:accessibles;features;dependencies - line = line.split(':') - name = line[0].strip() - nameId = self.get_name_id(name) - line = line[1].split(';') - features = self.get_features(line) - self.featureDict[nameId] = features - try: - self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()] - except: - unknownDeps = [] - for d in line[2].split(): - if not self.nameIdDict.has_key(d): - unknownDeps.append(d) - raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps)) - self.accessibleDict[nameId] = self.parse_unExpAcc(line[0]) - - self.changed = True - return nameId - - def parse_overwrite(self,line): - """ - Parses a single line, extracts the problemId and the Ids of the dependencies. - """ - assert line.startswith('p ') - line = line[2:] - - # line = name:dependencies - line = line.split(':') - name = line[0].strip() - try: - nameId = self.nameIdDict[name] - except: - raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % name) - try: - dependencies = [self.nameIdDict[d.strip()] for d in line[1].split()] - except: - unknownDeps = [] - for d in line[1].split(): - if not self.nameIdDict.has_key(d): - unknownDeps.append(d) - raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps)) - self.changed = True - return nameId,dependencies - - def parse_problem(self,line): - """ - Parses a problem and returns the features, the accessibles, and any hints. - """ - assert line.startswith('? ') - line = line[2:] - name = None - numberOfPredictions = None - - # How many predictions should be returned: - tmp = line.split('#') - if len(tmp) == 2: - numberOfPredictions = int(tmp[0].strip()) - line = tmp[1] - - # Check whether there is a problem name: - tmp = line.split(':') - if len(tmp) == 2: - name = tmp[0].strip() - line = tmp[1] - - # line = accessibles;features - line = line.split(';') - features = self.get_features(line) - - # Accessible Ids, expand and store the accessibles. - #unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()] - unExpAcc = self.parse_unExpAcc(line[0]) - if len(self.expandedAccessibles.keys())>=100: - self.expandedAccessibles = {} - self.changed = True - for accId in unExpAcc: - if not self.expandedAccessibles.has_key(accId): - accIdAcc = self.accessibleDict[accId] - self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc) - self.changed = True - accessibles = self.expand_accessibles(unExpAcc) - - # Get hints: - if len(line) == 3: - hints = [self.nameIdDict[d.strip()] for d in line[2].split()] - else: - hints = [] - - return name,features,accessibles,hints,numberOfPredictions - - def save(self,fileName): - if self.changed: - dictsStream = open(fileName, 'wb') - dump((self.accessibleDict,self.dependenciesDict,self.expandedAccessibles,self.featureDict,\ - self.featureIdDict,self.idNameDict,self.maxFeatureId,self.maxNameId,self.nameIdDict),dictsStream) - self.changed = False - dictsStream.close() - def load(self,fileName): - dictsStream = open(fileName, 'rb') - self.accessibleDict,self.dependenciesDict,self.expandedAccessibles,self.featureDict,\ - self.featureIdDict,self.idNameDict,self.maxFeatureId,self.maxNameId,self.nameIdDict = load(dictsStream) - self.changed = False - dictsStream.close() diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/mash.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,149 +0,0 @@ -#!/usr/bin/env python -# Title: HOL/Tools/Sledgehammer/MaSh/src/mash -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 - 2013 -# -# Entry point for MaSh (Machine Learning for Sledgehammer). - -''' -MaSh - Machine Learning for Sledgehammer - -MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer. - -Created on July 12, 2012 - -@author: Daniel Kuehlwein -''' - -import socket,sys,time,logging,os -from os.path import realpath,dirname - -from spawnDaemon import spawnDaemon -from parameters import init_parser - -def communicate(data,host,port): - logger = logging.getLogger('communicate') - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.connect((host,port)) - sock.sendall(data+'\n') - received = '' - cont = True - counter = 0 - while cont and counter < 100000: - rec = sock.recv(4096) - if rec.endswith('stop'): - cont = False - received += rec[:-4] - else: - received += rec - counter += 1 - if rec == '': - logger.warning('No response from server. Check server log for details.') - except: - logger.warning('Communication with server failed.') - received = -1 - finally: - sock.close() - return received - -def start_server(host,port): - logger = logging.getLogger('start_server') - logger.info('Starting Server.') - path = dirname(realpath(__file__)) - spawnDaemon(os.path.join(path,'server.py'),os.path.join(path,'server.py'),host,str(port)) - serverIsUp=False - for _i in range(20): - # Test if server is up - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect((host,port)) - sock.close() - serverIsUp = True - break - except: - time.sleep(0.5) - if not serverIsUp: - logger.error('Could not start server.') - sys.exit(-1) - return True - -def mash(argv = sys.argv[1:]): - # Initializing command-line arguments - args = init_parser(argv) - # Set up logging - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', - datefmt='%d-%m %H:%M:%S', - filename=args.log, - filemode='w') - logger = logging.getLogger('mash') - - if args.quiet: - logger.setLevel(logging.WARNING) - else: - console = logging.StreamHandler(sys.stdout) - console.setLevel(logging.INFO) - formatter = logging.Formatter('# %(message)s') - console.setFormatter(formatter) - logging.getLogger('').addHandler(console) - - if not os.path.exists(args.outputDir): - os.makedirs(args.outputDir) - - # Shutdown commands need not start the server fist. - if args.shutdownServer: - logger.info('Sending shutdown command.') - try: - received = communicate('shutdown',args.host,args.port) - logger.info(received) - except: - pass - return - - # If server is not running, start it. - startedServer = False - received = communicate(' '.join(('ping',args.modelFile,args.dictsFile)),args.host,args.port) - if received == -1: - startedServer = start_server(args.host,args.port) - elif received.startswith('Files do not match'): - logger.error('Filesnames do not match!') - logger.error('Modelfile server: '+ received.split()[-2]) - logger.error('Modelfile argument: '+ args.modelFile) - logger.error('Dictsfile server: '+ received.split()[-1]) - logger.error('Dictsfile argument: '+ args.dictsFile) - return - - if args.init or startedServer: - logger.info('Initializing Server.') - data = "i "+";".join(argv) - received = communicate(data,args.host,args.port) - logger.info(received) - - if not args.inputFile == None: - logger.debug('Using the following settings: %s',args) - # IO Streams - OS = open(args.predictions,'w') - IS = open(args.inputFile,'r') - lineCount = 0 - for line in IS: - lineCount += 1 - if lineCount % 100 == 0: - logger.info('On line %s', lineCount) - received = communicate(line,args.host,args.port) - if not received == '': - OS.write('%s\n' % received) - OS.close() - IS.close() - - # Statistics - if args.statistics: - received = communicate('avgStats',args.host,args.port) - logger.info(received) - - if args.saveModels: - communicate('save',args.host,args.port) - - -if __name__ == "__main__": - sys.exit(mash()) diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,47 +0,0 @@ -import datetime -from argparse import ArgumentParser,RawDescriptionHelpFormatter - -def init_parser(argv): - # Set up command-line parser - parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\ - MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\ - --------------- Example Usage ---------------\n\ - First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\ - Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\ - \n\n\ - Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter) - parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.') - parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.') - parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(), - help='File where the predictions stored. Default=../tmp/dateTime.predictions.') - parser.add_argument('--numberOfPredictions',default=500,help="Default number of premises to write in the output. Default=500.",type=int) - - parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.") - parser.add_argument('--inputDir',\ - help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility') - parser.add_argument('--depFile', default='mash_dependencies', - help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies') - - parser.add_argument('--algorithm',default='nb',help="Which learning algorithm is used. nb = Naive Bayes,KNN,predef=predefined. Default=nb.") - parser.add_argument('--predef',help="File containing the predefined suggestions. Only used when algorithm = predef.") - # NB Parameters - parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float) - parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float) - parser.add_argument('--NBPosWeight',default=10.0,help="Weight value for positive features. Default=10.0.",type=float) - parser.add_argument('--expandFeatures',default=False,action='store_true',help="Learning-based feature expansion. Default=False.") - - parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\ - WARNING: This will make the program a lot slower! Default=False.") - parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.") - parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int) - parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log') - parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.") - parser.add_argument('--modelFile', default='../tmp/model.pickle', help='Model file name. Default=../tmp/model.pickle') - parser.add_argument('--dictsFile', default='../tmp/dict.pickle', help='Dict file name. Default=../tmp/dict.pickle') - - parser.add_argument('--port', default='9255', help='Port of the Mash server. Default=9255',type=int) - parser.add_argument('--host', default='localhost', help='Host of the Mash server. Default=localhost') - parser.add_argument('--shutdownServer',default=False,action='store_true',help="Shutdown server without saving the models.") - parser.add_argument('--saveModels',default=False,action='store_true',help="Server saves the models.") - args = parser.parse_args(argv) - return args diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,68 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/predefined.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# A classifier that uses the Meng-Paulson predictions. - -''' -Created on Jul 11, 2012 - -@author: Daniel Kuehlwein -''' - -from cPickle import dump,load - -class Predefined(object): - ''' - A classifier that uses the Meng-Paulson predictions. - Only used to easily compare statistics between the old Sledgehammer algorithm and the new machine learning ones. - ''' - - def __init__(self,mpPredictionFile): - ''' - Constructor - ''' - self.predictionFile = mpPredictionFile - - def initializeModel(self,_trainData,dicts): - """ - Load predictions. - """ - self.predictions = {} - IS = open(self.predictionFile,'r') - for line in IS: - line = line.split(':') - name = line[0].strip() - predId = dicts.get_name_id(name) - line = line[1].split() - predsTmp = [] - for x in line: - x = x.split('=') - predsTmp.append(x[0]) - preds = [dicts.get_name_id(x.strip())for x in predsTmp] - self.predictions[predId] = preds - IS.close() - - def update(self,dataPoint,features,dependencies): - """ - Updates the Model. - """ - # No Update needed since we assume that we got all predictions - pass - - - def predict(self,problemId): - """ - Return the saved predictions. - """ - return self.predictions[problemId] - - def save(self,fileName): - OStream = open(fileName, 'wb') - dump((self.predictionFile,self.predictions),OStream) - OStream.close() - - def load(self,fileName): - OStream = open(fileName, 'rb') - self.predictionFile,self.predictions = load(OStream) - OStream.close() diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/readData.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/readData.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,58 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/readData.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# All functions to read the Isabelle output. - -''' -All functions to read the Isabelle output. - -Created on July 9, 2012 - -@author: Daniel Kuehlwein -''' - -import sys,logging - -def create_dependencies_dict(nameIdDict,inputFile): - logger = logging.getLogger('create_dependencies_dict') - dependenciesDict = {} - IS = open(inputFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - # Name Id - if not nameIdDict.has_key(name): - logger.warning('%s is missing in nameIdDict. Aborting.',name) - sys.exit(-1) - - nameId = nameIdDict[name] - dependenciesIds = [nameIdDict[f.strip()] for f in line[1].split()] - # Store results, add p proves p - if nameId == 0: - dependenciesDict[nameId] = dependenciesIds - else: - dependenciesDict[nameId] = [nameId] + dependenciesIds - IS.close() - return dependenciesDict - -def create_accessible_dict(nameIdDict,idNameDict,maxNameId,inputFile): - logger = logging.getLogger('create_accessible_dict') - accessibleDict = {} - IS = open(inputFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - # Name Id - if not nameIdDict.has_key(name): - logger.warning('%s is missing in nameIdDict. Adding it as theory.',name) - nameIdDict[name] = maxNameId - idNameDict[maxNameId] = name - nameId = maxNameId - maxNameId += 1 - else: - nameId = nameIdDict[name] - accessibleStrings = line[1].split() - accessibleDict[nameId] = [nameIdDict[a.strip()] for a in accessibleStrings] - IS.close() - return accessibleDict,maxNameId diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/server.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,234 +0,0 @@ -#!/usr/bin/env python -# Title: HOL/Tools/Sledgehammer/MaSh/src/server.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2013 -# -# The MaSh Server. - -import SocketServer,os,string,logging,sys -from multiprocessing import Manager -from threading import Timer -from time import time -from dictionaries import Dictionaries -from parameters import init_parser -from sparseNaiveBayes import sparseNBClassifier -from KNN import KNN,euclidean -from KNNs import KNNAdaptPointFeatures,KNNUrban -#from bayesPlusMetric import sparseNBPlusClassifier -from predefined import Predefined -from ExpandFeatures import ExpandFeatures -from stats import Statistics - - -class ThreadingTCPServer(SocketServer.ThreadingTCPServer): - - def __init__(self, *args, **kwargs): - SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs) - self.manager = Manager() - self.lock = Manager().Lock() - self.idle_timeout = 28800.0 # 8 hours in seconds - self.idle_timer = Timer(self.idle_timeout, self.shutdown) - self.idle_timer.start() - self.model = None - self.dicts = None - self.callCounter = 0 - - def save(self): - if self.model == None or self.dicts == None: - try: - self.logger.warning('Cannot save nonexisting models.') - except: - pass - return - # Save Models - self.model.save(self.args.modelFile) - self.dicts.save(self.args.dictsFile) - if not self.args.saveStats == None: - statsFile = os.path.join(self.args.outputDir,self.args.saveStats) - self.stats.save(statsFile) - - def save_and_shutdown(self): - self.save() - self.shutdown() - -class MaShHandler(SocketServer.StreamRequestHandler): - - def init(self,argv): - if argv == '': - self.server.args = init_parser([]) - else: - argv = argv.split(';') - self.server.args = init_parser(argv) - - # Set up logging - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', - datefmt='%d-%m %H:%M:%S', - filename=self.server.args.log+'server', - filemode='w') - self.server.logger = logging.getLogger('server') - - # Load all data - self.server.dicts = Dictionaries() - if os.path.isfile(self.server.args.dictsFile): - self.server.dicts.load(self.server.args.dictsFile) - #elif not self.server.args.dictsFile == '../tmp/dict.pickle': - # raise IOError('Cannot find dictsFile at %s '% self.server.args.dictsFile) - elif self.server.args.init: - self.server.dicts.init_all(self.server.args) - # Pick model - if self.server.args.algorithm == 'nb': - ###TODO: !! - self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal) - #self.server.model = sparseNBPlusClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal) - elif self.server.args.algorithm == 'KNN': - #self.server.model = KNN(self.server.dicts) - self.server.model = KNNAdaptPointFeatures(self.server.dicts) - elif self.server.args.algorithm == 'predef': - self.server.model = Predefined(self.server.args.predef) - else: # Default case - self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal) - if self.server.args.expandFeatures: - self.server.expandFeatures = ExpandFeatures(self.server.dicts) - self.server.expandFeatures.initialize(self.server.dicts) - # Create Model - if os.path.isfile(self.server.args.modelFile): - self.server.model.load(self.server.args.modelFile) - #elif not self.server.args.modelFile == '../tmp/model.pickle': - # raise IOError('Cannot find modelFile at %s '% self.server.args.modelFile) - elif self.server.args.init: - trainData = self.server.dicts.featureDict.keys() - self.server.model.initializeModel(trainData,self.server.dicts) - - if self.server.args.statistics: - self.server.stats = Statistics(self.server.args.cutOff) - self.server.statementCounter = 1 - self.server.computeStats = False - - self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.') - self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.') - self.server.callCounter = 1 - - def update(self): - problemId = self.server.dicts.parse_fact(self.data) - # Statistics - if self.server.args.statistics and self.server.computeStats: - self.server.computeStats = False - # Assume '!' comes after '?' - if self.server.args.algorithm == 'predef': - self.server.predictions = self.server.model.predict(problemId) - self.server.stats.update(self.server.predictions,self.server.dicts.dependenciesDict[problemId],self.server.statementCounter) - if not self.server.stats.badPreds == []: - bp = string.join([str(self.server.dicts.idNameDict[x]) for x in self.server.stats.badPreds], ',') - self.server.logger.debug('Poor predictions: %s',bp) - self.server.statementCounter += 1 - - if self.server.args.expandFeatures: - self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId]) - # Update Dependencies, p proves p - if not problemId == 0: - self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId] - ###TODO: - self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId]) - #self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId],self.server.dicts) - - def overwrite(self): - # Overwrite old proof. - problemId,newDependencies = self.server.dicts.parse_overwrite(self.data) - newDependencies = [problemId]+newDependencies - self.server.model.overwrite(problemId,newDependencies,self.server.dicts) - self.server.dicts.dependenciesDict[problemId] = newDependencies - - def predict(self): - self.server.computeStats = True - if self.server.args.algorithm == 'predef': - return - name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data) - if numberOfPredictions == None: - numberOfPredictions = self.server.args.numberOfPredictions - if not hints == []: - self.server.model.update('hints',features,hints) - if self.server.args.expandFeatures: - features = self.server.expandFeatures.expand(features) - # Create predictions - self.server.logger.debug('Starting computation for line %s',self.server.callCounter) - - self.server.predictions,predictionValues = self.server.model.predict(features,accessibles,self.server.dicts) - assert len(self.server.predictions) == len(predictionValues) - self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2))) - - # Output - predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]] - predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]] - predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] - predictionsString = string.join(predictionsStringList,' ') - #predictionsString = string.join(predictionNames,' ') - outString = '%s: %s' % (name,predictionsString) - self.request.sendall(outString) - - def shutdown(self,saveModels=True): - self.request.sendall('Shutting down server.') - if saveModels: - self.server.save() - self.server.idle_timer.cancel() - self.server.idle_timer = Timer(0.5, self.server.shutdown) - self.server.idle_timer.start() - - def handle(self): - # self.request is the TCP socket connected to the client - self.server.lock.acquire() - self.data = self.rfile.readline().strip() - try: - # Update idle shutdown timer - self.server.idle_timer.cancel() - self.server.idle_timer = Timer(self.server.idle_timeout, self.server.save_and_shutdown) - self.server.idle_timer.start() - - self.startTime = time() - if self.data == 'shutdown': - self.shutdown() - elif self.data == 'save': - self.server.save() - elif self.data.startswith('ping'): - mFile, dFile = self.data.split()[1:] - if mFile == self.server.args.modelFile and dFile == self.server.args.dictsFile: - self.request.sendall('All good.') - else: - self.request.sendall('Files do not match '+' '.join((self.server.args.modelFile,self.server.args.dictsFile))) - elif self.data.startswith('i'): - self.init(self.data[2:]) - elif self.data.startswith('!'): - self.update() - elif self.data.startswith('p'): - self.overwrite() - elif self.data.startswith('?'): - self.predict() - elif self.data == '': - # Empty Socket - return - elif self.data == 'avgStats': - self.request.sendall(self.server.stats.printAvg()) - else: - self.request.sendall('Unspecified input format: \n%s',self.data) - self.server.callCounter += 1 - self.request.sendall('stop') - except: # catch exceptions - #print 'Caught an error. Check %s for more details' % (self.server.args.log+'server') - logging.exception('') - finally: - self.server.lock.release() - -if __name__ == "__main__": - if not len(sys.argv[1:]) == 2: - print 'No Arguments for HOST and PORT found. Using localhost and 9255' - HOST, PORT = "localhost", 9255 - else: - HOST, PORT = sys.argv[1:] - SocketServer.TCPServer.allow_reuse_address = True - server = ThreadingTCPServer((HOST, int(PORT)), MaShHandler) - server.serve_forever() - - - - - \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,176 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# An updatable sparse naive Bayes classifier. - -''' -Created on Jul 11, 2012 - -@author: Daniel Kuehlwein -''' - -from cPickle import dump,load -from math import log,exp - - -class singleNBClassifier(object): - ''' - An updateable naive Bayes classifier. - ''' - - def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0): - ''' - Constructor - ''' - self.neg = 0.0 - self.pos = 0.0 - self.counts = {} # Counts is the tuple poscounts,negcounts - self.defValPos = defValPos - self.defValNeg = defValNeg - self.posWeight = posWeight - - def update(self,features,label): - """ - Updates the Model. - - @param label: True or False, True if the features belong to a positive label, false else. - """ - #print label,self.pos,self.neg,self.counts - if label: - self.pos += 1 - else: - self.neg += 1 - - for f,_w in features: - if not self.counts.has_key(f): - if label: - fPosCount = 0.0 - fNegCount = 0.0 - self.counts[f] = [fPosCount,fNegCount] - else: - continue - posCount,negCount = self.counts[f] - if label: - posCount += 1 - else: - negCount += 1 - self.counts[f] = [posCount,negCount] - #print label,self.pos,self.neg,self.counts - - - def delete(self,features,label): - """ - Deletes a single datapoint from the model. - """ - if label: - self.pos -= 1 - else: - self.neg -= 1 - for f,_w in features: - posCount,negCount = self.counts[f] - if label: - posCount -= 1 - else: - negCount -= 1 - self.counts[f] = [posCount,negCount] - - - def overwrite(self,features,labelOld,labelNew): - """ - Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. - """ - self.delete(features,labelOld) - self.update(features,labelNew) - - def predict_sparse(self,features): - """ - Returns 1 if the probability of + being the correct label is greater than the probability that - is the correct label. - """ - if self.neg == 0: - return 1 - elif self.pos ==0: - return 0 - logneg = log(self.neg) - lognegprior=log(float(self.neg)/5) - logpos = log(self.pos) - prob = logpos - lognegprior - - for f,_w in features: - if self.counts.has_key(f): - posCount,negCount = self.counts[f] - if posCount > 0: - prob += (log(self.posWeight * posCount) - logpos) - else: - prob += self.defValPos - if negCount > 0: - prob -= (log(negCount) - logneg) - else: - prob -= self.defValNeg - if prob >= 0 : - return 1 - else: - return 0 - - def predict(self,features): - """ - Returns 1 if the probability is greater than 50%. - """ - if self.neg == 0: - return 1 - elif self.pos ==0: - return 0 - defVal = -15.0 - expDefVal = exp(defVal) - - logneg = log(self.neg) - logpos = log(self.pos) - prob = logpos - logneg - - for f in self.counts.keys(): - posCount,negCount = self.counts[f] - if f in features: - if posCount == 0: - prob += defVal - else: - prob += log(float(posCount)/self.pos) - if negCount == 0: - prob -= defVal - else: - prob -= log(float(negCount)/self.neg) - else: - if posCount == self.pos: - prob += log(1-expDefVal) - else: - prob += log(1-float(posCount)/self.pos) - if negCount == self.neg: - prob -= log(1-expDefVal) - else: - prob -= log(1-float(negCount)/self.neg) - - if prob >= 0 : - return 1 - else: - return 0 - - def save(self,fileName): - OStream = open(fileName, 'wb') - dump(self.counts,OStream) - OStream.close() - - def load(self,fileName): - OStream = open(fileName, 'rb') - self.counts = load(OStream) - OStream.close() - -if __name__ == '__main__': - x = singleNBClassifier() - x.update([0], True) - assert x.predict([0]) == 1 - x = singleNBClassifier() - x.update([0], False) - assert x.predict([0]) == 0 - - x.update([0], True) - x.update([1], True) - print x.pos,x.neg,x.predict([0,1]) \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/snow.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,135 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/snow.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# Wrapper for SNoW. - -''' - -Created on Jul 12, 2012 - -@author: daniel -''' - -import logging,shlex,subprocess,string,shutil -#from cPickle import load,dump - -class SNoW(object): - ''' - Calls the SNoW framework. - ''' - - def __init__(self): - ''' - Constructor - ''' - self.logger = logging.getLogger('SNoW') - self.SNoWTrainFile = '../tmp/snow.train' - self.SNoWTestFile = '../snow.test' - self.SNoWNetFile = '../tmp/snow.net' - self.defMaxNameId = 100000 - - def initializeModel(self,trainData,dicts): - """ - Build basic model from training data. - """ - # Prepare input files - self.logger.debug('Creating IO Files') - OS = open(self.SNoWTrainFile,'w') - for nameId in trainData: - features = [f+dicts.maxNameId for f,_w in dicts.featureDict[nameId]] - #features = [f+self.defMaxNameId for f,_w in dicts.featureDict[nameId]] - features = map(str,features) - featureString = string.join(features,',') - dependencies = dicts.dependenciesDict[nameId] - dependencies = map(str,dependencies) - dependenciesString = string.join(dependencies,',') - snowString = string.join([featureString,dependenciesString],',')+':\n' - OS.write(snowString) - OS.close() - - # Build Model - self.logger.debug('Building Model START.') - snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,dicts.maxNameId-1) - #print snowTrainCommand - #snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,self.defMaxNameId-1) - args = shlex.split(snowTrainCommand) - p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) - p.wait() - self.logger.debug('Building Model END.') - - def update(self,dataPoint,features,dependencies,dicts): - """ - Updates the Model. - """ - """ - self.logger.debug('Updating Model START') - # Ignore Feature weights - features = [f+self.defMaxNameId for f,_w in features] - - OS = open(self.SNoWTestFile,'w') - features = map(str,features) - featureString = string.join(features, ',') - dependencies = map(str,dependencies) - dependenciesString = string.join(dependencies,',') - snowString = string.join([featureString,dependenciesString],',')+':\n' - OS.write(snowString) - OS.close() - snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth -i+' % (self.SNoWTestFile,self.SNoWNetFile) - args = shlex.split(snowTestCommand) - p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) - (_lines, _stderrdata) = p.communicate() - # Move new net file - src = self.SNoWNetFile+'.new' - dst = self.SNoWNetFile - shutil.move(src, dst) - self.logger.debug('Updating Model END') - """ - # Do nothing, only update at evaluation. Is a lot faster. - pass - - - def predict(self,features,accessibles,dicts): - trainData = dicts.featureDict.keys() - self.initializeModel(trainData, dicts) - - logger = logging.getLogger('predict_SNoW') - # Ignore Feature weights - #features = [f+self.defMaxNameId for f,_w in features] - features = [f+dicts.maxNameId for f,_w in features] - - OS = open(self.SNoWTestFile,'w') - features = map(str,features) - featureString = string.join(features, ',') - snowString = featureString+':' - OS.write(snowString) - OS.close() - - snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth' % (self.SNoWTestFile,self.SNoWNetFile) - args = shlex.split(snowTestCommand) - p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT) - (lines, _stderrdata) = p.communicate() - logger.debug('SNoW finished.') - lines = lines.split('\n') - assert lines[9].startswith('Example ') - assert lines[-4] == '' - predictionsCon = [] - predictionsValues = [] - for line in lines[10:-4]: - premiseId = int(line.split()[0][:-1]) - predictionsCon.append(premiseId) - val = line.split()[4] - if val.endswith('*'): - val = float(val[:-1]) - else: - val = float(val) - predictionsValues.append(val) - return predictionsCon,predictionsValues - - def save(self,fileName): - # Nothing to do since we don't update - pass - - def load(self,fileName): - # Nothing to do since we don't update - pass diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,179 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# An updatable sparse naive Bayes classifier. - -''' -Created on Jul 11, 2012 - -@author: Daniel Kuehlwein -''' -from cPickle import dump,load -from numpy import array -from math import log - -class sparseNBClassifier(object): - ''' - An updateable naive Bayes classifier. - ''' - - def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0): - ''' - Constructor - ''' - self.counts = {} - self.defaultPriorWeight = defaultPriorWeight - self.posWeight = posWeight - self.defVal = defVal - - def initializeModel(self,trainData,dicts): - """ - Build basic model from training data. - """ - for d in trainData: - dFeatureCounts = {} - # Add p proves p with weight self.defaultPriorWeight - if not self.defaultPriorWeight == 0: - for f in dicts.featureDict[d].iterkeys(): - dFeatureCounts[f] = self.defaultPriorWeight - self.counts[d] = [self.defaultPriorWeight,dFeatureCounts] - - for key,keyDeps in dicts.dependenciesDict.iteritems(): - keyFeatures = dicts.featureDict[key] - for dep in keyDeps: - self.counts[dep][0] += 1 - #depFeatures = dicts.featureDict[key] - for f in keyFeatures.iterkeys(): - if self.counts[dep][1].has_key(f): - self.counts[dep][1][f] += 1 - else: - self.counts[dep][1][f] = 1 - - - def update(self,dataPoint,features,dependencies): - """ - Updates the Model. - """ - if (not self.counts.has_key(dataPoint)) and (not dataPoint == 0): - dFeatureCounts = {} - # Give p |- p a higher weight - if not self.defaultPriorWeight == 0: - for f in features.iterkeys(): - dFeatureCounts[f] = self.defaultPriorWeight - self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts] - for dep in dependencies: - self.counts[dep][0] += 1 - for f in features.iterkeys(): - if self.counts[dep][1].has_key(f): - self.counts[dep][1][f] += 1 - else: - self.counts[dep][1][f] = 1 - - def delete(self,dataPoint,features,dependencies): - """ - Deletes a single datapoint from the model. - """ - for dep in dependencies: - self.counts[dep][0] -= 1 - for f,_w in features.items(): - self.counts[dep][1][f] -= 1 - if self.counts[dep][1][f] == 0: - del self.counts[dep][1][f] - - - def overwrite(self,problemId,newDependencies,dicts): - """ - Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly. - """ - try: - assert self.counts.has_key(problemId) - except: - raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % dicts.idNameDict[problemId]) - oldDeps = dicts.dependenciesDict[problemId] - features = dicts.featureDict[problemId] - self.delete(problemId,features,oldDeps) - self.update(problemId,features,newDependencies) - - def predict(self,features,accessibles,dicts): - """ - For each accessible, predicts the probability of it being useful given the features. - Returns a ranking of the accessibles. - """ - tau = 0.05 # Jasmin, change value here - predictions = [] - observedFeatures = features.keys() - for fVal in features.itervalues(): - _w,alternateF = fVal - observedFeatures += alternateF - - for a in accessibles: - posA = self.counts[a][0] - fA = set(self.counts[a][1].keys()) - fWeightsA = self.counts[a][1] - resultA = log(posA) - for f,fVal in features.iteritems(): - w,alternateF = fVal - # DEBUG - #w = 1.0 - # Test for multiple features - isMatch = False - matchF = None - if f in fA: - isMatch = True - matchF = f - elif len(alternateF) > 0: - inter = set(alternateF).intersection(fA) - if len(inter) > 0: - isMatch = True - for mF in inter: - ### TODO: matchF is randomly selected - matchF = mF - break - - if isMatch: - #if f in fA: - if fWeightsA[matchF] == 0: - resultA += w*self.defVal - else: - assert fWeightsA[matchF] <= posA - resultA += w*log(float(self.posWeight*fWeightsA[matchF])/posA) - else: - resultA += w*self.defVal - if not tau == 0.0: - missingFeatures = list(fA.difference(observedFeatures)) - #sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures]) # slower - sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT - #sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures) - resultA -= tau * sumOfWeights - predictions.append(resultA) - predictions = array(predictions) - perm = (-predictions).argsort() - return array(accessibles)[perm],predictions[perm] - - def save(self,fileName): - OStream = open(fileName, 'wb') - dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream) - OStream.close() - - def load(self,fileName): - OStream = open(fileName, 'rb') - self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream) - OStream.close() - - -if __name__ == '__main__': - featureDict = {0:[0,1,2],1:[3,2,1]} - dependenciesDict = {0:[0],1:[0,1]} - libDicts = (featureDict,dependenciesDict,{}) - c = sparseNBClassifier() - c.initializeModel([0,1],libDicts) - c.update(2,[14,1,3],[0,2]) - print c.counts - print c.predict([0,14],[0,1,2]) - c.storeModel('x') - d = sparseNBClassifier() - d.loadModel('x') - print c.counts - print d.counts - print 'Done' diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/spawnDaemon.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/spawnDaemon.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,36 +0,0 @@ -# http://stackoverflow.com/questions/972362/spawning-process-from-python/972383#972383 -import os - -def spawnDaemon(path_to_executable, *args): - """Spawn a completely detached subprocess (i.e., a daemon). - - E.g. for mark: - spawnDaemon("../bin/producenotify.py", "producenotify.py", "xx") - """ - # fork the first time (to make a non-session-leader child process) - try: - pid = os.fork() - except OSError, e: - raise RuntimeError("1st fork failed: %s [%d]" % (e.strerror, e.errno)) - if pid != 0: - # parent (calling) process is all done - return - - # detach from controlling terminal (to make child a session-leader) - os.setsid() - try: - pid = os.fork() - except OSError, e: - raise RuntimeError("2nd fork failed: %s [%d]" % (e.strerror, e.errno)) - raise Exception, "%s [%d]" % (e.strerror, e.errno) - if pid != 0: - # child process is all done - os._exit(0) - - # and finally let's execute the executable for the daemon! - try: - #os.execv(path_to_executable, [path_to_executable]) - os.execv(path_to_executable, args) - except Exception, e: - # oops, we're cut off from the world, let's just give up - os._exit(255) diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/stats.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/stats.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,155 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/stats.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# Statistics collector. - -''' -Created on Jul 9, 2012 - -@author: Daniel Kuehlwein -''' - -import logging,string -from cPickle import load,dump - -class Statistics(object): - ''' - Class for all the statistics - ''' - - def __init__(self,cutOff=500): - ''' - Constructor - ''' - self.logger = logging.getLogger('Statistics') - self.avgAUC = 0.0 - self.avgRecall100 = 0.0 - self.avgAvailable = 0.0 - self.avgDepNr = 0.0 - self.problems = 0.0 - self.cutOff = cutOff - self.recallData = [0]*cutOff - self.recall100Median = [] - self.recall100Data = [0]*cutOff - self.aucData = [] - self.premiseOccurenceCounter = {} - self.firstDepAppearance = {} - self.depAppearances = [] - - def update(self,predictions,dependencies,statementCounter): - """ - Evaluates AUC, dependencies, recall100 and number of available premises of a prediction. - """ - available = len(predictions) - predictions = predictions[:self.cutOff] - dependencies = set(dependencies) - # No Stats for if no dependencies - if len(dependencies) == 0: - self.logger.debug('No Dependencies for statement %s' % statementCounter ) - self.badPreds = [] - return - if len(predictions) < self.cutOff: - for i in range(len(predictions),self.cutOff): - self.recall100Data[i] += 1 - self.recallData[i] += 1 - for d in dependencies: - if self.premiseOccurenceCounter.has_key(d): - self.premiseOccurenceCounter[d] += 1 - else: - self.premiseOccurenceCounter[d] = 1 - if self.firstDepAppearance.has_key(d): - self.depAppearances.append(statementCounter-self.firstDepAppearance[d]) - else: - self.firstDepAppearance[d] = statementCounter - depNr = len(dependencies) - aucSum = 0. - posResults = 0. - positives, negatives = 0, 0 - recall100 = 0.0 - badPreds = [] - depsFound = [] - for index,pId in enumerate(predictions): - if pId in dependencies: #positive - posResults+=1 - positives+=1 - recall100 = index+1 - depsFound.append(pId) - if index > 200: - badPreds.append(pId) - else: - aucSum += posResults - negatives+=1 - # Update Recall and Recall100 stats - if depNr == positives: - self.recall100Data[index] += 1 - if depNr == 0: - self.recallData[index] += 1 - else: - self.recallData[index] += float(positives)/depNr - - if not depNr == positives: - depsFound = set(depsFound) - missing = [] - for dep in dependencies: - if not dep in depsFound: - missing.append(dep) - badPreds.append(dep) - recall100 = len(predictions)+1 - positives+=1 - self.logger.debug('Dependencies missing for %s in cutoff predictions! Estimating Statistics.',\ - string.join([str(dep) for dep in missing],',')) - - if positives == 0 or negatives == 0: - auc = 1.0 - else: - auc = aucSum/(negatives*positives) - - self.aucData.append(auc) - self.avgAUC += auc - self.avgRecall100 += recall100 - self.recall100Median.append(recall100) - self.problems += 1 - self.badPreds = badPreds - self.avgAvailable += available - self.avgDepNr += depNr - self.logger.info('Statement: %s: AUC: %s \t Needed: %s \t Recall100: %s \t Available: %s \t cutOff: %s',\ - statementCounter,round(100*auc,2),depNr,recall100,available,self.cutOff) - - def printAvg(self): - self.logger.info('Average results:') - #self.logger.info('avgAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t cutOff:%s', \ - # round(100*self.avgAUC/self.problems,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),self.cutOff) - # HACK FOR PAPER - assert len(self.aucData) == len(self.recall100Median) - nrDataPoints = len(self.aucData) - if nrDataPoints == 0: - return "No data points" - if nrDataPoints % 2 == 1: - medianAUC = sorted(self.aucData)[nrDataPoints/2 + 1] - else: - medianAUC = float(sorted(self.aucData)[nrDataPoints/2] + sorted(self.aucData)[nrDataPoints/2 + 1])/2 - #nrDataPoints = len(self.recall100Median) - if nrDataPoints % 2 == 1: - medianrecall100 = sorted(self.recall100Median)[nrDataPoints/2 + 1] - else: - medianrecall100 = float(sorted(self.recall100Median)[nrDataPoints/2] + sorted(self.recall100Median)[nrDataPoints/2 + 1])/2 - - returnString = 'avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff: %s' %\ - (round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff) - self.logger.info(returnString) - return returnString - - """ - self.logger.info('avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff:%s', \ - round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff) - """ - - def save(self,fileName): - oStream = open(fileName, 'wb') - dump((self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter),oStream) - oStream.close() - def load(self,fileName): - iStream = open(fileName, 'rb') - self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter = load(iStream) - iStream.close() diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/tester.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/tester.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,281 +0,0 @@ -''' -Created on Jan 11, 2013 - -Searches for the best parameters. - -@author: Daniel Kuehlwein -''' - -import logging,sys,os -from multiprocessing import Process,Queue,current_process,cpu_count -from mash import mash - -def worker(inQueue, outQueue): - for func, args in iter(inQueue.get, 'STOP'): - result = func(*args) - #print '%s says that %s%s = %s' % (current_process().name, func.__name__, args, result) - outQueue.put(result) - -def run_mash(runId,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=True): - # Init - runId = str(runId) - predictionFile = predictionFile + runId - args = ['--statistics','--init','--inputDir',inputDir,'--log',logFile,'--theoryFile','../tmp/t'+runId,'--modelFile','../tmp/m'+runId,'--dictsFile','../tmp/d'+runId, - '--theoryDefValPos',str(theoryDefValPos),'--theoryDefValNeg',str(theoryDefValNeg),'--theoryPosWeight',str(theoryPosWeight),\ - '--NBDefaultPriorWeight',str(NBDefaultPriorWeight),'--NBDefVal',str(NBDefVal),'--NBPosWeight',str(NBPosWeight)] - if learnTheories: - args += ['--learnTheories'] - if sineFeatures: - args += ['--sineFeatures','--sineWeight',str(sineWeight)] - if not predef == '': - args += ['--predef',predef] - if quit: - args += ['-q'] - #print args - mash(args) - # Run - args = ['-i',inputFile,'-p',predictionFile,'--statistics','--cutOff','1024','--log',logFile,'--theoryFile','../tmp/t'+runId,'--modelFile','../tmp/m'+runId,'--dictsFile','../tmp/d'+runId,\ - '--theoryDefValPos',str(theoryDefValPos),'--theoryDefValNeg',str(theoryDefValNeg),'--theoryPosWeight',str(theoryPosWeight),\ - '--NBDefaultPriorWeight',str(NBDefaultPriorWeight),'--NBDefVal',str(NBDefVal),'--NBPosWeight',str(NBPosWeight)] - if learnTheories: - args += ['--learnTheories'] - if sineFeatures: - args += ['--sineFeatures','--sineWeight',str(sineWeight)] - if not predef == '': - args += ['--predef',predef] - if quit: - args += ['-q'] - #print args - mash(args) - - # Get Results - IS = open(logFile,'r') - lines = IS.readlines() - tmpRes = lines[-1].split() - avgAuc = tmpRes[5] - medianAuc = tmpRes[7] - avgRecall100 = tmpRes[11] - medianRecall100 = tmpRes[13] - tmpTheoryRes = lines[-3].split() - if learnTheories: - avgTheoryPrecision = tmpTheoryRes[5] - avgTheoryRecall100 = tmpTheoryRes[7] - avgTheoryRecall = tmpTheoryRes[9] - avgTheoryPredictedPercent = tmpTheoryRes[11] - else: - avgTheoryPrecision = 'NA' - avgTheoryRecall100 = 'NA' - avgTheoryRecall = 'NA' - avgTheoryPredictedPercent = 'NA' - IS.close() - - # Delete old models - os.remove(logFile) - os.remove(predictionFile) - if learnTheories: - os.remove('../tmp/t'+runId) - os.remove('../tmp/m'+runId) - os.remove('../tmp/d'+runId) - - outFile = open('tester','a') - #print 'avgAuc %s avgRecall100 %s avgTheoryPrecision %s avgTheoryRecall100 %s avgTheoryRecall %s avgTheoryPredictedPercent %s' - outFile.write('\t'.join([str(learnTheories),str(theoryDefValPos),str(theoryDefValNeg),str(theoryPosWeight),\ - str(NBDefaultPriorWeight),str(NBDefVal),str(NBPosWeight),str(sineFeatures),str(sineWeight),\ - str(avgAuc),str(medianAuc),str(avgRecall100),str(medianRecall100),str(avgTheoryPrecision),\ - str(avgTheoryRecall100),str(avgTheoryRecall),str(avgTheoryPredictedPercent)])+'\n') - outFile.close() - print learnTheories,'\t',theoryDefValPos,'\t',theoryDefValNeg,'\t',theoryPosWeight,'\t',\ - NBDefaultPriorWeight,'\t',NBDefVal,'\t',NBPosWeight,'\t',\ - sineFeatures,'\t',sineWeight,'\t',\ - avgAuc,'\t',medianAuc,'\t',avgRecall100,'\t',medianRecall100,'\t',\ - avgTheoryPrecision,'\t',avgTheoryRecall100,'\t',avgTheoryRecall,'\t',avgTheoryPredictedPercent - return learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,\ - avgAuc,avgRecall100,avgTheoryPrecision,avgTheoryRecall100,avgTheoryRecall,avgTheoryPredictedPercent - -def update_best_params(avgRecall100,bestAvgRecall100,\ - bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight,\ - bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,sineFeatures,sineWeight,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight): - if avgRecall100 > bestAvgRecall100: - bestAvgRecall100 = avgRecall100 - bestNBDefaultPriorWeight = NBDefaultPriorWeight - bestNBDefVal = NBDefVal - bestNBPosWeight = NBPosWeight - bestSineFeatures = sineFeatures - bestSineWeight = sineWeight - return bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight - -if __name__ == '__main__': - cores = cpu_count() - #cores = 1 - # Options - depFile = 'mash_dependencies' - predef = '' - outputDir = '../tmp/' - numberOfPredictions = 1024 - - learnTheoriesRange = [True,False] - theoryDefValPosRange = [-x for x in range(1,20)] - theoryDefValNegRange = [-x for x in range(1,20)] - theoryPosWeightRange = [x for x in range(1,10)] - - NBDefaultPriorWeightRange = [10*x for x in range(10)] - NBDefValRange = [-x for x in range(1,20)] - NBPosWeightRange = [10*x for x in range(1,10)] - sineFeaturesRange = [True,False] - sineWeightRange = [0.1,0.25,0.5,0.75,1.0] - - """ - # Test 1 - inputFile = '../data/20121227b/Auth/mash_commands' - inputDir = '../data/20121227b/Auth/' - predictionFile = '../tmp/auth.pred' - logFile = '../tmp/auth.log' - learnTheories = True - theoryDefValPos = -7.5 - theoryDefValNeg = -15.0 - theoryPosWeight = 10.0 - NBDefaultPriorWeight = 20.0 - NBDefVal =- 15.0 - NBPosWeight = 10.0 - sineFeatures = True - sineWeight = 0.5 - - task_queue = Queue() - done_queue = Queue() - - runs = 0 - for inputDir in ['../data/20121227b/Auth/']: - problemId = inputDir.split('/')[-2] - inputFile = os.path.join(inputDir,'mash_commands') - predictionFile = os.path.join('../tmp/',problemId+'.pred') - logFile = os.path.join('../tmp/',problemId+'.log') - learnTheories = True - theoryDefValPos = -7.5 - theoryDefValNeg = -15.0 - theoryPosWeight = 10.0 - - bestAvgRecall100 = 0.0 - bestNBDefaultPriorWeight = 1.0 - bestNBDefVal = 1.0 - bestNBPosWeight = 1.0 - bestSineFeatures = False - bestSineWeight = 0.0 - bestlearnTheories = True - besttheoryDefValPos = 1.0 - besttheoryDefValNeg = -15.0 - besttheoryPosWeight = 5.0 - for theoryPosWeight in theoryPosWeightRange: - for theoryDefValNeg in theoryDefValNegRange: - for NBDefaultPriorWeight in NBDefaultPriorWeightRange: - for NBDefVal in NBDefValRange: - for NBPosWeight in NBPosWeightRange: - for sineFeatures in sineFeaturesRange: - if sineFeatures: - for sineWeight in sineWeightRange: - localLogFile = logFile+str(runs) - task_queue.put((run_mash,(runs,inputDir, localLogFile, predictionFile, learnTheories, theoryDefValPos, theoryDefValNeg, theoryPosWeight, NBDefaultPriorWeight, NBDefVal, NBPosWeight, sineFeatures, sineWeight))) - runs += 1 - else: - localLogFile = logFile+str(runs) - task_queue.put((run_mash,(runs,inputDir, localLogFile, predictionFile, learnTheories, theoryDefValPos, theoryDefValNeg, theoryPosWeight, NBDefaultPriorWeight, NBDefVal, NBPosWeight, sineFeatures, sineWeight))) - runs += 1 - # Start worker processes - processes = [] - for _i in range(cores): - process = Process(target=worker, args=(task_queue, done_queue)) - process.start() - processes.append(process) - - for _i in range(runs): - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,\ - avgAuc,avgRecall100,avgTheoryPrecision,avgTheoryRecall100,avgTheoryRecall,avgTheoryPredictedPercent = done_queue.get() - bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight = update_best_params(avgRecall100,bestAvgRecall100,\ - bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight,\ - bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,sineFeatures,sineWeight,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight) - print 'bestAvgRecall100 %s bestNBDefaultPriorWeight %s bestNBDefVal %s bestNBPosWeight %s bestSineFeatures %s bestSineWeight %s',bestAvgRecall100,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight - - """ - # Test 2 - #inputDir = '../data/20130118/Jinja' - inputDir = '../data/notheory/Prob' - inputFile = inputDir+'/mash_commands' - #inputFile = inputDir+'/mash_prover_commands' - - #depFile = 'mash_prover_dependencies' - depFile = 'mash_dependencies' - outputDir = '../tmp/' - numberOfPredictions = 1024 - predictionFile = '../tmp/auth.pred' - logFile = '../tmp/auth.log' - learnTheories = False - theoryDefValPos = -7.5 - theoryDefValNeg = -10.0 - theoryPosWeight = 2.0 - NBDefaultPriorWeight = 20.0 - NBDefVal =- 15.0 - NBPosWeight = 10.0 - sineFeatures = False - sineWeight = 0.5 - quiet = False - print inputDir - - #predef = inputDir+'/mash_prover_suggestions' - predef = inputDir+'/mash_suggestions' - print 'Mash Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - - #""" - predef = inputDir+'/mesh_suggestions' - #predef = inputDir+'/mesh_prover_suggestions' - print 'Mesh Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - #""" - predef = inputDir+'/mepo_suggestions' - print 'Mepo Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - - """ - inputFile = inputDir+'/mash_prover_commands' - depFile = 'mash_prover_dependencies' - - predef = inputDir+'/mash_prover_suggestions' - print 'Mash Prover Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - - predef = inputDir+'/mesh_prover_suggestions' - print 'Mesh Prover Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - - predef = inputDir+'/mepo_suggestions' - print 'Mepo Prover Isar' - run_mash(0,inputDir,logFile,predictionFile,predef,\ - learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\ - NBDefaultPriorWeight,NBDefVal,NBPosWeight,\ - sineFeatures,sineWeight,quiet=quiet) - #""" \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,151 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# An updatable sparse naive Bayes classifier. - -''' -Created on Dec 26, 2012 - -@author: Daniel Kuehlwein -''' - -from singleNaiveBayes import singleNBClassifier -from cPickle import load,dump -import sys,logging - -class TheoryModels(object): - ''' - MetaClass for all the theory models. - ''' - - - def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0): - ''' - Constructor - ''' - self.theoryModels = {} - # Model Params - self.defValPos = defValPos - self.defValNeg = defValNeg - self.posWeight = posWeight - self.theoryDict = {} - self.accessibleTheories = set([]) - self.currentTheory = None - - def init(self,depFile,dicts): - logger = logging.getLogger('TheoryModels') - IS = open(depFile,'r') - for line in IS: - line = line.split(':') - name = line[0] - theory = name.split('.')[0] - # Name Id - if not dicts.nameIdDict.has_key(name): - logger.warning('%s is missing in nameIdDict. Aborting.',name) - sys.exit(-1) - - nameId = dicts.nameIdDict[name] - features = dicts.featureDict[nameId] - if not self.theoryDict.has_key(theory): - assert not theory == self.currentTheory - if not self.currentTheory == None: - self.accessibleTheories.add(self.currentTheory) - self.currentTheory = theory - self.theoryDict[theory] = set([nameId]) - theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight) - self.theoryModels[theory] = theoryModel - else: - self.theoryDict[theory] = self.theoryDict[theory].union([nameId]) - - # Find the actually used theories - usedtheories = [] - dependencies = line[1].split() - if len(dependencies) == 0: - continue - for dep in dependencies: - depId = dicts.nameIdDict[dep.strip()] - deptheory = dep.split('.')[0] - usedtheories.append(deptheory) - if not self.theoryDict.has_key(deptheory): - self.theoryDict[deptheory] = set([depId]) - else: - self.theoryDict[deptheory] = self.theoryDict[deptheory].union([depId]) - - # Update theoryModels - self.theoryModels[self.currentTheory].update(features,self.currentTheory in usedtheories) - for a in self.accessibleTheories: - self.theoryModels[a].update(dicts.featureDict[nameId],a in usedtheories) - IS.close() - - def overwrite(self,problemId,newDependencies,dicts): - features = dicts.featureDict[problemId] - unExpAccessibles = dicts.accessibleDict[problemId] - accessibles = dicts.expand_accessibles(unExpAccessibles) - accTheories = [] - for x in accessibles: - xArt = (dicts.idNameDict[x]).split('.')[0] - accTheories.append(xArt) - oldTheories = set([x.split('.')[0] for x in dicts.dependenciesDict[problemId]]) - newTheories = set([x.split('.')[0] for x in newDependencies]) - for a in self.accTheories: - self.theoryModels[a].overwrite(features,a in oldTheories,a in newTheories) - - def delete(self,problemId,features,dependencies,dicts): - tmp = [dicts.idNameDict[x] for x in dependencies] - usedTheories = set([x.split('.')[0] for x in tmp]) - for a in self.accessibleTheories: - self.theoryModels[a].delete(features,a in usedTheories) - - def update(self,problemId,features,dependencies,dicts): - # TODO: Implicit assumption that self.accessibleTheories contains all accessible theories! - currentTheory = (dicts.idNameDict[problemId]).split('.')[0] - # Create new theory model, if there is a new theory - if not self.theoryDict.has_key(currentTheory): - assert not currentTheory == self.currentTheory - if not currentTheory == None: - self.theoryDict[currentTheory] = [] - self.currentTheory = currentTheory - theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight) - self.theoryModels[currentTheory] = theoryModel - self.accessibleTheories.add(self.currentTheory) - self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories) - - def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories): - # Find the actually used theories - tmp = [dicts.idNameDict[x] for x in dependencies] - usedTheories = set([x.split('.')[0] for x in tmp]) - if not len(usedTheories) == 0: - for a in accessibleTheories: - self.theoryModels[a].update(features,a in usedTheories) - - def predict(self,features,accessibles,dicts): - """ - Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories. - """ - self.accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles]) - - # Predict Theories - predictedTheories = [self.currentTheory] - for a in self.accessibleTheories: - if self.theoryModels[a].predict_sparse(features): - #if theoryModels[a].predict(dicts.featureDict[nameId]): - predictedTheories.append(a) - predictedTheories = set(predictedTheories) - - # Delete accessibles in unpredicted theories - newAcc = [] - for x in accessibles: - xArt = (dicts.idNameDict[x]).split('.')[0] - if xArt in predictedTheories: - newAcc.append(x) - return predictedTheories,newAcc - - def save(self,fileName): - outStream = open(fileName, 'wb') - dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight),outStream) - outStream.close() - def load(self,fileName): - inStream = open(fileName, 'rb') - self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight = load(inStream) - inStream.close() \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py Sat Jun 28 22:13:23 2014 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,70 +0,0 @@ -# Title: HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py -# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen -# Copyright 2012 -# -# An updatable sparse naive Bayes classifier. - -''' -Created on Dec 26, 2012 - -@author: Daniel Kuehlwein -''' - -from cPickle import load,dump -import logging,string - -class TheoryStatistics(object): - ''' - Stores statistics for theory lvl predictions - ''' - - - def __init__(self): - ''' - Constructor - ''' - self.logger = logging.getLogger('TheoryStatistics') - self.count = 0 - self.precision = 0.0 - self.recall100 = 0 - self.recall = 0.0 - self.predicted = 0.0 - self.predictedPercent = 0.0 - - def update(self,currentTheory,predictedTheories,usedTheories,nrAvailableTheories): - self.count += 1 - allPredTheories = predictedTheories.union([currentTheory]) - if set(usedTheories).issubset(allPredTheories): - self.recall100 += 1 - localPredicted = len(allPredTheories) - self.predicted += localPredicted - localPredictedPercent = float(localPredicted)/nrAvailableTheories - self.predictedPercent += localPredictedPercent - localPrec = float(len(set(usedTheories).intersection(allPredTheories))) / localPredicted - self.precision += localPrec - if len(set(usedTheories)) == 0: - localRecall = 1.0 - else: - localRecall = float(len(set(usedTheories).intersection(allPredTheories))) / len(set(usedTheories)) - self.recall += localRecall - self.logger.info('Theory prediction results:') - self.logger.info('Problem: %s \t Recall100: %s \t Precision: %s \t Recall: %s \t PredictedTeoriesPercent: %s PredictedTeories: %s',\ - self.count,self.recall100,round(localPrec,2),round(localRecall,2),round(localPredictedPercent,2),localPredicted) - - def printAvg(self): - self.logger.info('Average theory results:') - self.logger.info('avgPrecision: %s \t avgRecall100: %s \t avgRecall: %s \t avgPredictedPercent: %s \t avgPredicted: %s', \ - round(self.precision/self.count,2),\ - round(float(self.recall100)/self.count,2),\ - round(self.recall/self.count,2),\ - round(self.predictedPercent /self.count,2),\ - round(self.predicted /self.count,2)) - - def save(self,fileName): - oStream = open(fileName, 'wb') - dump((self.count,self.precision,self.recall100,self.recall,self.predicted),oStream) - oStream.close() - def load(self,fileName): - iStream = open(fileName, 'rb') - self.count,self.precision,self.recall100,self.recall,self.predicted = load(iStream) - iStream.close() \ No newline at end of file diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Sat Jun 28 22:13:23 2014 +0200 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Sun Jun 29 18:28:27 2014 +0200 @@ -33,19 +33,16 @@ val decode_str : string -> string val decode_strs : string -> string list val encode_features : (string * real) list -> string - val extract_suggestions : string -> string * (string * real) list datatype mash_engine = - MaSh_Py - | MaSh_SML_kNN - | MaSh_SML_kNN_Ext - | MaSh_SML_NB - | MaSh_SML_NB_Ext + MaSh_kNN + | MaSh_kNN_Ext + | MaSh_NB + | MaSh_NB_Ext val is_mash_enabled : unit -> bool val the_mash_engine : unit -> mash_engine - val mash_unlearn : Proof.context -> params -> unit val nickname_of_thm : thm -> string val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list @@ -69,18 +66,20 @@ ('b * thm) list -> ('b * thm) list -> ('b * thm) list * ('b * thm) list val mash_suggested_facts : Proof.context -> theory -> params -> int -> term list -> term -> raw_fact list -> fact list * fact list + + val mash_unlearn : unit -> unit val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit - val mash_learn_facts : Proof.context -> params -> string -> bool -> int -> bool -> Time.time -> + val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time -> raw_fact list -> string val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit - val mash_can_suggest_facts : Proof.context -> bool -> bool + val mash_can_suggest_facts : Proof.context -> bool val generous_max_suggestions : int -> int val mepo_weight : real val mash_weight : real val relevant_facts : Proof.context -> params -> string -> int -> fact_override -> term list -> term -> raw_fact list -> (string * fact list) list - val kill_learners : Proof.context -> params -> unit + val kill_learners : unit -> unit val running_learners : unit -> unit end; @@ -140,84 +139,28 @@ end datatype mash_engine = - MaSh_Py -| MaSh_SML_kNN -| MaSh_SML_kNN_Ext -| MaSh_SML_NB -| MaSh_SML_NB_Ext + MaSh_kNN +| MaSh_kNN_Ext +| MaSh_NB +| MaSh_NB_Ext fun mash_engine () = let val flag1 = Options.default_string @{system_option MaSh} in (case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of - "yes" => SOME MaSh_SML_NB - | "py" => SOME MaSh_Py - | "sml" => SOME MaSh_SML_NB - | "sml_knn" => SOME MaSh_SML_kNN - | "sml_knn_ext" => SOME MaSh_SML_kNN_Ext - | "sml_nb" => SOME MaSh_SML_NB - | "sml_nb_ext" => SOME MaSh_SML_NB_Ext + "yes" => SOME MaSh_NB + | "sml" => SOME MaSh_NB + | "knn" => SOME MaSh_kNN + | "knn_ext" => SOME MaSh_kNN_Ext + | "nb" => SOME MaSh_NB + | "nb_ext" => SOME MaSh_NB_Ext | _ => NONE) end val is_mash_enabled = is_some o mash_engine -val the_mash_engine = the_default MaSh_SML_NB o mash_engine +val the_mash_engine = the_default MaSh_NB o mash_engine -(*** Low-level communication with the Python version of MaSh ***) - -val save_models_arg = "--saveModels" -val shutdown_server_arg = "--shutdownServer" - -fun wipe_out_file file = ignore (try (File.rm o Path.explode) file) - -fun write_file banner (xs, f) path = - (case banner of SOME s => File.write path s | NONE => (); - xs |> chunk_list 500 |> List.app (File.append path o implode o map f)) - handle IO.Io _ => () - -fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs = - let - val (temp_dir, serial) = - if overlord then (getenv "ISABELLE_HOME_USER", "") - else (getenv "ISABELLE_TMP", serial_string ()) - val log_file = temp_dir ^ "/mash_log" ^ serial - val err_file = temp_dir ^ "/mash_err" ^ serial - val sugg_file = temp_dir ^ "/mash_suggs" ^ serial - val sugg_path = Path.explode sugg_file - val cmd_file = temp_dir ^ "/mash_commands" ^ serial - val cmd_path = Path.explode cmd_file - val model_dir = File.shell_path (mash_state_dir ()) - - val command = - "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \ - \PYTHONDONTWRITEBYTECODE=y ./mash.py\ - \ --quiet\ - \ --port=$MASH_PORT\ - \ --outputDir " ^ model_dir ^ - " --modelFile=" ^ model_dir ^ "/model.pickle\ - \ --dictsFile=" ^ model_dir ^ "/dict.pickle\ - \ --log " ^ log_file ^ - " --inputFile " ^ cmd_file ^ - " --predictions " ^ sugg_file ^ - (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^ " >& " ^ err_file ^ - (if background then " &" else "") - - fun run_on () = - (Isabelle_System.bash command - |> tap (fn _ => - (case try File.read (Path.explode err_file) |> the_default "" of - "" => trace_msg ctxt (K "Done") - | s => warning ("MaSh error: " ^ elide_string 1000 s))); - read_suggs (fn () => try File.read_lines sugg_path |> these)) - - fun clean_up () = - if overlord then () else List.app wipe_out_file [err_file, sugg_file, cmd_file] - in - write_file (SOME "") ([], K "") sugg_path; - write_file (SOME "") write_cmds cmd_path; - trace_msg ctxt (fn () => "Running " ^ command); - with_cleanup clean_up run_on () - end +(*** Maintenance of the persistent, string-based state ***) fun meta_char c = if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse @@ -246,70 +189,10 @@ val encode_features = map encode_feature #> space_implode " " -fun str_of_learn (name, parents, feats, deps) = - "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ encode_strs feats ^ "; " ^ - encode_strs deps ^ "\n" - -fun str_of_relearn (name, deps) = "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n" - -fun str_of_query max_suggs (learns, parents, feats) = - implode (map str_of_learn learns) ^ - "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^ encode_features feats ^ "\n" - -(* The suggested weights do not make much sense. *) -fun extract_suggestion sugg = - (case space_explode "=" sugg of - [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0) - | [name] => SOME (decode_str name, 1.0) - | _ => NONE) - -fun extract_suggestions line = - (case space_explode ":" line of - [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs)) - | _ => ("", [])) - -structure MaSh_Py = -struct - -fun shutdown ctxt overlord = - (trace_msg ctxt (K "MaSh_Py shutdown"); - run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ())) -fun save ctxt overlord = - (trace_msg ctxt (K "MaSh_Py save"); - run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())) - -fun unlearn ctxt overlord = - (trace_msg ctxt (K "MaSh_Py unlearn"); - shutdown ctxt overlord; - wipe_out_mash_state_dir ()) - -fun learn _ _ _ [] = () - | learn ctxt overlord save learns = - (trace_msg ctxt (fn () => - "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}"); - run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn) - (K ())) +(*** Isabelle-agnostic machine learning ***) -fun relearn _ _ _ [] = () - | relearn ctxt overlord save relearns = - (trace_msg ctxt (fn () => "MaSh_Py relearn " ^ - elide_string 1000 (space_implode " " (map #1 relearns))); - run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false - (relearns, str_of_relearn) (K ())) - -fun query ctxt overlord max_suggs (query as (_, _, feats)) = - (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats); - run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs => - (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs)))) - handle List.Empty => []) - -end; - - -(*** Standard ML version of MaSh ***) - -structure MaSh_SML = +structure MaSh = struct fun heap cmp bnd al a = @@ -542,25 +425,6 @@ end (* experimental *) -fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs goal_feats = - let - fun name_of_fact j = "f" ^ string_of_int j - fun fact_of_name s = the (Int.fromString (unprefix "f" s)) - fun name_of_feature j = "F" ^ string_of_int j - fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)] - - val learns = map (fn j => (name_of_fact j, parents_of j, - map name_of_feature (Vector.sub (featss, j)), - map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1) - val parents' = parents_of num_facts - in - MaSh_Py.unlearn ctxt overlord; - OS.Process.sleep (seconds 2.0); (* hack *) - MaSh_Py.query ctxt overlord max_suggs (learns, parents', goal_feats) - |> map (apfst fact_of_name) - end - -(* experimental *) fun external_tool tool max_suggs learns goal_feats = let val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *) @@ -600,17 +464,17 @@ val naive_bayes_ext = external_tool "predict/nbayes" fun query_external ctxt engine max_suggs learns goal_feats = - (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_features goal_feats); + (trace_msg ctxt (fn () => "MaSh query external " ^ encode_features goal_feats); (case engine of - MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats - | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats)) + MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats + | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats)) fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss) (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats = - (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_features goal_feats ^ " from {" ^ + (trace_msg ctxt (fn () => "MaSh query internal " ^ encode_features goal_feats ^ " from {" ^ elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}"); (case engine of - MaSh_SML_kNN => + MaSh_kNN => let val feat_facts = Array.array (num_feats, []) val _ = @@ -620,7 +484,7 @@ in k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats end - | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats) + | MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats) |> map (curry Vector.sub fact_names o fst)) end; @@ -698,7 +562,7 @@ Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)] in ((fact_names, featss, depss), - MaSh_SML.learn_facts freqs0 num_facts0 num_facts num_feats depss featss) + MaSh.learn_facts freqs0 num_facts0 num_facts num_feats depss featss) end fun reorder_learns (num_facts, fact_tab) learns = @@ -734,7 +598,7 @@ | _ => NONE) | _ => NONE) -fun load_state ctxt overlord (time_state as (memory_time, _)) = +fun load_state ctxt (time_state as (memory_time, _)) = let val path = mash_state_file () in (case try OS.FileSys.modTime (Path.implode path) of NONE => time_state @@ -758,11 +622,7 @@ EQUAL => try_graph ctxt "loading state" empty_G_etc (fn () => fold extract_line_and_add_node node_lines empty_G_etc) - | LESS => - (* cannot parse old file *) - (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord - else wipe_out_mash_state_dir (); - empty_G_etc) + | LESS => (wipe_out_mash_state_dir (); empty_G_etc) (* cannot parse old file *) | GREATER => raise FILE_VERSION_TOO_NEW ()) val (ffds, freqs) = @@ -794,7 +654,9 @@ SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names []) | NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G [])) in - write_file banner (entries, str_of_entry) path; + (case banner of SOME s => File.write path s | NONE => (); + entries |> chunk_list 500 |> List.app (File.append path o implode o map str_of_entry)) + handle IO.Io _ => (); trace_msg ctxt (fn () => "Saved fact graph (" ^ graph_info access_G ^ (case dirty_facts of @@ -808,25 +670,19 @@ in -fun map_state ctxt overlord f = - Synchronized.change global_state (load_state ctxt overlord ##> f #> save_state ctxt) +fun map_state ctxt f = + Synchronized.change global_state (load_state ctxt ##> f #> save_state ctxt) handle FILE_VERSION_TOO_NEW () => () -fun peek_state ctxt overlord f = - Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f) +fun peek_state ctxt = + Synchronized.change_result global_state (perhaps (try (load_state ctxt)) #> `snd) -fun clear_state ctxt overlord = - (* "MaSh_Py.unlearn" also removes the state file *) +fun clear_state () = Synchronized.change global_state (fn _ => - (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord - else wipe_out_mash_state_dir (); - (Time.zeroTime, empty_state))) + (wipe_out_mash_state_dir (); (Time.zeroTime, empty_state))) end -fun mash_unlearn ctxt ({overlord, ...} : params) = - (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.") - (*** Isabelle helpers ***) @@ -1284,7 +1140,7 @@ (mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown) end -fun mash_suggested_facts ctxt thy ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts = +fun mash_suggested_facts ctxt thy ({debug, ...} : params) max_suggs hyp_ts concl_t facts = let val thy_name = Context.theory_name thy val engine = the_mash_engine () @@ -1323,50 +1179,40 @@ (parents, feats) end - val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) = - peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} => - ((access_G, xtabs, ffds, freqs), - if Graph.is_empty access_G then - (trace_msg ctxt (K "Nothing has been learned yet"); []) - else if engine = MaSh_Py then - let val (parents, feats) = query_args access_G in - MaSh_Py.query ctxt overlord max_suggs ([], parents, feats) - |> map fst - end - else - [])) + val {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} = + peek_state ctxt - val sml_suggs = - if engine = MaSh_Py then - [] - else - let - val (parents, goal_feats) = query_args access_G - val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) - in - if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then - let - val learns = - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G - in - MaSh_SML.query_external ctxt engine max_suggs learns goal_feats - end - else - let - val int_goal_feats = - map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats - in - MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts - max_suggs goal_feats int_goal_feats - end - end + val suggs = + let + val (parents, goal_feats) = query_args access_G + val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents) + in + if engine = MaSh_kNN_Ext orelse engine = MaSh_NB_Ext then + let + val learns = + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G + in + MaSh.query_external ctxt engine max_suggs learns goal_feats + end + else + let + val int_goal_feats = + map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats + in + MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs + goal_feats int_goal_feats + end + end val unknown = filter_out (is_fact_in_graph access_G o snd) facts in - find_mash_suggestions ctxt max_suggs (py_suggs @ sml_suggs) facts chained unknown + find_mash_suggestions ctxt max_suggs suggs facts chained unknown |> pairself (map fact_of_raw_fact) end +fun mash_unlearn () = + (clear_state (); Output.urgent_message "Reset MaSh.") + fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (access_G, (fact_xtab, feat_xtab)) = let fun maybe_learn_from from (accum as (parents, access_G)) = @@ -1413,35 +1259,30 @@ fun learned_proof_name () = Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string () -fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths = +fun mash_learn_proof ctxt ({timeout, ...} : params) t facts used_ths = if not (null used_ths) andalso is_mash_enabled () then launch_thread timeout (fn () => let val thy = Proof_Context.theory_of ctxt val feats = features_of ctxt thy (Local, General) [t] in - map_state ctxt overlord - (fn state as {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} => + map_state ctxt + (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} => let val parents = maximal_wrt_access_graph access_G facts val deps = used_ths |> filter (is_fact_in_graph access_G) |> map nickname_of_thm + + val name = learned_proof_name () + val (access_G', xtabs', rev_learns) = + add_node Automatic_Proof name parents feats deps (access_G, xtabs, []) + + val (ffds', freqs') = + recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs in - if the_mash_engine () = MaSh_Py then - (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state) - else - let - val name = learned_proof_name () - val (access_G', xtabs', rev_learns) = - add_node Automatic_Proof name parents feats deps (access_G, xtabs, []) - - val (ffds', freqs') = - recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs - in - {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs', - dirty_facts = Option.map (cons name) dirty_facts} - end + {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs', + dirty_facts = Option.map (cons name) dirty_facts} end); (true, "") end) @@ -1453,14 +1294,13 @@ val commit_timeout = seconds 30.0 (* The timeout is understood in a very relaxed fashion. *) -fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level - run_prover learn_timeout facts = +fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover + learn_timeout facts = let val timer = Timer.startRealTimer () fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout) - val engine = the_mash_engine () - val {access_G, ...} = peek_state ctxt overlord I + val {access_G, ...} = peek_state ctxt val is_in_access_G = is_fact_in_graph access_G o snd val no_new_facts = forall is_in_access_G facts in @@ -1511,18 +1351,13 @@ else recompute_ffds_freqs_from_access_G access_G xtabs in - if engine = MaSh_Py then - (MaSh_Py.learn ctxt overlord (save andalso null relearns) learns; - MaSh_Py.relearn ctxt overlord save relearns) - else - (); {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs', dirty_facts = dirty_facts} end fun commit last learns relearns flops = (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else (); - map_state ctxt overlord (do_commit (rev learns) relearns flops); + map_state ctxt (do_commit (rev learns) relearns flops); if not last andalso auto_level = 0 then let val num_proofs = length learns + length relearns in Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^ @@ -1633,7 +1468,7 @@ val prover = hd provers fun learn auto_level run_prover = - mash_learn_facts ctxt params prover true auto_level run_prover one_year facts + mash_learn_facts ctxt params prover auto_level run_prover one_year facts |> Output.urgent_message in if run_prover then @@ -1650,8 +1485,8 @@ learn 0 false) end -fun mash_can_suggest_facts ctxt overlord = - not (Graph.is_empty (#access_G (peek_state ctxt overlord I))) +fun mash_can_suggest_facts ctxt = + not (Graph.is_empty (#access_G (peek_state ctxt))) (* Generate more suggestions than requested, because some might be thrown out later for various reasons (e.g., duplicates). *) @@ -1666,7 +1501,7 @@ and Try. *) val min_secs_for_learning = 15 -fun relevant_facts ctxt (params as {verbose, overlord, learn, fact_filter, timeout, ...}) prover +fun relevant_facts ctxt (params as {verbose, learn, fact_filter, timeout, ...}) prover max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts = if not (subset (op =) (the_list fact_filter, fact_filters)) then error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".") @@ -1689,42 +1524,41 @@ else ()); launch_thread timeout - (fn () => (true, mash_learn_facts ctxt params prover true 2 false timeout facts)) + (fn () => (true, mash_learn_facts ctxt params prover 2 false timeout facts)) end else () fun maybe_learn () = - if is_mash_enabled () andalso learn then + if learn then let - val {access_G, xtabs = ((num_facts0, _), _), ...} = peek_state ctxt overlord I + val {access_G, xtabs = ((num_facts0, _), _), ...} = peek_state ctxt val is_in_access_G = is_fact_in_graph access_G o snd val min_num_facts_to_learn = length facts - num_facts0 in if min_num_facts_to_learn <= max_facts_to_learn_before_query then (case length (filter_out is_in_access_G facts) of - 0 => false + 0 => () | num_facts_to_learn => if num_facts_to_learn <= max_facts_to_learn_before_query then - (mash_learn_facts ctxt params prover false 2 false timeout facts - |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s)); - true) + mash_learn_facts ctxt params prover 2 false timeout facts + |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s)) else - (maybe_launch_thread num_facts_to_learn; false)) + maybe_launch_thread num_facts_to_learn) else - (maybe_launch_thread min_num_facts_to_learn; false) + maybe_launch_thread min_num_facts_to_learn end else - false + () - val (save, effective_fact_filter) = + val effective_fact_filter = (case fact_filter of - SOME ff => (ff <> mepoN andalso maybe_learn (), ff) + SOME ff => ff | NONE => if is_mash_enabled () then - (maybe_learn (), if mash_can_suggest_facts ctxt overlord then meshN else mepoN) + (maybe_learn (); if mash_can_suggest_facts ctxt then meshN else mepoN) else - (false, mepoN)) + mepoN) val unique_facts = drop_duplicate_facts facts val add_ths = Attrib.eval_thms ctxt add @@ -1754,7 +1588,6 @@ |> Par_List.map (apsnd (fn f => f ())) val mesh = mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess |> add_and_take in - if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else (); (case (fact_filter, mess) of (NONE, [(_, (mepo, _)), (_, (mash, _))]) => [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take), @@ -1762,9 +1595,7 @@ | _ => [(effective_fact_filter, mesh)]) end -fun kill_learners ctxt ({overlord, ...} : params) = - (Async_Manager.kill_threads MaShN "learner"; - if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ()) +fun kill_learners () = Async_Manager.kill_threads MaShN "learner" fun running_learners () = Async_Manager.running_threads MaShN "learner" diff -r 020cea57eaa4 -r 02c408aed5ee src/HOL/Tools/etc/options --- a/src/HOL/Tools/etc/options Sat Jun 28 22:13:23 2014 +0200 +++ b/src/HOL/Tools/etc/options Sun Jun 29 18:28:27 2014 +0200 @@ -36,4 +36,4 @@ -- "status of Z3 activation for non-commercial use (yes, no, unknown)" public option MaSh : string = "sml" - -- "machine learning engine to use by Sledgehammer (sml, sml_knn, sml_nb, py, none)" + -- "machine learning engine to use by Sledgehammer (sml, nb, knn, none)"