--- a/NEWS Sat Jun 28 22:13:23 2014 +0200
+++ b/NEWS Sun Jun 29 18:28:27 2014 +0200
@@ -433,9 +433,7 @@
and increase performance and reliability.
- MaSh and MeSh are now used by default together with the traditional
MePo (Meng-Paulson) relevance filter. To disable MaSh, set the "MaSh"
- system option in Plugin Options / Isabelle / General to "none". Other
- allowed values include "sml" (for the default SML engine) and "py"
- (for the old Python engine).
+ system option in Plugin Options / Isabelle / General to "none".
- New option:
smt_proofs
- Renamed options:
--- a/src/Doc/Sledgehammer/document/root.tex Sat Jun 28 22:13:23 2014 +0200
+++ b/src/Doc/Sledgehammer/document/root.tex Sun Jun 29 18:28:27 2014 +0200
@@ -1059,14 +1059,11 @@
The MaSh machine learner. Three learning engines are provided:
\begin{enum}
-\item[\labelitemi] \textbf{\textit{sml\_nb}} (also called \textbf{\textit{sml}}
+\item[\labelitemi] \textbf{\textit{nb}} (also called \textbf{\textit{sml}}
and \textbf{\textit{yes}}) is a Standard ML implementation of naive Bayes.
-\item[\labelitemi] \textbf{\textit{sml\_knn}} is a Standard ML implementation of
+\item[\labelitemi] \textbf{\textit{knn}} is a Standard ML implementation of
$k$-nearest neighbors.
-
-\item[\labelitemi] \textbf{\textit{py}} is a Python implementation of naive Bayes.
-The program is included with Isabelle as \texttt{mash.py}.
\end{enum}
In addition, the special value \textit{none} is used to disable machine learning by
@@ -1077,10 +1074,7 @@
\texttt{\$ISABELLE\_HOME\_USER/etc/settings} file, or via the ``MaSh'' option
under ``Plugins > Plugin Options > Isabelle > General'' in Isabelle/jEdit.
Persistent data for both engines is stored in the directory
-\texttt{\$ISABELLE\_HOME\_USER/mash}. When switching to the \textit{py} engine,
-you will need to invoke the \textit{relearn\_isar} subcommand
-(\S\ref{sledgehammer}) to synchronize the persistent databases on the
-Python side.
+\texttt{\$ISABELLE\_HOME\_USER/mash}.
\item[\labelitemi] \textbf{\textit{mesh}:} The MeSh filter, which combines the
rankings from MePo and MaSh.
--- a/src/HOL/TPTP/MaSh_Export.thy Sat Jun 28 22:13:23 2014 +0200
+++ b/src/HOL/TPTP/MaSh_Export.thy Sun Jun 29 18:28:27 2014 +0200
@@ -46,22 +46,22 @@
()
*}
-ML {* Options.put_default @{system_option MaSh} "sml_nb" *}
+ML {* Options.put_default @{system_option MaSh} "nb" *}
ML {*
if do_it then
generate_mash_suggestions @{context} params (range, step) thys max_suggestions
- (prefix ^ "mash_sml_nb_suggestions")
+ (prefix ^ "mash_nb_suggestions")
else
()
*}
-ML {* Options.put_default @{system_option MaSh} "sml_knn" *}
+ML {* Options.put_default @{system_option MaSh} "knn" *}
ML {*
if do_it then
generate_mash_suggestions @{context} params (range, step) thys max_suggestions
- (prefix ^ "mash_sml_knn_suggestions")
+ (prefix ^ "mash_knn_suggestions")
else
()
*}
--- a/src/HOL/TPTP/mash_eval.ML Sat Jun 28 22:13:23 2014 +0200
+++ b/src/HOL/TPTP/mash_eval.ML Sun Jun 29 18:28:27 2014 +0200
@@ -29,6 +29,7 @@
open Sledgehammer_Prover
open Sledgehammer_Prover_ATP
open Sledgehammer_Commands
+open MaSh_Export
val prefix = Library.prefix
@@ -38,9 +39,6 @@
val MeSh_ProverN = MeShN ^ "-Prover"
val IsarN = "Isar"
-fun in_range (from, to) j =
- j >= from andalso (to = NONE orelse j <= the to)
-
fun evaluate_mash_suggestions ctxt params range methods prob_dir_name mepo_file_name
mash_isar_file_name mash_prover_file_name mesh_isar_file_name mesh_prover_file_name
report_file_name =
--- a/src/HOL/TPTP/mash_export.ML Sat Jun 28 22:13:23 2014 +0200
+++ b/src/HOL/TPTP/mash_export.ML Sun Jun 29 18:28:27 2014 +0200
@@ -9,6 +9,9 @@
sig
type params = Sledgehammer_Prover.params
+ val in_range : int * int option -> int -> bool
+ val extract_suggestions : string -> string * (string * real) list
+
val generate_accessibility : Proof.context -> theory list -> string -> unit
val generate_features : Proof.context -> theory list -> string -> unit
val generate_isar_dependencies : Proof.context -> int * int option -> theory list -> string ->
@@ -37,6 +40,18 @@
fun in_range (from, to) j =
j >= from andalso (to = NONE orelse j <= the to)
+(* The suggested weights do not make much sense. *)
+fun extract_suggestion sugg =
+ (case space_explode "=" sugg of
+ [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0)
+ | [name] => SOME (decode_str name, 1.0)
+ | _ => NONE)
+
+fun extract_suggestions line =
+ (case space_explode ":" line of
+ [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
+ | _ => ("", []))
+
fun has_thm_thy th thy =
Context.theory_name thy = Context.theory_name (theory_of_thm th)
@@ -265,11 +280,11 @@
#> Sledgehammer_MePo.mepo_suggested_facts ctxt params max_suggs NONE hyp_ts concl_t)
fun generate_mash_suggestions ctxt params =
- (Sledgehammer_MaSh.mash_unlearn ctxt params;
+ (Sledgehammer_MaSh.mash_unlearn ();
generate_mepo_or_mash_suggestions
(fn ctxt => fn thy => fn params as {provers = prover :: _, ...} => fn max_suggs => fn hyp_ts =>
fn concl_t =>
- tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover true 2 false
+ tap (Sledgehammer_MaSh.mash_learn_facts ctxt params prover 2 false
Sledgehammer_Util.one_year)
#> Sledgehammer_MaSh.mash_suggested_facts ctxt thy params max_suggs hyp_ts concl_t
#> fst) ctxt params)
--- a/src/HOL/Tools/Sledgehammer/MaSh/etc/settings Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,8 +0,0 @@
-# -*- shell-script -*- :mode=shellscript:
-
-ISABELLE_SLEDGEHAMMER_MASH="$COMPONENT"
-
-# MASH=yes
-if [ -z "$MASH_PORT" ]; then
- MASH_PORT=9255
-fi
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/ExpandFeatures.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,163 +0,0 @@
-'''
-Created on Aug 21, 2013
-
-@author: daniel
-'''
-
-from math import log
-#from gensim import corpora, models, similarities
-
-class ExpandFeatures(object):
-
- def __init__(self,dicts):
- self.dicts = dicts
- self.featureMap = {}
- self.alpha = 0.1
- self.featureCounts = {}
- self.counter = 0
- self.corpus = []
-# self.LSIModel = models.lsimodel.LsiModel(self.corpus,num_topics=500)
-
- def initialize(self,dicts):
- self.dicts = dicts
- IS = open(dicts.accFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- #print 'name',name
- nameId = dicts.nameIdDict[name]
- features = dicts.featureDict[nameId]
- dependencies = dicts.dependenciesDict[nameId]
- x = [self.dicts.idNameDict[d] for d in dependencies]
- #print x
- self.update(features, dependencies)
- self.corpus.append([(x,1) for x in features.keys()])
- IS.close()
- print 'x'
- #self.LSIModel = models.lsimodel.LsiModel(self.corpus,num_topics=500)
- print self.LSIModel
- print 'y'
-
- def update(self,features,dependencies):
- self.counter += 1
- self.corpus.append([(x,1) for x in features.keys()])
- self.LSIModel.add_documents([[(x,1) for x in features.keys()]])
- """
- for f in features.iterkeys():
- try:
- self.featureCounts[f] += 1
- except:
- self.featureCounts[f] = 1
- if self.featureCounts[f] > 100:
- continue
- try:
- self.featureMap[f] = self.featureMap[f].intersection(features.keys())
- except:
- self.featureMap[f] = set(features.keys())
- #print 'fOld',len(fMap),self.featureCounts[f],len(dependencies)
-
- for d in dependencies[1:]:
- #print 'dep',self.dicts.idNameDict[d]
- dFeatures = self.dicts.featureDict[d]
- for df in dFeatures.iterkeys():
- if self.featureCounts.has_key(df):
- if self.featureCounts[df] > 20:
- continue
- else:
- print df
- try:
- fMap[df] += self.alpha * (1.0 - fMap[df])
- except:
- fMap[df] = self.alpha
- """
- #print 'fNew',len(fMap)
-
- def expand(self,features):
- #print self.corpus[:50]
- #print corpus
- #tfidfmodel = models.TfidfModel(self.corpus, normalize=True)
- #print features.keys()
- #tfidfcorpus = [tfidfmodel[x] for x in self.corpus]
- #newFeatures = LSI[[(x,1) for x in features.keys()]]
- #newFeatures = self.LSIModel[[(x,1) for x in features.keys()]]
- #print features
- #print newFeatures
- #print newFeatures
-
- """
- newFeatures = dict(features)
- for f in features.keys():
- try:
- fC = self.featureCounts[f]
- except:
- fC = 0.5
- newFeatures[f] = log(float(8+self.counter) / fC)
- #nrOfFeatures = float(len(features))
- addedCount = 0
- alpha = 0.2
- #"""
-
- """
- consideredFeatures = []
- while len(newFeatures) < 30:
- #alpha = alpha * 0.5
- minF = None
- minFrequence = 1000000
- for f in newFeatures.iterkeys():
- if f in consideredFeatures:
- continue
- try:
- if self.featureCounts[f] < minFrequence:
- minF = f
- except:
- pass
- if minF == None:
- break
- # Expand minimal feature
- consideredFeatures.append(minF)
- for expF in self.featureMap[minF]:
- if not newFeatures.has_key(expF):
- fC = self.featureCounts[minF]
- newFeatures[expF] = alpha*log(float(8+self.counter) / fC)
- #print features, newFeatures
- #"""
- """
- for f in features.iterkeys():
- try:
- self.featureCounts[f] += 1
- except:
- self.featureCounts[f] = 0
- if self.featureCounts[f] > 10:
- continue
- addedCount += 1
- try:
- fmap = self.featureMap[f]
- except:
- self.featureMap[f] = {}
- fmap = {}
- for nf,nv in fmap.iteritems():
- try:
- newFeatures[nf] += nv
- except:
- newFeatures[nf] = nv
- if addedCount > 0:
- for f,w in newFeatures.iteritems():
- newFeatures[f] = float(w)/addedCount
- #"""
- """
- deleteF = []
- for f,w in newFeatures.iteritems():
- if w < 0.1:
- deleteF.append(f)
- for f in deleteF:
- del newFeatures[f]
- """
- #print 'fold',len(features)
- #print 'fnew',len(newFeatures)
- #return dict(newFeatures)
- return features
-
-if __name__ == "__main__":
- pass
-
-
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/KNN.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,99 +0,0 @@
-'''
-Created on Aug 21, 2013
-
-@author: daniel
-'''
-
-from cPickle import dump,load
-from numpy import array
-from math import sqrt,log
-
-def cosine(f1,f2):
- f1Norm = 0.0
- for f in f1.keys():
- f1Norm += f1[f] * f1[f]
- #assert f1Norm = sum(map(lambda x,y: x*y,f1.itervalues(),f1.itervalues()))
- f1Norm = sqrt(f1Norm)
-
- f2Norm = 0.0
- for f in f2.keys():
- f2Norm += f2[f] * f2[f]
- f2Norm = sqrt(f2Norm)
-
- dotProduct = 0.0
- featureIntersection = set(f1.keys()) & set(f2.keys())
- for f in featureIntersection:
- dotProduct += f1[f] * f2[f]
- cosine = dotProduct / (f1Norm * f2Norm)
- return 1.0 - cosine
-
-def euclidean(f1,f2):
- diffSum = 0.0
- featureUnion = set(f1.keys()) | set(f2.keys())
- for f in featureUnion:
- try:
- f1Val = f1[f]
- except:
- f1Val = 0.0
- try:
- f2Val = f2[f]
- except:
- f2Val = 0.0
- diff = f1Val - f2Val
- diffSum += diff * diff
- #if f in f1.keys():
- # diffSum += log(2+self.pointCount/self.featureCounts[f]) * diff * diff
- #else:
- # diffSum += diff * diff
- #print diffSum,f1,f2
- return diffSum
-
-class KNN(object):
- '''
- A basic KNN ranker.
- '''
-
- def __init__(self,dicts,metric=cosine):
- '''
- Constructor
- '''
- self.points = dicts.featureDict
- self.metric = metric
-
- def initializeModel(self,_trainData,_dicts):
- """
- Build basic model from training data.
- """
- pass
-
- def update(self,dataPoint,features,dependencies):
- assert self.points[dataPoint] == features
-
- def overwrite(self,problemId,newDependencies,dicts):
- # Taken care of by dicts
- pass
-
- def delete(self,dataPoint,features,dependencies):
- # Taken care of by dicts
- pass
-
- def predict(self,features,accessibles,dicts):
- predictions = map(lambda x: self.metric(features,self.points[x]),accessibles)
- predictions = array(predictions)
- perm = predictions.argsort()
- return array(accessibles)[perm],predictions[perm]
-
- def save(self,fileName):
- OStream = open(fileName, 'wb')
- dump((self.points,self.metric),OStream)
- OStream.close()
-
- def load(self,fileName):
- OStream = open(fileName, 'rb')
- self.points,self.metric = load(OStream)
- OStream.close()
-
-if __name__ == '__main__':
- pass
-
-
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/KNNs.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,105 +0,0 @@
-'''
-Created on Aug 21, 2013
-
-@author: daniel
-'''
-
-from math import log
-from KNN import KNN,cosine
-from numpy import array
-
-class KNNAdaptPointFeatures(KNN):
-
- def __init__(self,dicts,metric=cosine,alpha = 0.05):
- self.points = dicts.featureDict
- self.metric = self.euclidean
- self.alpha = alpha
- self.count = 0
- self.featureCount = {}
-
- def initializeModel(self,trainData,dicts):
- """
- Build basic model from training data.
- """
- IS = open(dicts.accFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- nameId = dicts.nameIdDict[name]
- features = dicts.featureDict[nameId]
- dependencies = dicts.dependenciesDict[nameId]
- self.update(nameId, features, dependencies)
- IS.close()
-
- def update(self,dataPoint,features,dependencies):
- self.count += 1
- for f in features.iterkeys():
- try:
- self.featureCount[f] += 1
- except:
- self.featureCount[f] = 1
- for d in dependencies:
- dFeatures = self.points[d]
- featureUnion = set(dFeatures.keys()) | set(features.keys())
- for f in featureUnion:
- try:
- pVal = features[f]
- except:
- pVal = 0.0
- try:
- dVal = dFeatures[f]
- except:
- dVal = 0.0
- newDVal = dVal + self.alpha * (pVal - dVal)
- dFeatures[f] = newDVal
-
- def euclidean(self,f1,f2):
- diffSum = 0.0
- f1Set = set(f1.keys())
- featureUnion = f1Set | set(f2.keys())
- for f in featureUnion:
- if not self.featureCount.has_key(f):
- continue
- if self.featureCount[f] == 1:
- continue
- try:
- f1Val = f1[f]
- except:
- f1Val = 0.0
- try:
- f2Val = f2[f]
- except:
- f2Val = 0.0
- diff = f1Val - f2Val
- diffSum += diff * diff
- if f in f1Set:
- diffSum += log(2+self.count/self.featureCount[f]) * diff * diff
- else:
- diffSum += diff * diff
- #print diffSum,f1,f2
- return diffSum
-
-class KNNUrban(KNN):
- def __init__(self,dicts,metric=cosine,nrOfNeighbours = 40):
- self.points = dicts.featureDict
- self.metric = metric
- self.nrOfNeighbours = nrOfNeighbours # Ignored at the moment
-
- def predict(self,features,accessibles,dicts):
- predictions = map(lambda x: self.metric(features,self.points[x]),accessibles)
- pDict = dict(zip(accessibles,predictions))
- for a,p in zip(accessibles,predictions):
- aDeps = dicts.dependenciesDict[a]
- for d in aDeps:
- pDict[d] -= p
- predictions = []
- names = []
- for n,p in pDict.items():
- predictions.append(p)
- names.append(n)
- predictions = array(predictions)
- perm = predictions.argsort()
- return array(names)[perm],predictions[perm]
-
-
-
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/argparse.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,2357 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/argparse.py
-#
-# Argument parser. See copyright notice below.
-
-# -*- coding: utf-8 -*-
-
-# Copyright (C) 2006-2009 Steven J. Bethard <steven.bethard@gmail.com>.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may not
-# use this file except in compliance with the License. You may obtain a copy
-# of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Command-line parsing library
-
-This module is an optparse-inspired command-line parsing library that:
-
- - handles both optional and positional arguments
- - produces highly informative usage messages
- - supports parsers that dispatch to sub-parsers
-
-The following is a simple usage example that sums integers from the
-command-line and writes the result to a file::
-
- parser = argparse.ArgumentParser(
- description='sum the integers at the command line')
- parser.add_argument(
- 'integers', metavar='int', nargs='+', type=int,
- help='an integer to be summed')
- parser.add_argument(
- '--log', default=sys.stdout, type=argparse.FileType('w'),
- help='the file where the sum should be written')
- args = parser.parse_args()
- args.log.write('%s' % sum(args.integers))
- args.log.close()
-
-The module contains the following public classes:
-
- - ArgumentParser -- The main entry point for command-line parsing. As the
- example above shows, the add_argument() method is used to populate
- the parser with actions for optional and positional arguments. Then
- the parse_args() method is invoked to convert the args at the
- command-line into an object with attributes.
-
- - ArgumentError -- The exception raised by ArgumentParser objects when
- there are errors with the parser's actions. Errors raised while
- parsing the command-line are caught by ArgumentParser and emitted
- as command-line messages.
-
- - FileType -- A factory for defining types of files to be created. As the
- example above shows, instances of FileType are typically passed as
- the type= argument of add_argument() calls.
-
- - Action -- The base class for parser actions. Typically actions are
- selected by passing strings like 'store_true' or 'append_const' to
- the action= argument of add_argument(). However, for greater
- customization of ArgumentParser actions, subclasses of Action may
- be defined and passed as the action= argument.
-
- - HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter,
- ArgumentDefaultsHelpFormatter -- Formatter classes which
- may be passed as the formatter_class= argument to the
- ArgumentParser constructor. HelpFormatter is the default,
- RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser
- not to change the formatting for help text, and
- ArgumentDefaultsHelpFormatter adds information about argument defaults
- to the help.
-
-All other classes in this module are considered implementation details.
-(Also note that HelpFormatter and RawDescriptionHelpFormatter are only
-considered public as object names -- the API of the formatter objects is
-still considered an implementation detail.)
-"""
-
-__version__ = '1.1'
-__all__ = [
- 'ArgumentParser',
- 'ArgumentError',
- 'Namespace',
- 'Action',
- 'FileType',
- 'HelpFormatter',
- 'RawDescriptionHelpFormatter',
- 'RawTextHelpFormatter',
- 'ArgumentDefaultsHelpFormatter',
-]
-
-
-import copy as _copy
-import os as _os
-import re as _re
-import sys as _sys
-import textwrap as _textwrap
-
-from gettext import gettext as _
-
-try:
- _set = set
-except NameError:
- from sets import Set as _set
-
-try:
- _basestring = basestring
-except NameError:
- _basestring = str
-
-try:
- _sorted = sorted
-except NameError:
-
- def _sorted(iterable, reverse=False):
- result = list(iterable)
- result.sort()
- if reverse:
- result.reverse()
- return result
-
-
-def _callable(obj):
- return hasattr(obj, '__call__') or hasattr(obj, '__bases__')
-
-# silence Python 2.6 buggy warnings about Exception.message
-if _sys.version_info[:2] == (2, 6):
- import warnings
- warnings.filterwarnings(
- action='ignore',
- message='BaseException.message has been deprecated as of Python 2.6',
- category=DeprecationWarning,
- module='argparse')
-
-
-SUPPRESS = '==SUPPRESS=='
-
-OPTIONAL = '?'
-ZERO_OR_MORE = '*'
-ONE_OR_MORE = '+'
-PARSER = 'A...'
-REMAINDER = '...'
-
-# =============================
-# Utility functions and classes
-# =============================
-
-class _AttributeHolder(object):
- """Abstract base class that provides __repr__.
-
- The __repr__ method returns a string in the format::
- ClassName(attr=name, attr=name, ...)
- The attributes are determined either by a class-level attribute,
- '_kwarg_names', or by inspecting the instance __dict__.
- """
-
- def __repr__(self):
- type_name = type(self).__name__
- arg_strings = []
- for arg in self._get_args():
- arg_strings.append(repr(arg))
- for name, value in self._get_kwargs():
- arg_strings.append('%s=%r' % (name, value))
- return '%s(%s)' % (type_name, ', '.join(arg_strings))
-
- def _get_kwargs(self):
- return _sorted(self.__dict__.items())
-
- def _get_args(self):
- return []
-
-
-def _ensure_value(namespace, name, value):
- if getattr(namespace, name, None) is None:
- setattr(namespace, name, value)
- return getattr(namespace, name)
-
-
-# ===============
-# Formatting Help
-# ===============
-
-class HelpFormatter(object):
- """Formatter for generating usage messages and argument help strings.
-
- Only the name of this class is considered a public API. All the methods
- provided by the class are considered an implementation detail.
- """
-
- def __init__(self,
- prog,
- indent_increment=2,
- max_help_position=24,
- width=None):
-
- # default setting for width
- if width is None:
- try:
- width = int(_os.environ['COLUMNS'])
- except (KeyError, ValueError):
- width = 80
- width -= 2
-
- self._prog = prog
- self._indent_increment = indent_increment
- self._max_help_position = max_help_position
- self._width = width
-
- self._current_indent = 0
- self._level = 0
- self._action_max_length = 0
-
- self._root_section = self._Section(self, None)
- self._current_section = self._root_section
-
- self._whitespace_matcher = _re.compile(r'\s+')
- self._long_break_matcher = _re.compile(r'\n\n\n+')
-
- # ===============================
- # Section and indentation methods
- # ===============================
- def _indent(self):
- self._current_indent += self._indent_increment
- self._level += 1
-
- def _dedent(self):
- self._current_indent -= self._indent_increment
- assert self._current_indent >= 0, 'Indent decreased below 0.'
- self._level -= 1
-
- class _Section(object):
-
- def __init__(self, formatter, parent, heading=None):
- self.formatter = formatter
- self.parent = parent
- self.heading = heading
- self.items = []
-
- def format_help(self):
- # format the indented section
- if self.parent is not None:
- self.formatter._indent()
- join = self.formatter._join_parts
- for func, args in self.items:
- func(*args)
- item_help = join([func(*args) for func, args in self.items])
- if self.parent is not None:
- self.formatter._dedent()
-
- # return nothing if the section was empty
- if not item_help:
- return ''
-
- # add the heading if the section was non-empty
- if self.heading is not SUPPRESS and self.heading is not None:
- current_indent = self.formatter._current_indent
- heading = '%*s%s:\n' % (current_indent, '', self.heading)
- else:
- heading = ''
-
- # join the section-initial newline, the heading and the help
- return join(['\n', heading, item_help, '\n'])
-
- def _add_item(self, func, args):
- self._current_section.items.append((func, args))
-
- # ========================
- # Message building methods
- # ========================
- def start_section(self, heading):
- self._indent()
- section = self._Section(self, self._current_section, heading)
- self._add_item(section.format_help, [])
- self._current_section = section
-
- def end_section(self):
- self._current_section = self._current_section.parent
- self._dedent()
-
- def add_text(self, text):
- if text is not SUPPRESS and text is not None:
- self._add_item(self._format_text, [text])
-
- def add_usage(self, usage, actions, groups, prefix=None):
- if usage is not SUPPRESS:
- args = usage, actions, groups, prefix
- self._add_item(self._format_usage, args)
-
- def add_argument(self, action):
- if action.help is not SUPPRESS:
-
- # find all invocations
- get_invocation = self._format_action_invocation
- invocations = [get_invocation(action)]
- for subaction in self._iter_indented_subactions(action):
- invocations.append(get_invocation(subaction))
-
- # update the maximum item length
- invocation_length = max([len(s) for s in invocations])
- action_length = invocation_length + self._current_indent
- self._action_max_length = max(self._action_max_length,
- action_length)
-
- # add the item to the list
- self._add_item(self._format_action, [action])
-
- def add_arguments(self, actions):
- for action in actions:
- self.add_argument(action)
-
- # =======================
- # Help-formatting methods
- # =======================
- def format_help(self):
- help = self._root_section.format_help()
- if help:
- help = self._long_break_matcher.sub('\n\n', help)
- help = help.strip('\n') + '\n'
- return help
-
- def _join_parts(self, part_strings):
- return ''.join([part
- for part in part_strings
- if part and part is not SUPPRESS])
-
- def _format_usage(self, usage, actions, groups, prefix):
- if prefix is None:
- prefix = _('usage: ')
-
- # if usage is specified, use that
- if usage is not None:
- usage = usage % dict(prog=self._prog)
-
- # if no optionals or positionals are available, usage is just prog
- elif usage is None and not actions:
- usage = '%(prog)s' % dict(prog=self._prog)
-
- # if optionals and positionals are available, calculate usage
- elif usage is None:
- prog = '%(prog)s' % dict(prog=self._prog)
-
- # split optionals from positionals
- optionals = []
- positionals = []
- for action in actions:
- if action.option_strings:
- optionals.append(action)
- else:
- positionals.append(action)
-
- # build full usage string
- format = self._format_actions_usage
- action_usage = format(optionals + positionals, groups)
- usage = ' '.join([s for s in [prog, action_usage] if s])
-
- # wrap the usage parts if it's too long
- text_width = self._width - self._current_indent
- if len(prefix) + len(usage) > text_width:
-
- # break usage into wrappable parts
- part_regexp = r'\(.*?\)+|\[.*?\]+|\S+'
- opt_usage = format(optionals, groups)
- pos_usage = format(positionals, groups)
- opt_parts = _re.findall(part_regexp, opt_usage)
- pos_parts = _re.findall(part_regexp, pos_usage)
- assert ' '.join(opt_parts) == opt_usage
- assert ' '.join(pos_parts) == pos_usage
-
- # helper for wrapping lines
- def get_lines(parts, indent, prefix=None):
- lines = []
- line = []
- if prefix is not None:
- line_len = len(prefix) - 1
- else:
- line_len = len(indent) - 1
- for part in parts:
- if line_len + 1 + len(part) > text_width:
- lines.append(indent + ' '.join(line))
- line = []
- line_len = len(indent) - 1
- line.append(part)
- line_len += len(part) + 1
- if line:
- lines.append(indent + ' '.join(line))
- if prefix is not None:
- lines[0] = lines[0][len(indent):]
- return lines
-
- # if prog is short, follow it with optionals or positionals
- if len(prefix) + len(prog) <= 0.75 * text_width:
- indent = ' ' * (len(prefix) + len(prog) + 1)
- if opt_parts:
- lines = get_lines([prog] + opt_parts, indent, prefix)
- lines.extend(get_lines(pos_parts, indent))
- elif pos_parts:
- lines = get_lines([prog] + pos_parts, indent, prefix)
- else:
- lines = [prog]
-
- # if prog is long, put it on its own line
- else:
- indent = ' ' * len(prefix)
- parts = opt_parts + pos_parts
- lines = get_lines(parts, indent)
- if len(lines) > 1:
- lines = []
- lines.extend(get_lines(opt_parts, indent))
- lines.extend(get_lines(pos_parts, indent))
- lines = [prog] + lines
-
- # join lines into usage
- usage = '\n'.join(lines)
-
- # prefix with 'usage:'
- return '%s%s\n\n' % (prefix, usage)
-
- def _format_actions_usage(self, actions, groups):
- # find group indices and identify actions in groups
- group_actions = _set()
- inserts = {}
- for group in groups:
- try:
- start = actions.index(group._group_actions[0])
- except ValueError:
- continue
- else:
- end = start + len(group._group_actions)
- if actions[start:end] == group._group_actions:
- for action in group._group_actions:
- group_actions.add(action)
- if not group.required:
- inserts[start] = '['
- inserts[end] = ']'
- else:
- inserts[start] = '('
- inserts[end] = ')'
- for i in range(start + 1, end):
- inserts[i] = '|'
-
- # collect all actions format strings
- parts = []
- for i, action in enumerate(actions):
-
- # suppressed arguments are marked with None
- # remove | separators for suppressed arguments
- if action.help is SUPPRESS:
- parts.append(None)
- if inserts.get(i) == '|':
- inserts.pop(i)
- elif inserts.get(i + 1) == '|':
- inserts.pop(i + 1)
-
- # produce all arg strings
- elif not action.option_strings:
- part = self._format_args(action, action.dest)
-
- # if it's in a group, strip the outer []
- if action in group_actions:
- if part[0] == '[' and part[-1] == ']':
- part = part[1:-1]
-
- # add the action string to the list
- parts.append(part)
-
- # produce the first way to invoke the option in brackets
- else:
- option_string = action.option_strings[0]
-
- # if the Optional doesn't take a value, format is:
- # -s or --long
- if action.nargs == 0:
- part = '%s' % option_string
-
- # if the Optional takes a value, format is:
- # -s ARGS or --long ARGS
- else:
- default = action.dest.upper()
- args_string = self._format_args(action, default)
- part = '%s %s' % (option_string, args_string)
-
- # make it look optional if it's not required or in a group
- if not action.required and action not in group_actions:
- part = '[%s]' % part
-
- # add the action string to the list
- parts.append(part)
-
- # insert things at the necessary indices
- for i in _sorted(inserts, reverse=True):
- parts[i:i] = [inserts[i]]
-
- # join all the action items with spaces
- text = ' '.join([item for item in parts if item is not None])
-
- # clean up separators for mutually exclusive groups
- open = r'[\[(]'
- close = r'[\])]'
- text = _re.sub(r'(%s) ' % open, r'\1', text)
- text = _re.sub(r' (%s)' % close, r'\1', text)
- text = _re.sub(r'%s *%s' % (open, close), r'', text)
- text = _re.sub(r'\(([^|]*)\)', r'\1', text)
- text = text.strip()
-
- # return the text
- return text
-
- def _format_text(self, text):
- if '%(prog)' in text:
- text = text % dict(prog=self._prog)
- text_width = self._width - self._current_indent
- indent = ' ' * self._current_indent
- return self._fill_text(text, text_width, indent) + '\n\n'
-
- def _format_action(self, action):
- # determine the required width and the entry label
- help_position = min(self._action_max_length + 2,
- self._max_help_position)
- help_width = self._width - help_position
- action_width = help_position - self._current_indent - 2
- action_header = self._format_action_invocation(action)
-
- # ho nelp; start on same line and add a final newline
- if not action.help:
- tup = self._current_indent, '', action_header
- action_header = '%*s%s\n' % tup
-
- # short action name; start on the same line and pad two spaces
- elif len(action_header) <= action_width:
- tup = self._current_indent, '', action_width, action_header
- action_header = '%*s%-*s ' % tup
- indent_first = 0
-
- # long action name; start on the next line
- else:
- tup = self._current_indent, '', action_header
- action_header = '%*s%s\n' % tup
- indent_first = help_position
-
- # collect the pieces of the action help
- parts = [action_header]
-
- # if there was help for the action, add lines of help text
- if action.help:
- help_text = self._expand_help(action)
- help_lines = self._split_lines(help_text, help_width)
- parts.append('%*s%s\n' % (indent_first, '', help_lines[0]))
- for line in help_lines[1:]:
- parts.append('%*s%s\n' % (help_position, '', line))
-
- # or add a newline if the description doesn't end with one
- elif not action_header.endswith('\n'):
- parts.append('\n')
-
- # if there are any sub-actions, add their help as well
- for subaction in self._iter_indented_subactions(action):
- parts.append(self._format_action(subaction))
-
- # return a single string
- return self._join_parts(parts)
-
- def _format_action_invocation(self, action):
- if not action.option_strings:
- metavar, = self._metavar_formatter(action, action.dest)(1)
- return metavar
-
- else:
- parts = []
-
- # if the Optional doesn't take a value, format is:
- # -s, --long
- if action.nargs == 0:
- parts.extend(action.option_strings)
-
- # if the Optional takes a value, format is:
- # -s ARGS, --long ARGS
- else:
- default = action.dest.upper()
- args_string = self._format_args(action, default)
- for option_string in action.option_strings:
- parts.append('%s %s' % (option_string, args_string))
-
- return ', '.join(parts)
-
- def _metavar_formatter(self, action, default_metavar):
- if action.metavar is not None:
- result = action.metavar
- elif action.choices is not None:
- choice_strs = [str(choice) for choice in action.choices]
- result = '{%s}' % ','.join(choice_strs)
- else:
- result = default_metavar
-
- def format(tuple_size):
- if isinstance(result, tuple):
- return result
- else:
- return (result, ) * tuple_size
- return format
-
- def _format_args(self, action, default_metavar):
- get_metavar = self._metavar_formatter(action, default_metavar)
- if action.nargs is None:
- result = '%s' % get_metavar(1)
- elif action.nargs == OPTIONAL:
- result = '[%s]' % get_metavar(1)
- elif action.nargs == ZERO_OR_MORE:
- result = '[%s [%s ...]]' % get_metavar(2)
- elif action.nargs == ONE_OR_MORE:
- result = '%s [%s ...]' % get_metavar(2)
- elif action.nargs == REMAINDER:
- result = '...'
- elif action.nargs == PARSER:
- result = '%s ...' % get_metavar(1)
- else:
- formats = ['%s' for _ in range(action.nargs)]
- result = ' '.join(formats) % get_metavar(action.nargs)
- return result
-
- def _expand_help(self, action):
- params = dict(vars(action), prog=self._prog)
- for name in list(params):
- if params[name] is SUPPRESS:
- del params[name]
- for name in list(params):
- if hasattr(params[name], '__name__'):
- params[name] = params[name].__name__
- if params.get('choices') is not None:
- choices_str = ', '.join([str(c) for c in params['choices']])
- params['choices'] = choices_str
- return self._get_help_string(action) % params
-
- def _iter_indented_subactions(self, action):
- try:
- get_subactions = action._get_subactions
- except AttributeError:
- pass
- else:
- self._indent()
- for subaction in get_subactions():
- yield subaction
- self._dedent()
-
- def _split_lines(self, text, width):
- text = self._whitespace_matcher.sub(' ', text).strip()
- return _textwrap.wrap(text, width)
-
- def _fill_text(self, text, width, indent):
- text = self._whitespace_matcher.sub(' ', text).strip()
- return _textwrap.fill(text, width, initial_indent=indent,
- subsequent_indent=indent)
-
- def _get_help_string(self, action):
- return action.help
-
-
-class RawDescriptionHelpFormatter(HelpFormatter):
- """Help message formatter which retains any formatting in descriptions.
-
- Only the name of this class is considered a public API. All the methods
- provided by the class are considered an implementation detail.
- """
-
- def _fill_text(self, text, width, indent):
- return ''.join([indent + line for line in text.splitlines(True)])
-
-
-class RawTextHelpFormatter(RawDescriptionHelpFormatter):
- """Help message formatter which retains formatting of all help text.
-
- Only the name of this class is considered a public API. All the methods
- provided by the class are considered an implementation detail.
- """
-
- def _split_lines(self, text, width):
- return text.splitlines()
-
-
-class ArgumentDefaultsHelpFormatter(HelpFormatter):
- """Help message formatter which adds default values to argument help.
-
- Only the name of this class is considered a public API. All the methods
- provided by the class are considered an implementation detail.
- """
-
- def _get_help_string(self, action):
- help = action.help
- if '%(default)' not in action.help:
- if action.default is not SUPPRESS:
- defaulting_nargs = [OPTIONAL, ZERO_OR_MORE]
- if action.option_strings or action.nargs in defaulting_nargs:
- help += ' (default: %(default)s)'
- return help
-
-
-# =====================
-# Options and Arguments
-# =====================
-
-def _get_action_name(argument):
- if argument is None:
- return None
- elif argument.option_strings:
- return '/'.join(argument.option_strings)
- elif argument.metavar not in (None, SUPPRESS):
- return argument.metavar
- elif argument.dest not in (None, SUPPRESS):
- return argument.dest
- else:
- return None
-
-
-class ArgumentError(Exception):
- """An error from creating or using an argument (optional or positional).
-
- The string value of this exception is the message, augmented with
- information about the argument that caused it.
- """
-
- def __init__(self, argument, message):
- self.argument_name = _get_action_name(argument)
- self.message = message
-
- def __str__(self):
- if self.argument_name is None:
- format = '%(message)s'
- else:
- format = 'argument %(argument_name)s: %(message)s'
- return format % dict(message=self.message,
- argument_name=self.argument_name)
-
-
-class ArgumentTypeError(Exception):
- """An error from trying to convert a command line string to a type."""
- pass
-
-
-# ==============
-# Action classes
-# ==============
-
-class Action(_AttributeHolder):
- """Information about how to convert command line strings to Python objects.
-
- Action objects are used by an ArgumentParser to represent the information
- needed to parse a single argument from one or more strings from the
- command line. The keyword arguments to the Action constructor are also
- all attributes of Action instances.
-
- Keyword Arguments:
-
- - option_strings -- A list of command-line option strings which
- should be associated with this action.
-
- - dest -- The name of the attribute to hold the created object(s)
-
- - nargs -- The number of command-line arguments that should be
- consumed. By default, one argument will be consumed and a single
- value will be produced. Other values include:
- - N (an integer) consumes N arguments (and produces a list)
- - '?' consumes zero or one arguments
- - '*' consumes zero or more arguments (and produces a list)
- - '+' consumes one or more arguments (and produces a list)
- Note that the difference between the default and nargs=1 is that
- with the default, a single value will be produced, while with
- nargs=1, a list containing a single value will be produced.
-
- - const -- The value to be produced if the option is specified and the
- option uses an action that takes no values.
-
- - default -- The value to be produced if the option is not specified.
-
- - type -- The type which the command-line arguments should be converted
- to, should be one of 'string', 'int', 'float', 'complex' or a
- callable object that accepts a single string argument. If None,
- 'string' is assumed.
-
- - choices -- A container of values that should be allowed. If not None,
- after a command-line argument has been converted to the appropriate
- type, an exception will be raised if it is not a member of this
- collection.
-
- - required -- True if the action must always be specified at the
- command line. This is only meaningful for optional command-line
- arguments.
-
- - help -- The help string describing the argument.
-
- - metavar -- The name to be used for the option's argument with the
- help string. If None, the 'dest' value will be used as the name.
- """
-
- def __init__(self,
- option_strings,
- dest,
- nargs=None,
- const=None,
- default=None,
- type=None,
- choices=None,
- required=False,
- help=None,
- metavar=None):
- self.option_strings = option_strings
- self.dest = dest
- self.nargs = nargs
- self.const = const
- self.default = default
- self.type = type
- self.choices = choices
- self.required = required
- self.help = help
- self.metavar = metavar
-
- def _get_kwargs(self):
- names = [
- 'option_strings',
- 'dest',
- 'nargs',
- 'const',
- 'default',
- 'type',
- 'choices',
- 'help',
- 'metavar',
- ]
- return [(name, getattr(self, name)) for name in names]
-
- def __call__(self, parser, namespace, values, option_string=None):
- raise NotImplementedError(_('.__call__() not defined'))
-
-
-class _StoreAction(Action):
-
- def __init__(self,
- option_strings,
- dest,
- nargs=None,
- const=None,
- default=None,
- type=None,
- choices=None,
- required=False,
- help=None,
- metavar=None):
- if nargs == 0:
- raise ValueError('nargs for store actions must be > 0; if you '
- 'have nothing to store, actions such as store '
- 'true or store const may be more appropriate')
- if const is not None and nargs != OPTIONAL:
- raise ValueError('nargs must be %r to supply const' % OPTIONAL)
- super(_StoreAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=nargs,
- const=const,
- default=default,
- type=type,
- choices=choices,
- required=required,
- help=help,
- metavar=metavar)
-
- def __call__(self, parser, namespace, values, option_string=None):
- setattr(namespace, self.dest, values)
-
-
-class _StoreConstAction(Action):
-
- def __init__(self,
- option_strings,
- dest,
- const,
- default=None,
- required=False,
- help=None,
- metavar=None):
- super(_StoreConstAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=const,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace, values, option_string=None):
- setattr(namespace, self.dest, self.const)
-
-
-class _StoreTrueAction(_StoreConstAction):
-
- def __init__(self,
- option_strings,
- dest,
- default=False,
- required=False,
- help=None):
- super(_StoreTrueAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- const=True,
- default=default,
- required=required,
- help=help)
-
-
-class _StoreFalseAction(_StoreConstAction):
-
- def __init__(self,
- option_strings,
- dest,
- default=True,
- required=False,
- help=None):
- super(_StoreFalseAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- const=False,
- default=default,
- required=required,
- help=help)
-
-
-class _AppendAction(Action):
-
- def __init__(self,
- option_strings,
- dest,
- nargs=None,
- const=None,
- default=None,
- type=None,
- choices=None,
- required=False,
- help=None,
- metavar=None):
- if nargs == 0:
- raise ValueError('nargs for append actions must be > 0; if arg '
- 'strings are not supplying the value to append, '
- 'the append const action may be more appropriate')
- if const is not None and nargs != OPTIONAL:
- raise ValueError('nargs must be %r to supply const' % OPTIONAL)
- super(_AppendAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=nargs,
- const=const,
- default=default,
- type=type,
- choices=choices,
- required=required,
- help=help,
- metavar=metavar)
-
- def __call__(self, parser, namespace, values, option_string=None):
- items = _copy.copy(_ensure_value(namespace, self.dest, []))
- items.append(values)
- setattr(namespace, self.dest, items)
-
-
-class _AppendConstAction(Action):
-
- def __init__(self,
- option_strings,
- dest,
- const,
- default=None,
- required=False,
- help=None,
- metavar=None):
- super(_AppendConstAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=const,
- default=default,
- required=required,
- help=help,
- metavar=metavar)
-
- def __call__(self, parser, namespace, values, option_string=None):
- items = _copy.copy(_ensure_value(namespace, self.dest, []))
- items.append(self.const)
- setattr(namespace, self.dest, items)
-
-
-class _CountAction(Action):
-
- def __init__(self,
- option_strings,
- dest,
- default=None,
- required=False,
- help=None):
- super(_CountAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace, values, option_string=None):
- new_count = _ensure_value(namespace, self.dest, 0) + 1
- setattr(namespace, self.dest, new_count)
-
-
-class _HelpAction(Action):
-
- def __init__(self,
- option_strings,
- dest=SUPPRESS,
- default=SUPPRESS,
- help=None):
- super(_HelpAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- default=default,
- nargs=0,
- help=help)
-
- def __call__(self, parser, namespace, values, option_string=None):
- parser.print_help()
- parser.exit()
-
-
-class _VersionAction(Action):
-
- def __init__(self,
- option_strings,
- version=None,
- dest=SUPPRESS,
- default=SUPPRESS,
- help=None):
- super(_VersionAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- default=default,
- nargs=0,
- help=help)
- self.version = version
-
- def __call__(self, parser, namespace, values, option_string=None):
- version = self.version
- if version is None:
- version = parser.version
- formatter = parser._get_formatter()
- formatter.add_text(version)
- parser.exit(message=formatter.format_help())
-
-
-class _SubParsersAction(Action):
-
- class _ChoicesPseudoAction(Action):
-
- def __init__(self, name, help):
- sup = super(_SubParsersAction._ChoicesPseudoAction, self)
- sup.__init__(option_strings=[], dest=name, help=help)
-
- def __init__(self,
- option_strings,
- prog,
- parser_class,
- dest=SUPPRESS,
- help=None,
- metavar=None):
-
- self._prog_prefix = prog
- self._parser_class = parser_class
- self._name_parser_map = {}
- self._choices_actions = []
-
- super(_SubParsersAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=PARSER,
- choices=self._name_parser_map,
- help=help,
- metavar=metavar)
-
- def add_parser(self, name, **kwargs):
- # set prog from the existing prefix
- if kwargs.get('prog') is None:
- kwargs['prog'] = '%s %s' % (self._prog_prefix, name)
-
- # create a pseudo-action to hold the choice help
- if 'help' in kwargs:
- help = kwargs.pop('help')
- choice_action = self._ChoicesPseudoAction(name, help)
- self._choices_actions.append(choice_action)
-
- # create the parser and add it to the map
- parser = self._parser_class(**kwargs)
- self._name_parser_map[name] = parser
- return parser
-
- def _get_subactions(self):
- return self._choices_actions
-
- def __call__(self, parser, namespace, values, option_string=None):
- parser_name = values[0]
- arg_strings = values[1:]
-
- # set the parser name if requested
- if self.dest is not SUPPRESS:
- setattr(namespace, self.dest, parser_name)
-
- # select the parser
- try:
- parser = self._name_parser_map[parser_name]
- except KeyError:
- tup = parser_name, ', '.join(self._name_parser_map)
- msg = _('unknown parser %r (choices: %s)' % tup)
- raise ArgumentError(self, msg)
-
- # parse all the remaining options into the namespace
- parser.parse_args(arg_strings, namespace)
-
-
-# ==============
-# Type classes
-# ==============
-
-class FileType(object):
- """Factory for creating file object types
-
- Instances of FileType are typically passed as type= arguments to the
- ArgumentParser add_argument() method.
-
- Keyword Arguments:
- - mode -- A string indicating how the file is to be opened. Accepts the
- same values as the builtin open() function.
- - bufsize -- The file's desired buffer size. Accepts the same values as
- the builtin open() function.
- """
-
- def __init__(self, mode='r', bufsize=None):
- self._mode = mode
- self._bufsize = bufsize
-
- def __call__(self, string):
- # the special argument "-" means sys.std{in,out}
- if string == '-':
- if 'r' in self._mode:
- return _sys.stdin
- elif 'w' in self._mode:
- return _sys.stdout
- else:
- msg = _('argument "-" with mode %r' % self._mode)
- raise ValueError(msg)
-
- # all other arguments are used as file names
- if self._bufsize:
- return open(string, self._mode, self._bufsize)
- else:
- return open(string, self._mode)
-
- def __repr__(self):
- args = [self._mode, self._bufsize]
- args_str = ', '.join([repr(arg) for arg in args if arg is not None])
- return '%s(%s)' % (type(self).__name__, args_str)
-
-# ===========================
-# Optional and Positional Parsing
-# ===========================
-
-class Namespace(_AttributeHolder):
- """Simple object for storing attributes.
-
- Implements equality by attribute names and values, and provides a simple
- string representation.
- """
-
- def __init__(self, **kwargs):
- for name in kwargs:
- setattr(self, name, kwargs[name])
-
- def __eq__(self, other):
- return vars(self) == vars(other)
-
- def __ne__(self, other):
- return not (self == other)
-
- def __contains__(self, key):
- return key in self.__dict__
-
-
-class _ActionsContainer(object):
-
- def __init__(self,
- description,
- prefix_chars,
- argument_default,
- conflict_handler):
- super(_ActionsContainer, self).__init__()
-
- self.description = description
- self.argument_default = argument_default
- self.prefix_chars = prefix_chars
- self.conflict_handler = conflict_handler
-
- # set up registries
- self._registries = {}
-
- # register actions
- self.register('action', None, _StoreAction)
- self.register('action', 'store', _StoreAction)
- self.register('action', 'store_const', _StoreConstAction)
- self.register('action', 'store_true', _StoreTrueAction)
- self.register('action', 'store_false', _StoreFalseAction)
- self.register('action', 'append', _AppendAction)
- self.register('action', 'append_const', _AppendConstAction)
- self.register('action', 'count', _CountAction)
- self.register('action', 'help', _HelpAction)
- self.register('action', 'version', _VersionAction)
- self.register('action', 'parsers', _SubParsersAction)
-
- # raise an exception if the conflict handler is invalid
- self._get_handler()
-
- # action storage
- self._actions = []
- self._option_string_actions = {}
-
- # groups
- self._action_groups = []
- self._mutually_exclusive_groups = []
-
- # defaults storage
- self._defaults = {}
-
- # determines whether an "option" looks like a negative number
- self._negative_number_matcher = _re.compile(r'^-\d+$|^-\d*\.\d+$')
-
- # whether or not there are any optionals that look like negative
- # numbers -- uses a list so it can be shared and edited
- self._has_negative_number_optionals = []
-
- # ====================
- # Registration methods
- # ====================
- def register(self, registry_name, value, object):
- registry = self._registries.setdefault(registry_name, {})
- registry[value] = object
-
- def _registry_get(self, registry_name, value, default=None):
- return self._registries[registry_name].get(value, default)
-
- # ==================================
- # Namespace default accessor methods
- # ==================================
- def set_defaults(self, **kwargs):
- self._defaults.update(kwargs)
-
- # if these defaults match any existing arguments, replace
- # the previous default on the object with the new one
- for action in self._actions:
- if action.dest in kwargs:
- action.default = kwargs[action.dest]
-
- def get_default(self, dest):
- for action in self._actions:
- if action.dest == dest and action.default is not None:
- return action.default
- return self._defaults.get(dest, None)
-
-
- # =======================
- # Adding argument actions
- # =======================
- def add_argument(self, *args, **kwargs):
- """
- add_argument(dest, ..., name=value, ...)
- add_argument(option_string, option_string, ..., name=value, ...)
- """
-
- # if no positional args are supplied or only one is supplied and
- # it doesn't look like an option string, parse a positional
- # argument
- chars = self.prefix_chars
- if not args or len(args) == 1 and args[0][0] not in chars:
- if args and 'dest' in kwargs:
- raise ValueError('dest supplied twice for positional argument')
- kwargs = self._get_positional_kwargs(*args, **kwargs)
-
- # otherwise, we're adding an optional argument
- else:
- kwargs = self._get_optional_kwargs(*args, **kwargs)
-
- # if no default was supplied, use the parser-level default
- if 'default' not in kwargs:
- dest = kwargs['dest']
- if dest in self._defaults:
- kwargs['default'] = self._defaults[dest]
- elif self.argument_default is not None:
- kwargs['default'] = self.argument_default
-
- # create the action object, and add it to the parser
- action_class = self._pop_action_class(kwargs)
- if not _callable(action_class):
- raise ValueError('unknown action "%s"' % action_class)
- action = action_class(**kwargs)
-
- # raise an error if the action type is not callable
- type_func = self._registry_get('type', action.type, action.type)
- if not _callable(type_func):
- raise ValueError('%r is not callable' % type_func)
-
- return self._add_action(action)
-
- def add_argument_group(self, *args, **kwargs):
- group = _ArgumentGroup(self, *args, **kwargs)
- self._action_groups.append(group)
- return group
-
- def add_mutually_exclusive_group(self, **kwargs):
- group = _MutuallyExclusiveGroup(self, **kwargs)
- self._mutually_exclusive_groups.append(group)
- return group
-
- def _add_action(self, action):
- # resolve any conflicts
- self._check_conflict(action)
-
- # add to actions list
- self._actions.append(action)
- action.container = self
-
- # index the action by any option strings it has
- for option_string in action.option_strings:
- self._option_string_actions[option_string] = action
-
- # set the flag if any option strings look like negative numbers
- for option_string in action.option_strings:
- if self._negative_number_matcher.match(option_string):
- if not self._has_negative_number_optionals:
- self._has_negative_number_optionals.append(True)
-
- # return the created action
- return action
-
- def _remove_action(self, action):
- self._actions.remove(action)
-
- def _add_container_actions(self, container):
- # collect groups by titles
- title_group_map = {}
- for group in self._action_groups:
- if group.title in title_group_map:
- msg = _('cannot merge actions - two groups are named %r')
- raise ValueError(msg % (group.title))
- title_group_map[group.title] = group
-
- # map each action to its group
- group_map = {}
- for group in container._action_groups:
-
- # if a group with the title exists, use that, otherwise
- # create a new group matching the container's group
- if group.title not in title_group_map:
- title_group_map[group.title] = self.add_argument_group(
- title=group.title,
- description=group.description,
- conflict_handler=group.conflict_handler)
-
- # map the actions to their new group
- for action in group._group_actions:
- group_map[action] = title_group_map[group.title]
-
- # add container's mutually exclusive groups
- # NOTE: if add_mutually_exclusive_group ever gains title= and
- # description= then this code will need to be expanded as above
- for group in container._mutually_exclusive_groups:
- mutex_group = self.add_mutually_exclusive_group(
- required=group.required)
-
- # map the actions to their new mutex group
- for action in group._group_actions:
- group_map[action] = mutex_group
-
- # add all actions to this container or their group
- for action in container._actions:
- group_map.get(action, self)._add_action(action)
-
- def _get_positional_kwargs(self, dest, **kwargs):
- # make sure required is not specified
- if 'required' in kwargs:
- msg = _("'required' is an invalid argument for positionals")
- raise TypeError(msg)
-
- # mark positional arguments as required if at least one is
- # always required
- if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]:
- kwargs['required'] = True
- if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs:
- kwargs['required'] = True
-
- # return the keyword arguments with no option strings
- return dict(kwargs, dest=dest, option_strings=[])
-
- def _get_optional_kwargs(self, *args, **kwargs):
- # determine short and long option strings
- option_strings = []
- long_option_strings = []
- for option_string in args:
- # error on strings that don't start with an appropriate prefix
- if not option_string[0] in self.prefix_chars:
- msg = _('invalid option string %r: '
- 'must start with a character %r')
- tup = option_string, self.prefix_chars
- raise ValueError(msg % tup)
-
- # strings starting with two prefix characters are long options
- option_strings.append(option_string)
- if option_string[0] in self.prefix_chars:
- if len(option_string) > 1:
- if option_string[1] in self.prefix_chars:
- long_option_strings.append(option_string)
-
- # infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x'
- dest = kwargs.pop('dest', None)
- if dest is None:
- if long_option_strings:
- dest_option_string = long_option_strings[0]
- else:
- dest_option_string = option_strings[0]
- dest = dest_option_string.lstrip(self.prefix_chars)
- if not dest:
- msg = _('dest= is required for options like %r')
- raise ValueError(msg % option_string)
- dest = dest.replace('-', '_')
-
- # return the updated keyword arguments
- return dict(kwargs, dest=dest, option_strings=option_strings)
-
- def _pop_action_class(self, kwargs, default=None):
- action = kwargs.pop('action', default)
- return self._registry_get('action', action, action)
-
- def _get_handler(self):
- # determine function from conflict handler string
- handler_func_name = '_handle_conflict_%s' % self.conflict_handler
- try:
- return getattr(self, handler_func_name)
- except AttributeError:
- msg = _('invalid conflict_resolution value: %r')
- raise ValueError(msg % self.conflict_handler)
-
- def _check_conflict(self, action):
-
- # find all options that conflict with this option
- confl_optionals = []
- for option_string in action.option_strings:
- if option_string in self._option_string_actions:
- confl_optional = self._option_string_actions[option_string]
- confl_optionals.append((option_string, confl_optional))
-
- # resolve any conflicts
- if confl_optionals:
- conflict_handler = self._get_handler()
- conflict_handler(action, confl_optionals)
-
- def _handle_conflict_error(self, action, conflicting_actions):
- message = _('conflicting option string(s): %s')
- conflict_string = ', '.join([option_string
- for option_string, action
- in conflicting_actions])
- raise ArgumentError(action, message % conflict_string)
-
- def _handle_conflict_resolve(self, action, conflicting_actions):
-
- # remove all conflicting options
- for option_string, action in conflicting_actions:
-
- # remove the conflicting option
- action.option_strings.remove(option_string)
- self._option_string_actions.pop(option_string, None)
-
- # if the option now has no option string, remove it from the
- # container holding it
- if not action.option_strings:
- action.container._remove_action(action)
-
-
-class _ArgumentGroup(_ActionsContainer):
-
- def __init__(self, container, title=None, description=None, **kwargs):
- # add any missing keyword arguments by checking the container
- update = kwargs.setdefault
- update('conflict_handler', container.conflict_handler)
- update('prefix_chars', container.prefix_chars)
- update('argument_default', container.argument_default)
- super_init = super(_ArgumentGroup, self).__init__
- super_init(description=description, **kwargs)
-
- # group attributes
- self.title = title
- self._group_actions = []
-
- # share most attributes with the container
- self._registries = container._registries
- self._actions = container._actions
- self._option_string_actions = container._option_string_actions
- self._defaults = container._defaults
- self._has_negative_number_optionals = \
- container._has_negative_number_optionals
-
- def _add_action(self, action):
- action = super(_ArgumentGroup, self)._add_action(action)
- self._group_actions.append(action)
- return action
-
- def _remove_action(self, action):
- super(_ArgumentGroup, self)._remove_action(action)
- self._group_actions.remove(action)
-
-
-class _MutuallyExclusiveGroup(_ArgumentGroup):
-
- def __init__(self, container, required=False):
- super(_MutuallyExclusiveGroup, self).__init__(container)
- self.required = required
- self._container = container
-
- def _add_action(self, action):
- if action.required:
- msg = _('mutually exclusive arguments must be optional')
- raise ValueError(msg)
- action = self._container._add_action(action)
- self._group_actions.append(action)
- return action
-
- def _remove_action(self, action):
- self._container._remove_action(action)
- self._group_actions.remove(action)
-
-
-class ArgumentParser(_AttributeHolder, _ActionsContainer):
- """Object for parsing command line strings into Python objects.
-
- Keyword Arguments:
- - prog -- The name of the program (default: sys.argv[0])
- - usage -- A usage message (default: auto-generated from arguments)
- - description -- A description of what the program does
- - epilog -- Text following the argument descriptions
- - parents -- Parsers whose arguments should be copied into this one
- - formatter_class -- HelpFormatter class for printing help messages
- - prefix_chars -- Characters that prefix optional arguments
- - fromfile_prefix_chars -- Characters that prefix files containing
- additional arguments
- - argument_default -- The default value for all arguments
- - conflict_handler -- String indicating how to handle conflicts
- - add_help -- Add a -h/-help option
- """
-
- def __init__(self,
- prog=None,
- usage=None,
- description=None,
- epilog=None,
- version=None,
- parents=[],
- formatter_class=HelpFormatter,
- prefix_chars='-',
- fromfile_prefix_chars=None,
- argument_default=None,
- conflict_handler='error',
- add_help=True):
-
- if version is not None:
- import warnings
- warnings.warn(
- """The "version" argument to ArgumentParser is deprecated. """
- """Please use """
- """"add_argument(..., action='version', version="N", ...)" """
- """instead""", DeprecationWarning)
-
- superinit = super(ArgumentParser, self).__init__
- superinit(description=description,
- prefix_chars=prefix_chars,
- argument_default=argument_default,
- conflict_handler=conflict_handler)
-
- # default setting for prog
- if prog is None:
- prog = _os.path.basename(_sys.argv[0])
-
- self.prog = prog
- self.usage = usage
- self.epilog = epilog
- self.version = version
- self.formatter_class = formatter_class
- self.fromfile_prefix_chars = fromfile_prefix_chars
- self.add_help = add_help
-
- add_group = self.add_argument_group
- self._positionals = add_group(_('positional arguments'))
- self._optionals = add_group(_('optional arguments'))
- self._subparsers = None
-
- # register types
- def identity(string):
- return string
- self.register('type', None, identity)
-
- # add help and version arguments if necessary
- # (using explicit default to override global argument_default)
- if self.add_help:
- self.add_argument(
- '-h', '--help', action='help', default=SUPPRESS,
- help=_('show this help message and exit'))
- if self.version:
- self.add_argument(
- '-v', '--version', action='version', default=SUPPRESS,
- version=self.version,
- help=_("show program's version number and exit"))
-
- # add parent arguments and defaults
- for parent in parents:
- self._add_container_actions(parent)
- try:
- defaults = parent._defaults
- except AttributeError:
- pass
- else:
- self._defaults.update(defaults)
-
- # =======================
- # Pretty __repr__ methods
- # =======================
- def _get_kwargs(self):
- names = [
- 'prog',
- 'usage',
- 'description',
- 'version',
- 'formatter_class',
- 'conflict_handler',
- 'add_help',
- ]
- return [(name, getattr(self, name)) for name in names]
-
- # ==================================
- # Optional/Positional adding methods
- # ==================================
- def add_subparsers(self, **kwargs):
- if self._subparsers is not None:
- self.error(_('cannot have multiple subparser arguments'))
-
- # add the parser class to the arguments if it's not present
- kwargs.setdefault('parser_class', type(self))
-
- if 'title' in kwargs or 'description' in kwargs:
- title = _(kwargs.pop('title', 'subcommands'))
- description = _(kwargs.pop('description', None))
- self._subparsers = self.add_argument_group(title, description)
- else:
- self._subparsers = self._positionals
-
- # prog defaults to the usage message of this parser, skipping
- # optional arguments and with no "usage:" prefix
- if kwargs.get('prog') is None:
- formatter = self._get_formatter()
- positionals = self._get_positional_actions()
- groups = self._mutually_exclusive_groups
- formatter.add_usage(self.usage, positionals, groups, '')
- kwargs['prog'] = formatter.format_help().strip()
-
- # create the parsers action and add it to the positionals list
- parsers_class = self._pop_action_class(kwargs, 'parsers')
- action = parsers_class(option_strings=[], **kwargs)
- self._subparsers._add_action(action)
-
- # return the created parsers action
- return action
-
- def _add_action(self, action):
- if action.option_strings:
- self._optionals._add_action(action)
- else:
- self._positionals._add_action(action)
- return action
-
- def _get_optional_actions(self):
- return [action
- for action in self._actions
- if action.option_strings]
-
- def _get_positional_actions(self):
- return [action
- for action in self._actions
- if not action.option_strings]
-
- # =====================================
- # Command line argument parsing methods
- # =====================================
- def parse_args(self, args=None, namespace=None):
- args, argv = self.parse_known_args(args, namespace)
- if argv:
- msg = _('unrecognized arguments: %s')
- self.error(msg % ' '.join(argv))
- return args
-
- def parse_known_args(self, args=None, namespace=None):
- # args default to the system args
- if args is None:
- args = _sys.argv[1:]
-
- # default Namespace built from parser defaults
- if namespace is None:
- namespace = Namespace()
-
- # add any action defaults that aren't present
- for action in self._actions:
- if action.dest is not SUPPRESS:
- if not hasattr(namespace, action.dest):
- if action.default is not SUPPRESS:
- default = action.default
- if isinstance(action.default, _basestring):
- default = self._get_value(action, default)
- setattr(namespace, action.dest, default)
-
- # add any parser defaults that aren't present
- for dest in self._defaults:
- if not hasattr(namespace, dest):
- setattr(namespace, dest, self._defaults[dest])
-
- # parse the arguments and exit if there are any errors
- try:
- return self._parse_known_args(args, namespace)
- except ArgumentError:
- err = _sys.exc_info()[1]
- self.error(str(err))
-
- def _parse_known_args(self, arg_strings, namespace):
- # replace arg strings that are file references
- if self.fromfile_prefix_chars is not None:
- arg_strings = self._read_args_from_files(arg_strings)
-
- # map all mutually exclusive arguments to the other arguments
- # they can't occur with
- action_conflicts = {}
- for mutex_group in self._mutually_exclusive_groups:
- group_actions = mutex_group._group_actions
- for i, mutex_action in enumerate(mutex_group._group_actions):
- conflicts = action_conflicts.setdefault(mutex_action, [])
- conflicts.extend(group_actions[:i])
- conflicts.extend(group_actions[i + 1:])
-
- # find all option indices, and determine the arg_string_pattern
- # which has an 'O' if there is an option at an index,
- # an 'A' if there is an argument, or a '-' if there is a '--'
- option_string_indices = {}
- arg_string_pattern_parts = []
- arg_strings_iter = iter(arg_strings)
- for i, arg_string in enumerate(arg_strings_iter):
-
- # all args after -- are non-options
- if arg_string == '--':
- arg_string_pattern_parts.append('-')
- for arg_string in arg_strings_iter:
- arg_string_pattern_parts.append('A')
-
- # otherwise, add the arg to the arg strings
- # and note the index if it was an option
- else:
- option_tuple = self._parse_optional(arg_string)
- if option_tuple is None:
- pattern = 'A'
- else:
- option_string_indices[i] = option_tuple
- pattern = 'O'
- arg_string_pattern_parts.append(pattern)
-
- # join the pieces together to form the pattern
- arg_strings_pattern = ''.join(arg_string_pattern_parts)
-
- # converts arg strings to the appropriate and then takes the action
- seen_actions = _set()
- seen_non_default_actions = _set()
-
- def take_action(action, argument_strings, option_string=None):
- seen_actions.add(action)
- argument_values = self._get_values(action, argument_strings)
-
- # error if this argument is not allowed with other previously
- # seen arguments, assuming that actions that use the default
- # value don't really count as "present"
- if argument_values is not action.default:
- seen_non_default_actions.add(action)
- for conflict_action in action_conflicts.get(action, []):
- if conflict_action in seen_non_default_actions:
- msg = _('not allowed with argument %s')
- action_name = _get_action_name(conflict_action)
- raise ArgumentError(action, msg % action_name)
-
- # take the action if we didn't receive a SUPPRESS value
- # (e.g. from a default)
- if argument_values is not SUPPRESS:
- action(self, namespace, argument_values, option_string)
-
- # function to convert arg_strings into an optional action
- def consume_optional(start_index):
-
- # get the optional identified at this index
- option_tuple = option_string_indices[start_index]
- action, option_string, explicit_arg = option_tuple
-
- # identify additional optionals in the same arg string
- # (e.g. -xyz is the same as -x -y -z if no args are required)
- match_argument = self._match_argument
- action_tuples = []
- while True:
-
- # if we found no optional action, skip it
- if action is None:
- extras.append(arg_strings[start_index])
- return start_index + 1
-
- # if there is an explicit argument, try to match the
- # optional's string arguments to only this
- if explicit_arg is not None:
- arg_count = match_argument(action, 'A')
-
- # if the action is a single-dash option and takes no
- # arguments, try to parse more single-dash options out
- # of the tail of the option string
- chars = self.prefix_chars
- if arg_count == 0 and option_string[1] not in chars:
- action_tuples.append((action, [], option_string))
- for char in self.prefix_chars:
- option_string = char + explicit_arg[0]
- explicit_arg = explicit_arg[1:] or None
- optionals_map = self._option_string_actions
- if option_string in optionals_map:
- action = optionals_map[option_string]
- break
- else:
- msg = _('ignored explicit argument %r')
- raise ArgumentError(action, msg % explicit_arg)
-
- # if the action expect exactly one argument, we've
- # successfully matched the option; exit the loop
- elif arg_count == 1:
- stop = start_index + 1
- args = [explicit_arg]
- action_tuples.append((action, args, option_string))
- break
-
- # error if a double-dash option did not use the
- # explicit argument
- else:
- msg = _('ignored explicit argument %r')
- raise ArgumentError(action, msg % explicit_arg)
-
- # if there is no explicit argument, try to match the
- # optional's string arguments with the following strings
- # if successful, exit the loop
- else:
- start = start_index + 1
- selected_patterns = arg_strings_pattern[start:]
- arg_count = match_argument(action, selected_patterns)
- stop = start + arg_count
- args = arg_strings[start:stop]
- action_tuples.append((action, args, option_string))
- break
-
- # add the Optional to the list and return the index at which
- # the Optional's string args stopped
- assert action_tuples
- for action, args, option_string in action_tuples:
- take_action(action, args, option_string)
- return stop
-
- # the list of Positionals left to be parsed; this is modified
- # by consume_positionals()
- positionals = self._get_positional_actions()
-
- # function to convert arg_strings into positional actions
- def consume_positionals(start_index):
- # match as many Positionals as possible
- match_partial = self._match_arguments_partial
- selected_pattern = arg_strings_pattern[start_index:]
- arg_counts = match_partial(positionals, selected_pattern)
-
- # slice off the appropriate arg strings for each Positional
- # and add the Positional and its args to the list
- for action, arg_count in zip(positionals, arg_counts):
- args = arg_strings[start_index: start_index + arg_count]
- start_index += arg_count
- take_action(action, args)
-
- # slice off the Positionals that we just parsed and return the
- # index at which the Positionals' string args stopped
- positionals[:] = positionals[len(arg_counts):]
- return start_index
-
- # consume Positionals and Optionals alternately, until we have
- # passed the last option string
- extras = []
- start_index = 0
- if option_string_indices:
- max_option_string_index = max(option_string_indices)
- else:
- max_option_string_index = -1
- while start_index <= max_option_string_index:
-
- # consume any Positionals preceding the next option
- next_option_string_index = min([
- index
- for index in option_string_indices
- if index >= start_index])
- if start_index != next_option_string_index:
- positionals_end_index = consume_positionals(start_index)
-
- # only try to parse the next optional if we didn't consume
- # the option string during the positionals parsing
- if positionals_end_index > start_index:
- start_index = positionals_end_index
- continue
- else:
- start_index = positionals_end_index
-
- # if we consumed all the positionals we could and we're not
- # at the index of an option string, there were extra arguments
- if start_index not in option_string_indices:
- strings = arg_strings[start_index:next_option_string_index]
- extras.extend(strings)
- start_index = next_option_string_index
-
- # consume the next optional and any arguments for it
- start_index = consume_optional(start_index)
-
- # consume any positionals following the last Optional
- stop_index = consume_positionals(start_index)
-
- # if we didn't consume all the argument strings, there were extras
- extras.extend(arg_strings[stop_index:])
-
- # if we didn't use all the Positional objects, there were too few
- # arg strings supplied.
- if positionals:
- self.error(_('too few arguments'))
-
- # make sure all required actions were present
- for action in self._actions:
- if action.required:
- if action not in seen_actions:
- name = _get_action_name(action)
- self.error(_('argument %s is required') % name)
-
- # make sure all required groups had one option present
- for group in self._mutually_exclusive_groups:
- if group.required:
- for action in group._group_actions:
- if action in seen_non_default_actions:
- break
-
- # if no actions were used, report the error
- else:
- names = [_get_action_name(action)
- for action in group._group_actions
- if action.help is not SUPPRESS]
- msg = _('one of the arguments %s is required')
- self.error(msg % ' '.join(names))
-
- # return the updated namespace and the extra arguments
- return namespace, extras
-
- def _read_args_from_files(self, arg_strings):
- # expand arguments referencing files
- new_arg_strings = []
- for arg_string in arg_strings:
-
- # for regular arguments, just add them back into the list
- if arg_string[0] not in self.fromfile_prefix_chars:
- new_arg_strings.append(arg_string)
-
- # replace arguments referencing files with the file content
- else:
- try:
- args_file = open(arg_string[1:])
- try:
- arg_strings = []
- for arg_line in args_file.read().splitlines():
- for arg in self.convert_arg_line_to_args(arg_line):
- arg_strings.append(arg)
- arg_strings = self._read_args_from_files(arg_strings)
- new_arg_strings.extend(arg_strings)
- finally:
- args_file.close()
- except IOError:
- err = _sys.exc_info()[1]
- self.error(str(err))
-
- # return the modified argument list
- return new_arg_strings
-
- def convert_arg_line_to_args(self, arg_line):
- return [arg_line]
-
- def _match_argument(self, action, arg_strings_pattern):
- # match the pattern for this action to the arg strings
- nargs_pattern = self._get_nargs_pattern(action)
- match = _re.match(nargs_pattern, arg_strings_pattern)
-
- # raise an exception if we weren't able to find a match
- if match is None:
- nargs_errors = {
- None: _('expected one argument'),
- OPTIONAL: _('expected at most one argument'),
- ONE_OR_MORE: _('expected at least one argument'),
- }
- default = _('expected %s argument(s)') % action.nargs
- msg = nargs_errors.get(action.nargs, default)
- raise ArgumentError(action, msg)
-
- # return the number of arguments matched
- return len(match.group(1))
-
- def _match_arguments_partial(self, actions, arg_strings_pattern):
- # progressively shorten the actions list by slicing off the
- # final actions until we find a match
- result = []
- for i in range(len(actions), 0, -1):
- actions_slice = actions[:i]
- pattern = ''.join([self._get_nargs_pattern(action)
- for action in actions_slice])
- match = _re.match(pattern, arg_strings_pattern)
- if match is not None:
- result.extend([len(string) for string in match.groups()])
- break
-
- # return the list of arg string counts
- return result
-
- def _parse_optional(self, arg_string):
- # if it's an empty string, it was meant to be a positional
- if not arg_string:
- return None
-
- # if it doesn't start with a prefix, it was meant to be positional
- if not arg_string[0] in self.prefix_chars:
- return None
-
- # if the option string is present in the parser, return the action
- if arg_string in self._option_string_actions:
- action = self._option_string_actions[arg_string]
- return action, arg_string, None
-
- # if it's just a single character, it was meant to be positional
- if len(arg_string) == 1:
- return None
-
- # if the option string before the "=" is present, return the action
- if '=' in arg_string:
- option_string, explicit_arg = arg_string.split('=', 1)
- if option_string in self._option_string_actions:
- action = self._option_string_actions[option_string]
- return action, option_string, explicit_arg
-
- # search through all possible prefixes of the option string
- # and all actions in the parser for possible interpretations
- option_tuples = self._get_option_tuples(arg_string)
-
- # if multiple actions match, the option string was ambiguous
- if len(option_tuples) > 1:
- options = ', '.join([option_string
- for action, option_string, explicit_arg in option_tuples])
- tup = arg_string, options
- self.error(_('ambiguous option: %s could match %s') % tup)
-
- # if exactly one action matched, this segmentation is good,
- # so return the parsed action
- elif len(option_tuples) == 1:
- option_tuple, = option_tuples
- return option_tuple
-
- # if it was not found as an option, but it looks like a negative
- # number, it was meant to be positional
- # unless there are negative-number-like options
- if self._negative_number_matcher.match(arg_string):
- if not self._has_negative_number_optionals:
- return None
-
- # if it contains a space, it was meant to be a positional
- if ' ' in arg_string:
- return None
-
- # it was meant to be an optional but there is no such option
- # in this parser (though it might be a valid option in a subparser)
- return None, arg_string, None
-
- def _get_option_tuples(self, option_string):
- result = []
-
- # option strings starting with two prefix characters are only
- # split at the '='
- chars = self.prefix_chars
- if option_string[0] in chars and option_string[1] in chars:
- if '=' in option_string:
- option_prefix, explicit_arg = option_string.split('=', 1)
- else:
- option_prefix = option_string
- explicit_arg = None
- for option_string in self._option_string_actions:
- if option_string.startswith(option_prefix):
- action = self._option_string_actions[option_string]
- tup = action, option_string, explicit_arg
- result.append(tup)
-
- # single character options can be concatenated with their arguments
- # but multiple character options always have to have their argument
- # separate
- elif option_string[0] in chars and option_string[1] not in chars:
- option_prefix = option_string
- explicit_arg = None
- short_option_prefix = option_string[:2]
- short_explicit_arg = option_string[2:]
-
- for option_string in self._option_string_actions:
- if option_string == short_option_prefix:
- action = self._option_string_actions[option_string]
- tup = action, option_string, short_explicit_arg
- result.append(tup)
- elif option_string.startswith(option_prefix):
- action = self._option_string_actions[option_string]
- tup = action, option_string, explicit_arg
- result.append(tup)
-
- # shouldn't ever get here
- else:
- self.error(_('unexpected option string: %s') % option_string)
-
- # return the collected option tuples
- return result
-
- def _get_nargs_pattern(self, action):
- # in all examples below, we have to allow for '--' args
- # which are represented as '-' in the pattern
- nargs = action.nargs
-
- # the default (None) is assumed to be a single argument
- if nargs is None:
- nargs_pattern = '(-*A-*)'
-
- # allow zero or one arguments
- elif nargs == OPTIONAL:
- nargs_pattern = '(-*A?-*)'
-
- # allow zero or more arguments
- elif nargs == ZERO_OR_MORE:
- nargs_pattern = '(-*[A-]*)'
-
- # allow one or more arguments
- elif nargs == ONE_OR_MORE:
- nargs_pattern = '(-*A[A-]*)'
-
- # allow any number of options or arguments
- elif nargs == REMAINDER:
- nargs_pattern = '([-AO]*)'
-
- # allow one argument followed by any number of options or arguments
- elif nargs == PARSER:
- nargs_pattern = '(-*A[-AO]*)'
-
- # all others should be integers
- else:
- nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs)
-
- # if this is an optional action, -- is not allowed
- if action.option_strings:
- nargs_pattern = nargs_pattern.replace('-*', '')
- nargs_pattern = nargs_pattern.replace('-', '')
-
- # return the pattern
- return nargs_pattern
-
- # ========================
- # Value conversion methods
- # ========================
- def _get_values(self, action, arg_strings):
- # for everything but PARSER args, strip out '--'
- if action.nargs not in [PARSER, REMAINDER]:
- arg_strings = [s for s in arg_strings if s != '--']
-
- # optional argument produces a default when not present
- if not arg_strings and action.nargs == OPTIONAL:
- if action.option_strings:
- value = action.const
- else:
- value = action.default
- if isinstance(value, _basestring):
- value = self._get_value(action, value)
- self._check_value(action, value)
-
- # when nargs='*' on a positional, if there were no command-line
- # args, use the default if it is anything other than None
- elif (not arg_strings and action.nargs == ZERO_OR_MORE and
- not action.option_strings):
- if action.default is not None:
- value = action.default
- else:
- value = arg_strings
- self._check_value(action, value)
-
- # single argument or optional argument produces a single value
- elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]:
- arg_string, = arg_strings
- value = self._get_value(action, arg_string)
- self._check_value(action, value)
-
- # REMAINDER arguments convert all values, checking none
- elif action.nargs == REMAINDER:
- value = [self._get_value(action, v) for v in arg_strings]
-
- # PARSER arguments convert all values, but check only the first
- elif action.nargs == PARSER:
- value = [self._get_value(action, v) for v in arg_strings]
- self._check_value(action, value[0])
-
- # all other types of nargs produce a list
- else:
- value = [self._get_value(action, v) for v in arg_strings]
- for v in value:
- self._check_value(action, v)
-
- # return the converted value
- return value
-
- def _get_value(self, action, arg_string):
- type_func = self._registry_get('type', action.type, action.type)
- if not _callable(type_func):
- msg = _('%r is not callable')
- raise ArgumentError(action, msg % type_func)
-
- # convert the value to the appropriate type
- try:
- result = type_func(arg_string)
-
- # ArgumentTypeErrors indicate errors
- except ArgumentTypeError:
- name = getattr(action.type, '__name__', repr(action.type))
- msg = str(_sys.exc_info()[1])
- raise ArgumentError(action, msg)
-
- # TypeErrors or ValueErrors also indicate errors
- except (TypeError, ValueError):
- name = getattr(action.type, '__name__', repr(action.type))
- msg = _('invalid %s value: %r')
- raise ArgumentError(action, msg % (name, arg_string))
-
- # return the converted value
- return result
-
- def _check_value(self, action, value):
- # converted value must be one of the choices (if specified)
- if action.choices is not None and value not in action.choices:
- tup = value, ', '.join(map(repr, action.choices))
- msg = _('invalid choice: %r (choose from %s)') % tup
- raise ArgumentError(action, msg)
-
- # =======================
- # Help-formatting methods
- # =======================
- def format_usage(self):
- formatter = self._get_formatter()
- formatter.add_usage(self.usage, self._actions,
- self._mutually_exclusive_groups)
- return formatter.format_help()
-
- def format_help(self):
- formatter = self._get_formatter()
-
- # usage
- formatter.add_usage(self.usage, self._actions,
- self._mutually_exclusive_groups)
-
- # description
- formatter.add_text(self.description)
-
- # positionals, optionals and user-defined groups
- for action_group in self._action_groups:
- formatter.start_section(action_group.title)
- formatter.add_text(action_group.description)
- formatter.add_arguments(action_group._group_actions)
- formatter.end_section()
-
- # epilog
- formatter.add_text(self.epilog)
-
- # determine help from format above
- return formatter.format_help()
-
- def format_version(self):
- import warnings
- warnings.warn(
- 'The format_version method is deprecated -- the "version" '
- 'argument to ArgumentParser is no longer supported.',
- DeprecationWarning)
- formatter = self._get_formatter()
- formatter.add_text(self.version)
- return formatter.format_help()
-
- def _get_formatter(self):
- return self.formatter_class(prog=self.prog)
-
- # =====================
- # Help-printing methods
- # =====================
- def print_usage(self, file=None):
- if file is None:
- file = _sys.stdout
- self._print_message(self.format_usage(), file)
-
- def print_help(self, file=None):
- if file is None:
- file = _sys.stdout
- self._print_message(self.format_help(), file)
-
- def print_version(self, file=None):
- import warnings
- warnings.warn(
- 'The print_version method is deprecated -- the "version" '
- 'argument to ArgumentParser is no longer supported.',
- DeprecationWarning)
- self._print_message(self.format_version(), file)
-
- def _print_message(self, message, file=None):
- if message:
- if file is None:
- file = _sys.stderr
- file.write(message)
-
- # ===============
- # Exiting methods
- # ===============
- def exit(self, status=0, message=None):
- if message:
- self._print_message(message, _sys.stderr)
- _sys.exit(status)
-
- def error(self, message):
- """error(message: string)
-
- Prints a usage message incorporating the message to stderr and
- exits.
-
- If you override this in a subclass, it should not return -- it
- should either exit or raise an exception.
- """
- self.print_usage(_sys.stderr)
- self.exit(2, _('%s: error: %s\n') % (self.prog, message))
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/compareStats.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,65 +0,0 @@
-#!/usr/bin/env python
-# Title: HOL/Tools/Sledgehammer/MaSh/src/compareStats.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# Tool that compares MaSh statistics and displays a comparison.
-
-'''
-Created on Jul 13, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-import sys
-from argparse import ArgumentParser,RawDescriptionHelpFormatter
-from matplotlib.pyplot import plot,figure,show,legend,xlabel,ylabel,axis,hist
-from stats import Statistics
-
-parser = ArgumentParser(description='Compare Statistics. \n\n\
-Loads different statistics and displays a comparison. Requires the matplotlib module.\n\n\
--------- Example Usage ---------------\n\
-./compareStats.py --statFiles ../tmp/natISANB.stats ../tmp/natATPNB.stats -b 30\n\n\
-Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter)
-parser.add_argument('--statFiles', default=None, nargs='+',
- help='The names of the saved statistic files.')
-parser.add_argument('-b','--bins',default=50,help="Number of bins for the AUC histogram. Default=50.",type=int)
-
-def main(argv = sys.argv[1:]):
- args = parser.parse_args(argv)
- if args.statFiles == None:
- print 'Filenames missing.'
- sys.exit(-1)
-
- aucData = []
- aucLabels = []
- for statFile in args.statFiles:
- s = Statistics()
- s.load(statFile)
- avgRecall = [float(x)/s.problems for x in s.recallData]
- figure('Recall')
- plot(range(s.cutOff),avgRecall,label=statFile)
- legend(loc='lower right')
- ylabel('Average Recall')
- xlabel('Highest ranked premises')
- axis([0,s.cutOff,0.0,1.0])
- figure('100%Recall')
- plot(range(s.cutOff),s.recall100Data,label=statFile)
- legend(loc='lower right')
- ylabel('100%Recall')
- xlabel('Highest ranked premises')
- axis([0,s.cutOff,0,s.problems])
- aucData.append(s.aucData)
- aucLabels.append(statFile)
- figure('AUC Histogram')
- hist(aucData,bins=args.bins,label=aucLabels,histtype='bar')
- legend(loc='upper left')
- ylabel('Problems')
- xlabel('AUC')
-
- show()
-
-if __name__ == '__main__':
- #args = ['--statFiles','../tmp/natISANB.stats','../tmp/natATPNB.stats','-b','30']
- #sys.exit(main(args))
- sys.exit(main())
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,261 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012-2013
-#
-# Persistent dictionaries: accessibility, dependencies, and features.
-
-import sys
-from os.path import join
-from Queue import Queue
-from readData import create_accessible_dict,create_dependencies_dict
-from cPickle import load,dump
-from exceptions import LookupError
-
-class Dictionaries(object):
- '''
- This class contains all info about name-> id mapping, etc.
- '''
- def __init__(self):
- '''
- Constructor
- '''
- self.nameIdDict = {}
- self.idNameDict = {}
- self.featureIdDict={}
- self.maxNameId = 1
- self.maxFeatureId = 0
- self.featureDict = {}
- self.dependenciesDict = {}
- self.accessibleDict = {}
- self.expandedAccessibles = {}
- self.accFile = ''
- self.changed = True
- # Unnamed facts
- self.nameIdDict[''] = 0
- self.idNameDict[0] = 'Unnamed Fact'
-
- """
- Init functions. nameIdDict, idNameDict, featureIdDict, articleDict get filled!
- """
- def init_featureDict(self,featureFile):
- self.create_feature_dict(featureFile)
- #self.featureDict,self.maxNameId,self.maxFeatureId,self.featureCountDict,self.triggerFeaturesDict,self.featureTriggeredFormulasDict =\
- # create_feature_dict(self.nameIdDict,self.idNameDict,self.maxNameId,self.featureIdDict,self.maxFeatureId,self.featureCountDict,\
- # self.triggerFeaturesDict,self.featureTriggeredFormulasDict,sineFeatures,featureFile)
- def init_dependenciesDict(self,depFile):
- self.dependenciesDict = create_dependencies_dict(self.nameIdDict,depFile)
- def init_accessibleDict(self,accFile):
- self.accessibleDict,self.maxNameId = create_accessible_dict(self.nameIdDict,self.idNameDict,self.maxNameId,accFile)
-
- def init_all(self,args):
- self.featureFileName = 'mash_features'
- self.accFileName = 'mash_accessibility'
- featureFile = join(args.inputDir,self.featureFileName)
- depFile = join(args.inputDir,args.depFile)
- self.accFile = join(args.inputDir,self.accFileName)
- self.init_featureDict(featureFile)
- self.init_accessibleDict(self.accFile)
- self.init_dependenciesDict(depFile)
- self.expandedAccessibles = {}
- self.changed = True
-
- def create_feature_dict(self,inputFile):
- self.featureDict = {}
- IS = open(inputFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- # Name Id
- if self.nameIdDict.has_key(name):
- raise LookupError('%s appears twice in the feature file. Aborting.'% name)
- sys.exit(-1)
- else:
- self.nameIdDict[name] = self.maxNameId
- self.idNameDict[self.maxNameId] = name
- nameId = self.maxNameId
- self.maxNameId += 1
- features = self.get_features(line)
- # Store results
- self.featureDict[nameId] = features
- IS.close()
- return
-
- def get_name_id(self,name):
- """
- Return the Id for a name.
- If it doesn't exist yet, a new entry is created.
- """
- if self.nameIdDict.has_key(name):
- nameId = self.nameIdDict[name]
- else:
- self.nameIdDict[name] = self.maxNameId
- self.idNameDict[self.maxNameId] = name
- nameId = self.maxNameId
- self.maxNameId += 1
- self.changed = True
- return nameId
-
- def add_feature(self,featureName):
- fMul = featureName.split('|')
- fIds = []
- for f in fMul:
- if not self.featureIdDict.has_key(f):
- self.featureIdDict[f] = self.maxFeatureId
- self.maxFeatureId += 1
- self.changed = True
- fId = self.featureIdDict[f]
- fIds.append(fId)
- return fIds
-
- def get_features(self,line):
- featureNames = [f.strip() for f in line[1].split()]
- features = {}
- for fn in featureNames:
- tmp = fn.split('=')
- weight = 1.0
- if len(tmp) == 2:
- fn = tmp[0]
- weight = float(tmp[1])
- fIds = self.add_feature(tmp[0])
- features[fIds[0]] = (weight,fIds[1:])
- #features[fId] = 1.0 ###
- return features
-
- def expand_accessibles(self,acc):
- accessibles = set(acc)
- unexpandedQueue = Queue()
- for a in acc:
- if self.expandedAccessibles.has_key(a):
- accessibles = accessibles.union(self.expandedAccessibles[a])
- else:
- unexpandedQueue.put(a)
- while not unexpandedQueue.empty():
- nextUnExp = unexpandedQueue.get()
- nextUnExpAcc = self.accessibleDict[nextUnExp]
- for a in nextUnExpAcc:
- if not a in accessibles:
- accessibles = accessibles.union([a])
- if self.expandedAccessibles.has_key(a):
- accessibles = accessibles.union(self.expandedAccessibles[a])
- else:
- unexpandedQueue.put(a)
- return list(accessibles)
-
- def parse_unExpAcc(self,line):
- try:
- unExpAcc = [self.nameIdDict[a.strip()] for a in line.split()]
- except:
- raise LookupError('Cannot find the accessibles:%s. Accessibles need to be introduced before referring to them.' % line)
- return unExpAcc
-
- def parse_fact(self,line):
- """
- Parses a single line, extracting accessibles, features, and dependencies.
- """
- assert line.startswith('! ')
- line = line[2:]
-
- # line = name:accessibles;features;dependencies
- line = line.split(':')
- name = line[0].strip()
- nameId = self.get_name_id(name)
- line = line[1].split(';')
- features = self.get_features(line)
- self.featureDict[nameId] = features
- try:
- self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()]
- except:
- unknownDeps = []
- for d in line[2].split():
- if not self.nameIdDict.has_key(d):
- unknownDeps.append(d)
- raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps))
- self.accessibleDict[nameId] = self.parse_unExpAcc(line[0])
-
- self.changed = True
- return nameId
-
- def parse_overwrite(self,line):
- """
- Parses a single line, extracts the problemId and the Ids of the dependencies.
- """
- assert line.startswith('p ')
- line = line[2:]
-
- # line = name:dependencies
- line = line.split(':')
- name = line[0].strip()
- try:
- nameId = self.nameIdDict[name]
- except:
- raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % name)
- try:
- dependencies = [self.nameIdDict[d.strip()] for d in line[1].split()]
- except:
- unknownDeps = []
- for d in line[1].split():
- if not self.nameIdDict.has_key(d):
- unknownDeps.append(d)
- raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps))
- self.changed = True
- return nameId,dependencies
-
- def parse_problem(self,line):
- """
- Parses a problem and returns the features, the accessibles, and any hints.
- """
- assert line.startswith('? ')
- line = line[2:]
- name = None
- numberOfPredictions = None
-
- # How many predictions should be returned:
- tmp = line.split('#')
- if len(tmp) == 2:
- numberOfPredictions = int(tmp[0].strip())
- line = tmp[1]
-
- # Check whether there is a problem name:
- tmp = line.split(':')
- if len(tmp) == 2:
- name = tmp[0].strip()
- line = tmp[1]
-
- # line = accessibles;features
- line = line.split(';')
- features = self.get_features(line)
-
- # Accessible Ids, expand and store the accessibles.
- #unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
- unExpAcc = self.parse_unExpAcc(line[0])
- if len(self.expandedAccessibles.keys())>=100:
- self.expandedAccessibles = {}
- self.changed = True
- for accId in unExpAcc:
- if not self.expandedAccessibles.has_key(accId):
- accIdAcc = self.accessibleDict[accId]
- self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc)
- self.changed = True
- accessibles = self.expand_accessibles(unExpAcc)
-
- # Get hints:
- if len(line) == 3:
- hints = [self.nameIdDict[d.strip()] for d in line[2].split()]
- else:
- hints = []
-
- return name,features,accessibles,hints,numberOfPredictions
-
- def save(self,fileName):
- if self.changed:
- dictsStream = open(fileName, 'wb')
- dump((self.accessibleDict,self.dependenciesDict,self.expandedAccessibles,self.featureDict,\
- self.featureIdDict,self.idNameDict,self.maxFeatureId,self.maxNameId,self.nameIdDict),dictsStream)
- self.changed = False
- dictsStream.close()
- def load(self,fileName):
- dictsStream = open(fileName, 'rb')
- self.accessibleDict,self.dependenciesDict,self.expandedAccessibles,self.featureDict,\
- self.featureIdDict,self.idNameDict,self.maxFeatureId,self.maxNameId,self.nameIdDict = load(dictsStream)
- self.changed = False
- dictsStream.close()
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,149 +0,0 @@
-#!/usr/bin/env python
-# Title: HOL/Tools/Sledgehammer/MaSh/src/mash
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012 - 2013
-#
-# Entry point for MaSh (Machine Learning for Sledgehammer).
-
-'''
-MaSh - Machine Learning for Sledgehammer
-
-MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer.
-
-Created on July 12, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-import socket,sys,time,logging,os
-from os.path import realpath,dirname
-
-from spawnDaemon import spawnDaemon
-from parameters import init_parser
-
-def communicate(data,host,port):
- logger = logging.getLogger('communicate')
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- try:
- sock.connect((host,port))
- sock.sendall(data+'\n')
- received = ''
- cont = True
- counter = 0
- while cont and counter < 100000:
- rec = sock.recv(4096)
- if rec.endswith('stop'):
- cont = False
- received += rec[:-4]
- else:
- received += rec
- counter += 1
- if rec == '':
- logger.warning('No response from server. Check server log for details.')
- except:
- logger.warning('Communication with server failed.')
- received = -1
- finally:
- sock.close()
- return received
-
-def start_server(host,port):
- logger = logging.getLogger('start_server')
- logger.info('Starting Server.')
- path = dirname(realpath(__file__))
- spawnDaemon(os.path.join(path,'server.py'),os.path.join(path,'server.py'),host,str(port))
- serverIsUp=False
- for _i in range(20):
- # Test if server is up
- try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.connect((host,port))
- sock.close()
- serverIsUp = True
- break
- except:
- time.sleep(0.5)
- if not serverIsUp:
- logger.error('Could not start server.')
- sys.exit(-1)
- return True
-
-def mash(argv = sys.argv[1:]):
- # Initializing command-line arguments
- args = init_parser(argv)
- # Set up logging
- logging.basicConfig(level=logging.DEBUG,
- format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
- datefmt='%d-%m %H:%M:%S',
- filename=args.log,
- filemode='w')
- logger = logging.getLogger('mash')
-
- if args.quiet:
- logger.setLevel(logging.WARNING)
- else:
- console = logging.StreamHandler(sys.stdout)
- console.setLevel(logging.INFO)
- formatter = logging.Formatter('# %(message)s')
- console.setFormatter(formatter)
- logging.getLogger('').addHandler(console)
-
- if not os.path.exists(args.outputDir):
- os.makedirs(args.outputDir)
-
- # Shutdown commands need not start the server fist.
- if args.shutdownServer:
- logger.info('Sending shutdown command.')
- try:
- received = communicate('shutdown',args.host,args.port)
- logger.info(received)
- except:
- pass
- return
-
- # If server is not running, start it.
- startedServer = False
- received = communicate(' '.join(('ping',args.modelFile,args.dictsFile)),args.host,args.port)
- if received == -1:
- startedServer = start_server(args.host,args.port)
- elif received.startswith('Files do not match'):
- logger.error('Filesnames do not match!')
- logger.error('Modelfile server: '+ received.split()[-2])
- logger.error('Modelfile argument: '+ args.modelFile)
- logger.error('Dictsfile server: '+ received.split()[-1])
- logger.error('Dictsfile argument: '+ args.dictsFile)
- return
-
- if args.init or startedServer:
- logger.info('Initializing Server.')
- data = "i "+";".join(argv)
- received = communicate(data,args.host,args.port)
- logger.info(received)
-
- if not args.inputFile == None:
- logger.debug('Using the following settings: %s',args)
- # IO Streams
- OS = open(args.predictions,'w')
- IS = open(args.inputFile,'r')
- lineCount = 0
- for line in IS:
- lineCount += 1
- if lineCount % 100 == 0:
- logger.info('On line %s', lineCount)
- received = communicate(line,args.host,args.port)
- if not received == '':
- OS.write('%s\n' % received)
- OS.close()
- IS.close()
-
- # Statistics
- if args.statistics:
- received = communicate('avgStats',args.host,args.port)
- logger.info(received)
-
- if args.saveModels:
- communicate('save',args.host,args.port)
-
-
-if __name__ == "__main__":
- sys.exit(mash())
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,47 +0,0 @@
-import datetime
-from argparse import ArgumentParser,RawDescriptionHelpFormatter
-
-def init_parser(argv):
- # Set up command-line parser
- parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer. \n\n\
- MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\
- --------------- Example Usage ---------------\n\
- First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Jinja/ \n\
- Then create predictions:\n./mash.py -i ../data/Jinja/mash_commands -p ../data/Jinja/mash_suggestions -l test.log -o ../tmp/ --statistics\n\
- \n\n\
- Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter)
- parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.')
- parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.')
- parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(),
- help='File where the predictions stored. Default=../tmp/dateTime.predictions.')
- parser.add_argument('--numberOfPredictions',default=500,help="Default number of premises to write in the output. Default=500.",type=int)
-
- parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
- parser.add_argument('--inputDir',\
- help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
- parser.add_argument('--depFile', default='mash_dependencies',
- help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
-
- parser.add_argument('--algorithm',default='nb',help="Which learning algorithm is used. nb = Naive Bayes,KNN,predef=predefined. Default=nb.")
- parser.add_argument('--predef',help="File containing the predefined suggestions. Only used when algorithm = predef.")
- # NB Parameters
- parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float)
- parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float)
- parser.add_argument('--NBPosWeight',default=10.0,help="Weight value for positive features. Default=10.0.",type=float)
- parser.add_argument('--expandFeatures',default=False,action='store_true',help="Learning-based feature expansion. Default=False.")
-
- parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\
- WARNING: This will make the program a lot slower! Default=False.")
- parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.")
- parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int)
- parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log')
- parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.")
- parser.add_argument('--modelFile', default='../tmp/model.pickle', help='Model file name. Default=../tmp/model.pickle')
- parser.add_argument('--dictsFile', default='../tmp/dict.pickle', help='Dict file name. Default=../tmp/dict.pickle')
-
- parser.add_argument('--port', default='9255', help='Port of the Mash server. Default=9255',type=int)
- parser.add_argument('--host', default='localhost', help='Host of the Mash server. Default=localhost')
- parser.add_argument('--shutdownServer',default=False,action='store_true',help="Shutdown server without saving the models.")
- parser.add_argument('--saveModels',default=False,action='store_true',help="Server saves the models.")
- args = parser.parse_args(argv)
- return args
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,68 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/predefined.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# A classifier that uses the Meng-Paulson predictions.
-
-'''
-Created on Jul 11, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-from cPickle import dump,load
-
-class Predefined(object):
- '''
- A classifier that uses the Meng-Paulson predictions.
- Only used to easily compare statistics between the old Sledgehammer algorithm and the new machine learning ones.
- '''
-
- def __init__(self,mpPredictionFile):
- '''
- Constructor
- '''
- self.predictionFile = mpPredictionFile
-
- def initializeModel(self,_trainData,dicts):
- """
- Load predictions.
- """
- self.predictions = {}
- IS = open(self.predictionFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0].strip()
- predId = dicts.get_name_id(name)
- line = line[1].split()
- predsTmp = []
- for x in line:
- x = x.split('=')
- predsTmp.append(x[0])
- preds = [dicts.get_name_id(x.strip())for x in predsTmp]
- self.predictions[predId] = preds
- IS.close()
-
- def update(self,dataPoint,features,dependencies):
- """
- Updates the Model.
- """
- # No Update needed since we assume that we got all predictions
- pass
-
-
- def predict(self,problemId):
- """
- Return the saved predictions.
- """
- return self.predictions[problemId]
-
- def save(self,fileName):
- OStream = open(fileName, 'wb')
- dump((self.predictionFile,self.predictions),OStream)
- OStream.close()
-
- def load(self,fileName):
- OStream = open(fileName, 'rb')
- self.predictionFile,self.predictions = load(OStream)
- OStream.close()
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/readData.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,58 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/readData.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# All functions to read the Isabelle output.
-
-'''
-All functions to read the Isabelle output.
-
-Created on July 9, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-import sys,logging
-
-def create_dependencies_dict(nameIdDict,inputFile):
- logger = logging.getLogger('create_dependencies_dict')
- dependenciesDict = {}
- IS = open(inputFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- # Name Id
- if not nameIdDict.has_key(name):
- logger.warning('%s is missing in nameIdDict. Aborting.',name)
- sys.exit(-1)
-
- nameId = nameIdDict[name]
- dependenciesIds = [nameIdDict[f.strip()] for f in line[1].split()]
- # Store results, add p proves p
- if nameId == 0:
- dependenciesDict[nameId] = dependenciesIds
- else:
- dependenciesDict[nameId] = [nameId] + dependenciesIds
- IS.close()
- return dependenciesDict
-
-def create_accessible_dict(nameIdDict,idNameDict,maxNameId,inputFile):
- logger = logging.getLogger('create_accessible_dict')
- accessibleDict = {}
- IS = open(inputFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- # Name Id
- if not nameIdDict.has_key(name):
- logger.warning('%s is missing in nameIdDict. Adding it as theory.',name)
- nameIdDict[name] = maxNameId
- idNameDict[maxNameId] = name
- nameId = maxNameId
- maxNameId += 1
- else:
- nameId = nameIdDict[name]
- accessibleStrings = line[1].split()
- accessibleDict[nameId] = [nameIdDict[a.strip()] for a in accessibleStrings]
- IS.close()
- return accessibleDict,maxNameId
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,234 +0,0 @@
-#!/usr/bin/env python
-# Title: HOL/Tools/Sledgehammer/MaSh/src/server.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2013
-#
-# The MaSh Server.
-
-import SocketServer,os,string,logging,sys
-from multiprocessing import Manager
-from threading import Timer
-from time import time
-from dictionaries import Dictionaries
-from parameters import init_parser
-from sparseNaiveBayes import sparseNBClassifier
-from KNN import KNN,euclidean
-from KNNs import KNNAdaptPointFeatures,KNNUrban
-#from bayesPlusMetric import sparseNBPlusClassifier
-from predefined import Predefined
-from ExpandFeatures import ExpandFeatures
-from stats import Statistics
-
-
-class ThreadingTCPServer(SocketServer.ThreadingTCPServer):
-
- def __init__(self, *args, **kwargs):
- SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs)
- self.manager = Manager()
- self.lock = Manager().Lock()
- self.idle_timeout = 28800.0 # 8 hours in seconds
- self.idle_timer = Timer(self.idle_timeout, self.shutdown)
- self.idle_timer.start()
- self.model = None
- self.dicts = None
- self.callCounter = 0
-
- def save(self):
- if self.model == None or self.dicts == None:
- try:
- self.logger.warning('Cannot save nonexisting models.')
- except:
- pass
- return
- # Save Models
- self.model.save(self.args.modelFile)
- self.dicts.save(self.args.dictsFile)
- if not self.args.saveStats == None:
- statsFile = os.path.join(self.args.outputDir,self.args.saveStats)
- self.stats.save(statsFile)
-
- def save_and_shutdown(self):
- self.save()
- self.shutdown()
-
-class MaShHandler(SocketServer.StreamRequestHandler):
-
- def init(self,argv):
- if argv == '':
- self.server.args = init_parser([])
- else:
- argv = argv.split(';')
- self.server.args = init_parser(argv)
-
- # Set up logging
- logging.basicConfig(level=logging.DEBUG,
- format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
- datefmt='%d-%m %H:%M:%S',
- filename=self.server.args.log+'server',
- filemode='w')
- self.server.logger = logging.getLogger('server')
-
- # Load all data
- self.server.dicts = Dictionaries()
- if os.path.isfile(self.server.args.dictsFile):
- self.server.dicts.load(self.server.args.dictsFile)
- #elif not self.server.args.dictsFile == '../tmp/dict.pickle':
- # raise IOError('Cannot find dictsFile at %s '% self.server.args.dictsFile)
- elif self.server.args.init:
- self.server.dicts.init_all(self.server.args)
- # Pick model
- if self.server.args.algorithm == 'nb':
- ###TODO: !!
- self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
- #self.server.model = sparseNBPlusClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
- elif self.server.args.algorithm == 'KNN':
- #self.server.model = KNN(self.server.dicts)
- self.server.model = KNNAdaptPointFeatures(self.server.dicts)
- elif self.server.args.algorithm == 'predef':
- self.server.model = Predefined(self.server.args.predef)
- else: # Default case
- self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
- if self.server.args.expandFeatures:
- self.server.expandFeatures = ExpandFeatures(self.server.dicts)
- self.server.expandFeatures.initialize(self.server.dicts)
- # Create Model
- if os.path.isfile(self.server.args.modelFile):
- self.server.model.load(self.server.args.modelFile)
- #elif not self.server.args.modelFile == '../tmp/model.pickle':
- # raise IOError('Cannot find modelFile at %s '% self.server.args.modelFile)
- elif self.server.args.init:
- trainData = self.server.dicts.featureDict.keys()
- self.server.model.initializeModel(trainData,self.server.dicts)
-
- if self.server.args.statistics:
- self.server.stats = Statistics(self.server.args.cutOff)
- self.server.statementCounter = 1
- self.server.computeStats = False
-
- self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.')
- self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.')
- self.server.callCounter = 1
-
- def update(self):
- problemId = self.server.dicts.parse_fact(self.data)
- # Statistics
- if self.server.args.statistics and self.server.computeStats:
- self.server.computeStats = False
- # Assume '!' comes after '?'
- if self.server.args.algorithm == 'predef':
- self.server.predictions = self.server.model.predict(problemId)
- self.server.stats.update(self.server.predictions,self.server.dicts.dependenciesDict[problemId],self.server.statementCounter)
- if not self.server.stats.badPreds == []:
- bp = string.join([str(self.server.dicts.idNameDict[x]) for x in self.server.stats.badPreds], ',')
- self.server.logger.debug('Poor predictions: %s',bp)
- self.server.statementCounter += 1
-
- if self.server.args.expandFeatures:
- self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
- # Update Dependencies, p proves p
- if not problemId == 0:
- self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
- ###TODO:
- self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
- #self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId],self.server.dicts)
-
- def overwrite(self):
- # Overwrite old proof.
- problemId,newDependencies = self.server.dicts.parse_overwrite(self.data)
- newDependencies = [problemId]+newDependencies
- self.server.model.overwrite(problemId,newDependencies,self.server.dicts)
- self.server.dicts.dependenciesDict[problemId] = newDependencies
-
- def predict(self):
- self.server.computeStats = True
- if self.server.args.algorithm == 'predef':
- return
- name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)
- if numberOfPredictions == None:
- numberOfPredictions = self.server.args.numberOfPredictions
- if not hints == []:
- self.server.model.update('hints',features,hints)
- if self.server.args.expandFeatures:
- features = self.server.expandFeatures.expand(features)
- # Create predictions
- self.server.logger.debug('Starting computation for line %s',self.server.callCounter)
-
- self.server.predictions,predictionValues = self.server.model.predict(features,accessibles,self.server.dicts)
- assert len(self.server.predictions) == len(predictionValues)
- self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2)))
-
- # Output
- predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
- predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
- predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
- predictionsString = string.join(predictionsStringList,' ')
- #predictionsString = string.join(predictionNames,' ')
- outString = '%s: %s' % (name,predictionsString)
- self.request.sendall(outString)
-
- def shutdown(self,saveModels=True):
- self.request.sendall('Shutting down server.')
- if saveModels:
- self.server.save()
- self.server.idle_timer.cancel()
- self.server.idle_timer = Timer(0.5, self.server.shutdown)
- self.server.idle_timer.start()
-
- def handle(self):
- # self.request is the TCP socket connected to the client
- self.server.lock.acquire()
- self.data = self.rfile.readline().strip()
- try:
- # Update idle shutdown timer
- self.server.idle_timer.cancel()
- self.server.idle_timer = Timer(self.server.idle_timeout, self.server.save_and_shutdown)
- self.server.idle_timer.start()
-
- self.startTime = time()
- if self.data == 'shutdown':
- self.shutdown()
- elif self.data == 'save':
- self.server.save()
- elif self.data.startswith('ping'):
- mFile, dFile = self.data.split()[1:]
- if mFile == self.server.args.modelFile and dFile == self.server.args.dictsFile:
- self.request.sendall('All good.')
- else:
- self.request.sendall('Files do not match '+' '.join((self.server.args.modelFile,self.server.args.dictsFile)))
- elif self.data.startswith('i'):
- self.init(self.data[2:])
- elif self.data.startswith('!'):
- self.update()
- elif self.data.startswith('p'):
- self.overwrite()
- elif self.data.startswith('?'):
- self.predict()
- elif self.data == '':
- # Empty Socket
- return
- elif self.data == 'avgStats':
- self.request.sendall(self.server.stats.printAvg())
- else:
- self.request.sendall('Unspecified input format: \n%s',self.data)
- self.server.callCounter += 1
- self.request.sendall('stop')
- except: # catch exceptions
- #print 'Caught an error. Check %s for more details' % (self.server.args.log+'server')
- logging.exception('')
- finally:
- self.server.lock.release()
-
-if __name__ == "__main__":
- if not len(sys.argv[1:]) == 2:
- print 'No Arguments for HOST and PORT found. Using localhost and 9255'
- HOST, PORT = "localhost", 9255
- else:
- HOST, PORT = sys.argv[1:]
- SocketServer.TCPServer.allow_reuse_address = True
- server = ThreadingTCPServer((HOST, int(PORT)), MaShHandler)
- server.serve_forever()
-
-
-
-
-
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,176 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# An updatable sparse naive Bayes classifier.
-
-'''
-Created on Jul 11, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-from cPickle import dump,load
-from math import log,exp
-
-
-class singleNBClassifier(object):
- '''
- An updateable naive Bayes classifier.
- '''
-
- def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0):
- '''
- Constructor
- '''
- self.neg = 0.0
- self.pos = 0.0
- self.counts = {} # Counts is the tuple poscounts,negcounts
- self.defValPos = defValPos
- self.defValNeg = defValNeg
- self.posWeight = posWeight
-
- def update(self,features,label):
- """
- Updates the Model.
-
- @param label: True or False, True if the features belong to a positive label, false else.
- """
- #print label,self.pos,self.neg,self.counts
- if label:
- self.pos += 1
- else:
- self.neg += 1
-
- for f,_w in features:
- if not self.counts.has_key(f):
- if label:
- fPosCount = 0.0
- fNegCount = 0.0
- self.counts[f] = [fPosCount,fNegCount]
- else:
- continue
- posCount,negCount = self.counts[f]
- if label:
- posCount += 1
- else:
- negCount += 1
- self.counts[f] = [posCount,negCount]
- #print label,self.pos,self.neg,self.counts
-
-
- def delete(self,features,label):
- """
- Deletes a single datapoint from the model.
- """
- if label:
- self.pos -= 1
- else:
- self.neg -= 1
- for f,_w in features:
- posCount,negCount = self.counts[f]
- if label:
- posCount -= 1
- else:
- negCount -= 1
- self.counts[f] = [posCount,negCount]
-
-
- def overwrite(self,features,labelOld,labelNew):
- """
- Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
- """
- self.delete(features,labelOld)
- self.update(features,labelNew)
-
- def predict_sparse(self,features):
- """
- Returns 1 if the probability of + being the correct label is greater than the probability that - is the correct label.
- """
- if self.neg == 0:
- return 1
- elif self.pos ==0:
- return 0
- logneg = log(self.neg)
- lognegprior=log(float(self.neg)/5)
- logpos = log(self.pos)
- prob = logpos - lognegprior
-
- for f,_w in features:
- if self.counts.has_key(f):
- posCount,negCount = self.counts[f]
- if posCount > 0:
- prob += (log(self.posWeight * posCount) - logpos)
- else:
- prob += self.defValPos
- if negCount > 0:
- prob -= (log(negCount) - logneg)
- else:
- prob -= self.defValNeg
- if prob >= 0 :
- return 1
- else:
- return 0
-
- def predict(self,features):
- """
- Returns 1 if the probability is greater than 50%.
- """
- if self.neg == 0:
- return 1
- elif self.pos ==0:
- return 0
- defVal = -15.0
- expDefVal = exp(defVal)
-
- logneg = log(self.neg)
- logpos = log(self.pos)
- prob = logpos - logneg
-
- for f in self.counts.keys():
- posCount,negCount = self.counts[f]
- if f in features:
- if posCount == 0:
- prob += defVal
- else:
- prob += log(float(posCount)/self.pos)
- if negCount == 0:
- prob -= defVal
- else:
- prob -= log(float(negCount)/self.neg)
- else:
- if posCount == self.pos:
- prob += log(1-expDefVal)
- else:
- prob += log(1-float(posCount)/self.pos)
- if negCount == self.neg:
- prob -= log(1-expDefVal)
- else:
- prob -= log(1-float(negCount)/self.neg)
-
- if prob >= 0 :
- return 1
- else:
- return 0
-
- def save(self,fileName):
- OStream = open(fileName, 'wb')
- dump(self.counts,OStream)
- OStream.close()
-
- def load(self,fileName):
- OStream = open(fileName, 'rb')
- self.counts = load(OStream)
- OStream.close()
-
-if __name__ == '__main__':
- x = singleNBClassifier()
- x.update([0], True)
- assert x.predict([0]) == 1
- x = singleNBClassifier()
- x.update([0], False)
- assert x.predict([0]) == 0
-
- x.update([0], True)
- x.update([1], True)
- print x.pos,x.neg,x.predict([0,1])
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,135 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/snow.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# Wrapper for SNoW.
-
-'''
-
-Created on Jul 12, 2012
-
-@author: daniel
-'''
-
-import logging,shlex,subprocess,string,shutil
-#from cPickle import load,dump
-
-class SNoW(object):
- '''
- Calls the SNoW framework.
- '''
-
- def __init__(self):
- '''
- Constructor
- '''
- self.logger = logging.getLogger('SNoW')
- self.SNoWTrainFile = '../tmp/snow.train'
- self.SNoWTestFile = '../snow.test'
- self.SNoWNetFile = '../tmp/snow.net'
- self.defMaxNameId = 100000
-
- def initializeModel(self,trainData,dicts):
- """
- Build basic model from training data.
- """
- # Prepare input files
- self.logger.debug('Creating IO Files')
- OS = open(self.SNoWTrainFile,'w')
- for nameId in trainData:
- features = [f+dicts.maxNameId for f,_w in dicts.featureDict[nameId]]
- #features = [f+self.defMaxNameId for f,_w in dicts.featureDict[nameId]]
- features = map(str,features)
- featureString = string.join(features,',')
- dependencies = dicts.dependenciesDict[nameId]
- dependencies = map(str,dependencies)
- dependenciesString = string.join(dependencies,',')
- snowString = string.join([featureString,dependenciesString],',')+':\n'
- OS.write(snowString)
- OS.close()
-
- # Build Model
- self.logger.debug('Building Model START.')
- snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,dicts.maxNameId-1)
- #print snowTrainCommand
- #snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,self.defMaxNameId-1)
- args = shlex.split(snowTrainCommand)
- p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
- p.wait()
- self.logger.debug('Building Model END.')
-
- def update(self,dataPoint,features,dependencies,dicts):
- """
- Updates the Model.
- """
- """
- self.logger.debug('Updating Model START')
- # Ignore Feature weights
- features = [f+self.defMaxNameId for f,_w in features]
-
- OS = open(self.SNoWTestFile,'w')
- features = map(str,features)
- featureString = string.join(features, ',')
- dependencies = map(str,dependencies)
- dependenciesString = string.join(dependencies,',')
- snowString = string.join([featureString,dependenciesString],',')+':\n'
- OS.write(snowString)
- OS.close()
- snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth -i+' % (self.SNoWTestFile,self.SNoWNetFile)
- args = shlex.split(snowTestCommand)
- p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
- (_lines, _stderrdata) = p.communicate()
- # Move new net file
- src = self.SNoWNetFile+'.new'
- dst = self.SNoWNetFile
- shutil.move(src, dst)
- self.logger.debug('Updating Model END')
- """
- # Do nothing, only update at evaluation. Is a lot faster.
- pass
-
-
- def predict(self,features,accessibles,dicts):
- trainData = dicts.featureDict.keys()
- self.initializeModel(trainData, dicts)
-
- logger = logging.getLogger('predict_SNoW')
- # Ignore Feature weights
- #features = [f+self.defMaxNameId for f,_w in features]
- features = [f+dicts.maxNameId for f,_w in features]
-
- OS = open(self.SNoWTestFile,'w')
- features = map(str,features)
- featureString = string.join(features, ',')
- snowString = featureString+':'
- OS.write(snowString)
- OS.close()
-
- snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth' % (self.SNoWTestFile,self.SNoWNetFile)
- args = shlex.split(snowTestCommand)
- p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
- (lines, _stderrdata) = p.communicate()
- logger.debug('SNoW finished.')
- lines = lines.split('\n')
- assert lines[9].startswith('Example ')
- assert lines[-4] == ''
- predictionsCon = []
- predictionsValues = []
- for line in lines[10:-4]:
- premiseId = int(line.split()[0][:-1])
- predictionsCon.append(premiseId)
- val = line.split()[4]
- if val.endswith('*'):
- val = float(val[:-1])
- else:
- val = float(val)
- predictionsValues.append(val)
- return predictionsCon,predictionsValues
-
- def save(self,fileName):
- # Nothing to do since we don't update
- pass
-
- def load(self,fileName):
- # Nothing to do since we don't update
- pass
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,179 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# An updatable sparse naive Bayes classifier.
-
-'''
-Created on Jul 11, 2012
-
-@author: Daniel Kuehlwein
-'''
-from cPickle import dump,load
-from numpy import array
-from math import log
-
-class sparseNBClassifier(object):
- '''
- An updateable naive Bayes classifier.
- '''
-
- def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0):
- '''
- Constructor
- '''
- self.counts = {}
- self.defaultPriorWeight = defaultPriorWeight
- self.posWeight = posWeight
- self.defVal = defVal
-
- def initializeModel(self,trainData,dicts):
- """
- Build basic model from training data.
- """
- for d in trainData:
- dFeatureCounts = {}
- # Add p proves p with weight self.defaultPriorWeight
- if not self.defaultPriorWeight == 0:
- for f in dicts.featureDict[d].iterkeys():
- dFeatureCounts[f] = self.defaultPriorWeight
- self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]
-
- for key,keyDeps in dicts.dependenciesDict.iteritems():
- keyFeatures = dicts.featureDict[key]
- for dep in keyDeps:
- self.counts[dep][0] += 1
- #depFeatures = dicts.featureDict[key]
- for f in keyFeatures.iterkeys():
- if self.counts[dep][1].has_key(f):
- self.counts[dep][1][f] += 1
- else:
- self.counts[dep][1][f] = 1
-
-
- def update(self,dataPoint,features,dependencies):
- """
- Updates the Model.
- """
- if (not self.counts.has_key(dataPoint)) and (not dataPoint == 0):
- dFeatureCounts = {}
- # Give p |- p a higher weight
- if not self.defaultPriorWeight == 0:
- for f in features.iterkeys():
- dFeatureCounts[f] = self.defaultPriorWeight
- self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts]
- for dep in dependencies:
- self.counts[dep][0] += 1
- for f in features.iterkeys():
- if self.counts[dep][1].has_key(f):
- self.counts[dep][1][f] += 1
- else:
- self.counts[dep][1][f] = 1
-
- def delete(self,dataPoint,features,dependencies):
- """
- Deletes a single datapoint from the model.
- """
- for dep in dependencies:
- self.counts[dep][0] -= 1
- for f,_w in features.items():
- self.counts[dep][1][f] -= 1
- if self.counts[dep][1][f] == 0:
- del self.counts[dep][1][f]
-
-
- def overwrite(self,problemId,newDependencies,dicts):
- """
- Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
- """
- try:
- assert self.counts.has_key(problemId)
- except:
- raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % dicts.idNameDict[problemId])
- oldDeps = dicts.dependenciesDict[problemId]
- features = dicts.featureDict[problemId]
- self.delete(problemId,features,oldDeps)
- self.update(problemId,features,newDependencies)
-
- def predict(self,features,accessibles,dicts):
- """
- For each accessible, predicts the probability of it being useful given the features.
- Returns a ranking of the accessibles.
- """
- tau = 0.05 # Jasmin, change value here
- predictions = []
- observedFeatures = features.keys()
- for fVal in features.itervalues():
- _w,alternateF = fVal
- observedFeatures += alternateF
-
- for a in accessibles:
- posA = self.counts[a][0]
- fA = set(self.counts[a][1].keys())
- fWeightsA = self.counts[a][1]
- resultA = log(posA)
- for f,fVal in features.iteritems():
- w,alternateF = fVal
- # DEBUG
- #w = 1.0
- # Test for multiple features
- isMatch = False
- matchF = None
- if f in fA:
- isMatch = True
- matchF = f
- elif len(alternateF) > 0:
- inter = set(alternateF).intersection(fA)
- if len(inter) > 0:
- isMatch = True
- for mF in inter:
- ### TODO: matchF is randomly selected
- matchF = mF
- break
-
- if isMatch:
- #if f in fA:
- if fWeightsA[matchF] == 0:
- resultA += w*self.defVal
- else:
- assert fWeightsA[matchF] <= posA
- resultA += w*log(float(self.posWeight*fWeightsA[matchF])/posA)
- else:
- resultA += w*self.defVal
- if not tau == 0.0:
- missingFeatures = list(fA.difference(observedFeatures))
- #sumOfWeights = sum([log(float(fWeightsA[x])/posA) for x in missingFeatures]) # slower
- sumOfWeights = sum([log(float(fWeightsA[x])) for x in missingFeatures]) - log(posA) * len(missingFeatures) #DEFAULT
- #sumOfWeights = sum([log(float(fWeightsA[x])/self.totalFeatureCounts[x]) for x in missingFeatures]) - log(posA) * len(missingFeatures)
- resultA -= tau * sumOfWeights
- predictions.append(resultA)
- predictions = array(predictions)
- perm = (-predictions).argsort()
- return array(accessibles)[perm],predictions[perm]
-
- def save(self,fileName):
- OStream = open(fileName, 'wb')
- dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal),OStream)
- OStream.close()
-
- def load(self,fileName):
- OStream = open(fileName, 'rb')
- self.counts,self.defaultPriorWeight,self.posWeight,self.defVal = load(OStream)
- OStream.close()
-
-
-if __name__ == '__main__':
- featureDict = {0:[0,1,2],1:[3,2,1]}
- dependenciesDict = {0:[0],1:[0,1]}
- libDicts = (featureDict,dependenciesDict,{})
- c = sparseNBClassifier()
- c.initializeModel([0,1],libDicts)
- c.update(2,[14,1,3],[0,2])
- print c.counts
- print c.predict([0,14],[0,1,2])
- c.storeModel('x')
- d = sparseNBClassifier()
- d.loadModel('x')
- print c.counts
- print d.counts
- print 'Done'
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/spawnDaemon.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,36 +0,0 @@
-# http://stackoverflow.com/questions/972362/spawning-process-from-python/972383#972383
-import os
-
-def spawnDaemon(path_to_executable, *args):
- """Spawn a completely detached subprocess (i.e., a daemon).
-
- E.g. for mark:
- spawnDaemon("../bin/producenotify.py", "producenotify.py", "xx")
- """
- # fork the first time (to make a non-session-leader child process)
- try:
- pid = os.fork()
- except OSError, e:
- raise RuntimeError("1st fork failed: %s [%d]" % (e.strerror, e.errno))
- if pid != 0:
- # parent (calling) process is all done
- return
-
- # detach from controlling terminal (to make child a session-leader)
- os.setsid()
- try:
- pid = os.fork()
- except OSError, e:
- raise RuntimeError("2nd fork failed: %s [%d]" % (e.strerror, e.errno))
- raise Exception, "%s [%d]" % (e.strerror, e.errno)
- if pid != 0:
- # child process is all done
- os._exit(0)
-
- # and finally let's execute the executable for the daemon!
- try:
- #os.execv(path_to_executable, [path_to_executable])
- os.execv(path_to_executable, args)
- except Exception, e:
- # oops, we're cut off from the world, let's just give up
- os._exit(255)
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/stats.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,155 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/stats.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# Statistics collector.
-
-'''
-Created on Jul 9, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-import logging,string
-from cPickle import load,dump
-
-class Statistics(object):
- '''
- Class for all the statistics
- '''
-
- def __init__(self,cutOff=500):
- '''
- Constructor
- '''
- self.logger = logging.getLogger('Statistics')
- self.avgAUC = 0.0
- self.avgRecall100 = 0.0
- self.avgAvailable = 0.0
- self.avgDepNr = 0.0
- self.problems = 0.0
- self.cutOff = cutOff
- self.recallData = [0]*cutOff
- self.recall100Median = []
- self.recall100Data = [0]*cutOff
- self.aucData = []
- self.premiseOccurenceCounter = {}
- self.firstDepAppearance = {}
- self.depAppearances = []
-
- def update(self,predictions,dependencies,statementCounter):
- """
- Evaluates AUC, dependencies, recall100 and number of available premises of a prediction.
- """
- available = len(predictions)
- predictions = predictions[:self.cutOff]
- dependencies = set(dependencies)
- # No Stats for if no dependencies
- if len(dependencies) == 0:
- self.logger.debug('No Dependencies for statement %s' % statementCounter )
- self.badPreds = []
- return
- if len(predictions) < self.cutOff:
- for i in range(len(predictions),self.cutOff):
- self.recall100Data[i] += 1
- self.recallData[i] += 1
- for d in dependencies:
- if self.premiseOccurenceCounter.has_key(d):
- self.premiseOccurenceCounter[d] += 1
- else:
- self.premiseOccurenceCounter[d] = 1
- if self.firstDepAppearance.has_key(d):
- self.depAppearances.append(statementCounter-self.firstDepAppearance[d])
- else:
- self.firstDepAppearance[d] = statementCounter
- depNr = len(dependencies)
- aucSum = 0.
- posResults = 0.
- positives, negatives = 0, 0
- recall100 = 0.0
- badPreds = []
- depsFound = []
- for index,pId in enumerate(predictions):
- if pId in dependencies: #positive
- posResults+=1
- positives+=1
- recall100 = index+1
- depsFound.append(pId)
- if index > 200:
- badPreds.append(pId)
- else:
- aucSum += posResults
- negatives+=1
- # Update Recall and Recall100 stats
- if depNr == positives:
- self.recall100Data[index] += 1
- if depNr == 0:
- self.recallData[index] += 1
- else:
- self.recallData[index] += float(positives)/depNr
-
- if not depNr == positives:
- depsFound = set(depsFound)
- missing = []
- for dep in dependencies:
- if not dep in depsFound:
- missing.append(dep)
- badPreds.append(dep)
- recall100 = len(predictions)+1
- positives+=1
- self.logger.debug('Dependencies missing for %s in cutoff predictions! Estimating Statistics.',\
- string.join([str(dep) for dep in missing],','))
-
- if positives == 0 or negatives == 0:
- auc = 1.0
- else:
- auc = aucSum/(negatives*positives)
-
- self.aucData.append(auc)
- self.avgAUC += auc
- self.avgRecall100 += recall100
- self.recall100Median.append(recall100)
- self.problems += 1
- self.badPreds = badPreds
- self.avgAvailable += available
- self.avgDepNr += depNr
- self.logger.info('Statement: %s: AUC: %s \t Needed: %s \t Recall100: %s \t Available: %s \t cutOff: %s',\
- statementCounter,round(100*auc,2),depNr,recall100,available,self.cutOff)
-
- def printAvg(self):
- self.logger.info('Average results:')
- #self.logger.info('avgAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t cutOff:%s', \
- # round(100*self.avgAUC/self.problems,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),self.cutOff)
- # HACK FOR PAPER
- assert len(self.aucData) == len(self.recall100Median)
- nrDataPoints = len(self.aucData)
- if nrDataPoints == 0:
- return "No data points"
- if nrDataPoints % 2 == 1:
- medianAUC = sorted(self.aucData)[nrDataPoints/2 + 1]
- else:
- medianAUC = float(sorted(self.aucData)[nrDataPoints/2] + sorted(self.aucData)[nrDataPoints/2 + 1])/2
- #nrDataPoints = len(self.recall100Median)
- if nrDataPoints % 2 == 1:
- medianrecall100 = sorted(self.recall100Median)[nrDataPoints/2 + 1]
- else:
- medianrecall100 = float(sorted(self.recall100Median)[nrDataPoints/2] + sorted(self.recall100Median)[nrDataPoints/2 + 1])/2
-
- returnString = 'avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff: %s' %\
- (round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff)
- self.logger.info(returnString)
- return returnString
-
- """
- self.logger.info('avgAUC: %s \t medianAUC: %s \t avgDepNr: %s \t avgRecall100: %s \t medianRecall100: %s \t cutOff:%s', \
- round(100*self.avgAUC/self.problems,2),round(100*medianAUC,2),round(self.avgDepNr/self.problems,2),round(self.avgRecall100/self.problems,2),round(medianrecall100,2),self.cutOff)
- """
-
- def save(self,fileName):
- oStream = open(fileName, 'wb')
- dump((self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter),oStream)
- oStream.close()
- def load(self,fileName):
- iStream = open(fileName, 'rb')
- self.avgAUC,self.avgRecall100,self.avgAvailable,self.avgDepNr,self.problems,self.cutOff,self.recallData,self.recall100Data,self.aucData,self.premiseOccurenceCounter = load(iStream)
- iStream.close()
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/tester.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,281 +0,0 @@
-'''
-Created on Jan 11, 2013
-
-Searches for the best parameters.
-
-@author: Daniel Kuehlwein
-'''
-
-import logging,sys,os
-from multiprocessing import Process,Queue,current_process,cpu_count
-from mash import mash
-
-def worker(inQueue, outQueue):
- for func, args in iter(inQueue.get, 'STOP'):
- result = func(*args)
- #print '%s says that %s%s = %s' % (current_process().name, func.__name__, args, result)
- outQueue.put(result)
-
-def run_mash(runId,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=True):
- # Init
- runId = str(runId)
- predictionFile = predictionFile + runId
- args = ['--statistics','--init','--inputDir',inputDir,'--log',logFile,'--theoryFile','../tmp/t'+runId,'--modelFile','../tmp/m'+runId,'--dictsFile','../tmp/d'+runId,
- '--theoryDefValPos',str(theoryDefValPos),'--theoryDefValNeg',str(theoryDefValNeg),'--theoryPosWeight',str(theoryPosWeight),\
- '--NBDefaultPriorWeight',str(NBDefaultPriorWeight),'--NBDefVal',str(NBDefVal),'--NBPosWeight',str(NBPosWeight)]
- if learnTheories:
- args += ['--learnTheories']
- if sineFeatures:
- args += ['--sineFeatures','--sineWeight',str(sineWeight)]
- if not predef == '':
- args += ['--predef',predef]
- if quit:
- args += ['-q']
- #print args
- mash(args)
- # Run
- args = ['-i',inputFile,'-p',predictionFile,'--statistics','--cutOff','1024','--log',logFile,'--theoryFile','../tmp/t'+runId,'--modelFile','../tmp/m'+runId,'--dictsFile','../tmp/d'+runId,\
- '--theoryDefValPos',str(theoryDefValPos),'--theoryDefValNeg',str(theoryDefValNeg),'--theoryPosWeight',str(theoryPosWeight),\
- '--NBDefaultPriorWeight',str(NBDefaultPriorWeight),'--NBDefVal',str(NBDefVal),'--NBPosWeight',str(NBPosWeight)]
- if learnTheories:
- args += ['--learnTheories']
- if sineFeatures:
- args += ['--sineFeatures','--sineWeight',str(sineWeight)]
- if not predef == '':
- args += ['--predef',predef]
- if quit:
- args += ['-q']
- #print args
- mash(args)
-
- # Get Results
- IS = open(logFile,'r')
- lines = IS.readlines()
- tmpRes = lines[-1].split()
- avgAuc = tmpRes[5]
- medianAuc = tmpRes[7]
- avgRecall100 = tmpRes[11]
- medianRecall100 = tmpRes[13]
- tmpTheoryRes = lines[-3].split()
- if learnTheories:
- avgTheoryPrecision = tmpTheoryRes[5]
- avgTheoryRecall100 = tmpTheoryRes[7]
- avgTheoryRecall = tmpTheoryRes[9]
- avgTheoryPredictedPercent = tmpTheoryRes[11]
- else:
- avgTheoryPrecision = 'NA'
- avgTheoryRecall100 = 'NA'
- avgTheoryRecall = 'NA'
- avgTheoryPredictedPercent = 'NA'
- IS.close()
-
- # Delete old models
- os.remove(logFile)
- os.remove(predictionFile)
- if learnTheories:
- os.remove('../tmp/t'+runId)
- os.remove('../tmp/m'+runId)
- os.remove('../tmp/d'+runId)
-
- outFile = open('tester','a')
- #print 'avgAuc %s avgRecall100 %s avgTheoryPrecision %s avgTheoryRecall100 %s avgTheoryRecall %s avgTheoryPredictedPercent %s'
- outFile.write('\t'.join([str(learnTheories),str(theoryDefValPos),str(theoryDefValNeg),str(theoryPosWeight),\
- str(NBDefaultPriorWeight),str(NBDefVal),str(NBPosWeight),str(sineFeatures),str(sineWeight),\
- str(avgAuc),str(medianAuc),str(avgRecall100),str(medianRecall100),str(avgTheoryPrecision),\
- str(avgTheoryRecall100),str(avgTheoryRecall),str(avgTheoryPredictedPercent)])+'\n')
- outFile.close()
- print learnTheories,'\t',theoryDefValPos,'\t',theoryDefValNeg,'\t',theoryPosWeight,'\t',\
- NBDefaultPriorWeight,'\t',NBDefVal,'\t',NBPosWeight,'\t',\
- sineFeatures,'\t',sineWeight,'\t',\
- avgAuc,'\t',medianAuc,'\t',avgRecall100,'\t',medianRecall100,'\t',\
- avgTheoryPrecision,'\t',avgTheoryRecall100,'\t',avgTheoryRecall,'\t',avgTheoryPredictedPercent
- return learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,\
- avgAuc,avgRecall100,avgTheoryPrecision,avgTheoryRecall100,avgTheoryRecall,avgTheoryPredictedPercent
-
-def update_best_params(avgRecall100,bestAvgRecall100,\
- bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight,\
- bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,sineFeatures,sineWeight,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight):
- if avgRecall100 > bestAvgRecall100:
- bestAvgRecall100 = avgRecall100
- bestNBDefaultPriorWeight = NBDefaultPriorWeight
- bestNBDefVal = NBDefVal
- bestNBPosWeight = NBPosWeight
- bestSineFeatures = sineFeatures
- bestSineWeight = sineWeight
- return bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight
-
-if __name__ == '__main__':
- cores = cpu_count()
- #cores = 1
- # Options
- depFile = 'mash_dependencies'
- predef = ''
- outputDir = '../tmp/'
- numberOfPredictions = 1024
-
- learnTheoriesRange = [True,False]
- theoryDefValPosRange = [-x for x in range(1,20)]
- theoryDefValNegRange = [-x for x in range(1,20)]
- theoryPosWeightRange = [x for x in range(1,10)]
-
- NBDefaultPriorWeightRange = [10*x for x in range(10)]
- NBDefValRange = [-x for x in range(1,20)]
- NBPosWeightRange = [10*x for x in range(1,10)]
- sineFeaturesRange = [True,False]
- sineWeightRange = [0.1,0.25,0.5,0.75,1.0]
-
- """
- # Test 1
- inputFile = '../data/20121227b/Auth/mash_commands'
- inputDir = '../data/20121227b/Auth/'
- predictionFile = '../tmp/auth.pred'
- logFile = '../tmp/auth.log'
- learnTheories = True
- theoryDefValPos = -7.5
- theoryDefValNeg = -15.0
- theoryPosWeight = 10.0
- NBDefaultPriorWeight = 20.0
- NBDefVal =- 15.0
- NBPosWeight = 10.0
- sineFeatures = True
- sineWeight = 0.5
-
- task_queue = Queue()
- done_queue = Queue()
-
- runs = 0
- for inputDir in ['../data/20121227b/Auth/']:
- problemId = inputDir.split('/')[-2]
- inputFile = os.path.join(inputDir,'mash_commands')
- predictionFile = os.path.join('../tmp/',problemId+'.pred')
- logFile = os.path.join('../tmp/',problemId+'.log')
- learnTheories = True
- theoryDefValPos = -7.5
- theoryDefValNeg = -15.0
- theoryPosWeight = 10.0
-
- bestAvgRecall100 = 0.0
- bestNBDefaultPriorWeight = 1.0
- bestNBDefVal = 1.0
- bestNBPosWeight = 1.0
- bestSineFeatures = False
- bestSineWeight = 0.0
- bestlearnTheories = True
- besttheoryDefValPos = 1.0
- besttheoryDefValNeg = -15.0
- besttheoryPosWeight = 5.0
- for theoryPosWeight in theoryPosWeightRange:
- for theoryDefValNeg in theoryDefValNegRange:
- for NBDefaultPriorWeight in NBDefaultPriorWeightRange:
- for NBDefVal in NBDefValRange:
- for NBPosWeight in NBPosWeightRange:
- for sineFeatures in sineFeaturesRange:
- if sineFeatures:
- for sineWeight in sineWeightRange:
- localLogFile = logFile+str(runs)
- task_queue.put((run_mash,(runs,inputDir, localLogFile, predictionFile, learnTheories, theoryDefValPos, theoryDefValNeg, theoryPosWeight, NBDefaultPriorWeight, NBDefVal, NBPosWeight, sineFeatures, sineWeight)))
- runs += 1
- else:
- localLogFile = logFile+str(runs)
- task_queue.put((run_mash,(runs,inputDir, localLogFile, predictionFile, learnTheories, theoryDefValPos, theoryDefValNeg, theoryPosWeight, NBDefaultPriorWeight, NBDefVal, NBPosWeight, sineFeatures, sineWeight)))
- runs += 1
- # Start worker processes
- processes = []
- for _i in range(cores):
- process = Process(target=worker, args=(task_queue, done_queue))
- process.start()
- processes.append(process)
-
- for _i in range(runs):
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,\
- avgAuc,avgRecall100,avgTheoryPrecision,avgTheoryRecall100,avgTheoryRecall,avgTheoryPredictedPercent = done_queue.get()
- bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight = update_best_params(avgRecall100,bestAvgRecall100,\
- bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight,\
- bestlearnTheories,besttheoryDefValPos,besttheoryDefValNeg,besttheoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,sineFeatures,sineWeight,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight)
- print 'bestAvgRecall100 %s bestNBDefaultPriorWeight %s bestNBDefVal %s bestNBPosWeight %s bestSineFeatures %s bestSineWeight %s',bestAvgRecall100,bestNBDefaultPriorWeight,bestNBDefVal,bestNBPosWeight,bestSineFeatures,bestSineWeight
-
- """
- # Test 2
- #inputDir = '../data/20130118/Jinja'
- inputDir = '../data/notheory/Prob'
- inputFile = inputDir+'/mash_commands'
- #inputFile = inputDir+'/mash_prover_commands'
-
- #depFile = 'mash_prover_dependencies'
- depFile = 'mash_dependencies'
- outputDir = '../tmp/'
- numberOfPredictions = 1024
- predictionFile = '../tmp/auth.pred'
- logFile = '../tmp/auth.log'
- learnTheories = False
- theoryDefValPos = -7.5
- theoryDefValNeg = -10.0
- theoryPosWeight = 2.0
- NBDefaultPriorWeight = 20.0
- NBDefVal =- 15.0
- NBPosWeight = 10.0
- sineFeatures = False
- sineWeight = 0.5
- quiet = False
- print inputDir
-
- #predef = inputDir+'/mash_prover_suggestions'
- predef = inputDir+'/mash_suggestions'
- print 'Mash Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
-
- #"""
- predef = inputDir+'/mesh_suggestions'
- #predef = inputDir+'/mesh_prover_suggestions'
- print 'Mesh Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
- #"""
- predef = inputDir+'/mepo_suggestions'
- print 'Mepo Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
-
- """
- inputFile = inputDir+'/mash_prover_commands'
- depFile = 'mash_prover_dependencies'
-
- predef = inputDir+'/mash_prover_suggestions'
- print 'Mash Prover Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
-
- predef = inputDir+'/mesh_prover_suggestions'
- print 'Mesh Prover Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
-
- predef = inputDir+'/mepo_suggestions'
- print 'Mepo Prover Isar'
- run_mash(0,inputDir,logFile,predictionFile,predef,\
- learnTheories,theoryDefValPos,theoryDefValNeg,theoryPosWeight,\
- NBDefaultPriorWeight,NBDefVal,NBPosWeight,\
- sineFeatures,sineWeight,quiet=quiet)
- #"""
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,151 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# An updatable sparse naive Bayes classifier.
-
-'''
-Created on Dec 26, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-from singleNaiveBayes import singleNBClassifier
-from cPickle import load,dump
-import sys,logging
-
-class TheoryModels(object):
- '''
- MetaClass for all the theory models.
- '''
-
-
- def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0):
- '''
- Constructor
- '''
- self.theoryModels = {}
- # Model Params
- self.defValPos = defValPos
- self.defValNeg = defValNeg
- self.posWeight = posWeight
- self.theoryDict = {}
- self.accessibleTheories = set([])
- self.currentTheory = None
-
- def init(self,depFile,dicts):
- logger = logging.getLogger('TheoryModels')
- IS = open(depFile,'r')
- for line in IS:
- line = line.split(':')
- name = line[0]
- theory = name.split('.')[0]
- # Name Id
- if not dicts.nameIdDict.has_key(name):
- logger.warning('%s is missing in nameIdDict. Aborting.',name)
- sys.exit(-1)
-
- nameId = dicts.nameIdDict[name]
- features = dicts.featureDict[nameId]
- if not self.theoryDict.has_key(theory):
- assert not theory == self.currentTheory
- if not self.currentTheory == None:
- self.accessibleTheories.add(self.currentTheory)
- self.currentTheory = theory
- self.theoryDict[theory] = set([nameId])
- theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight)
- self.theoryModels[theory] = theoryModel
- else:
- self.theoryDict[theory] = self.theoryDict[theory].union([nameId])
-
- # Find the actually used theories
- usedtheories = []
- dependencies = line[1].split()
- if len(dependencies) == 0:
- continue
- for dep in dependencies:
- depId = dicts.nameIdDict[dep.strip()]
- deptheory = dep.split('.')[0]
- usedtheories.append(deptheory)
- if not self.theoryDict.has_key(deptheory):
- self.theoryDict[deptheory] = set([depId])
- else:
- self.theoryDict[deptheory] = self.theoryDict[deptheory].union([depId])
-
- # Update theoryModels
- self.theoryModels[self.currentTheory].update(features,self.currentTheory in usedtheories)
- for a in self.accessibleTheories:
- self.theoryModels[a].update(dicts.featureDict[nameId],a in usedtheories)
- IS.close()
-
- def overwrite(self,problemId,newDependencies,dicts):
- features = dicts.featureDict[problemId]
- unExpAccessibles = dicts.accessibleDict[problemId]
- accessibles = dicts.expand_accessibles(unExpAccessibles)
- accTheories = []
- for x in accessibles:
- xArt = (dicts.idNameDict[x]).split('.')[0]
- accTheories.append(xArt)
- oldTheories = set([x.split('.')[0] for x in dicts.dependenciesDict[problemId]])
- newTheories = set([x.split('.')[0] for x in newDependencies])
- for a in self.accTheories:
- self.theoryModels[a].overwrite(features,a in oldTheories,a in newTheories)
-
- def delete(self,problemId,features,dependencies,dicts):
- tmp = [dicts.idNameDict[x] for x in dependencies]
- usedTheories = set([x.split('.')[0] for x in tmp])
- for a in self.accessibleTheories:
- self.theoryModels[a].delete(features,a in usedTheories)
-
- def update(self,problemId,features,dependencies,dicts):
- # TODO: Implicit assumption that self.accessibleTheories contains all accessible theories!
- currentTheory = (dicts.idNameDict[problemId]).split('.')[0]
- # Create new theory model, if there is a new theory
- if not self.theoryDict.has_key(currentTheory):
- assert not currentTheory == self.currentTheory
- if not currentTheory == None:
- self.theoryDict[currentTheory] = []
- self.currentTheory = currentTheory
- theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight)
- self.theoryModels[currentTheory] = theoryModel
- self.accessibleTheories.add(self.currentTheory)
- self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories)
-
- def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories):
- # Find the actually used theories
- tmp = [dicts.idNameDict[x] for x in dependencies]
- usedTheories = set([x.split('.')[0] for x in tmp])
- if not len(usedTheories) == 0:
- for a in accessibleTheories:
- self.theoryModels[a].update(features,a in usedTheories)
-
- def predict(self,features,accessibles,dicts):
- """
- Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories.
- """
- self.accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles])
-
- # Predict Theories
- predictedTheories = [self.currentTheory]
- for a in self.accessibleTheories:
- if self.theoryModels[a].predict_sparse(features):
- #if theoryModels[a].predict(dicts.featureDict[nameId]):
- predictedTheories.append(a)
- predictedTheories = set(predictedTheories)
-
- # Delete accessibles in unpredicted theories
- newAcc = []
- for x in accessibles:
- xArt = (dicts.idNameDict[x]).split('.')[0]
- if xArt in predictedTheories:
- newAcc.append(x)
- return predictedTheories,newAcc
-
- def save(self,fileName):
- outStream = open(fileName, 'wb')
- dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight),outStream)
- outStream.close()
- def load(self,fileName):
- inStream = open(fileName, 'rb')
- self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight = load(inStream)
- inStream.close()
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py Sat Jun 28 22:13:23 2014 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,70 +0,0 @@
-# Title: HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py
-# Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen
-# Copyright 2012
-#
-# An updatable sparse naive Bayes classifier.
-
-'''
-Created on Dec 26, 2012
-
-@author: Daniel Kuehlwein
-'''
-
-from cPickle import load,dump
-import logging,string
-
-class TheoryStatistics(object):
- '''
- Stores statistics for theory lvl predictions
- '''
-
-
- def __init__(self):
- '''
- Constructor
- '''
- self.logger = logging.getLogger('TheoryStatistics')
- self.count = 0
- self.precision = 0.0
- self.recall100 = 0
- self.recall = 0.0
- self.predicted = 0.0
- self.predictedPercent = 0.0
-
- def update(self,currentTheory,predictedTheories,usedTheories,nrAvailableTheories):
- self.count += 1
- allPredTheories = predictedTheories.union([currentTheory])
- if set(usedTheories).issubset(allPredTheories):
- self.recall100 += 1
- localPredicted = len(allPredTheories)
- self.predicted += localPredicted
- localPredictedPercent = float(localPredicted)/nrAvailableTheories
- self.predictedPercent += localPredictedPercent
- localPrec = float(len(set(usedTheories).intersection(allPredTheories))) / localPredicted
- self.precision += localPrec
- if len(set(usedTheories)) == 0:
- localRecall = 1.0
- else:
- localRecall = float(len(set(usedTheories).intersection(allPredTheories))) / len(set(usedTheories))
- self.recall += localRecall
- self.logger.info('Theory prediction results:')
- self.logger.info('Problem: %s \t Recall100: %s \t Precision: %s \t Recall: %s \t PredictedTeoriesPercent: %s PredictedTeories: %s',\
- self.count,self.recall100,round(localPrec,2),round(localRecall,2),round(localPredictedPercent,2),localPredicted)
-
- def printAvg(self):
- self.logger.info('Average theory results:')
- self.logger.info('avgPrecision: %s \t avgRecall100: %s \t avgRecall: %s \t avgPredictedPercent: %s \t avgPredicted: %s', \
- round(self.precision/self.count,2),\
- round(float(self.recall100)/self.count,2),\
- round(self.recall/self.count,2),\
- round(self.predictedPercent /self.count,2),\
- round(self.predicted /self.count,2))
-
- def save(self,fileName):
- oStream = open(fileName, 'wb')
- dump((self.count,self.precision,self.recall100,self.recall,self.predicted),oStream)
- oStream.close()
- def load(self,fileName):
- iStream = open(fileName, 'rb')
- self.count,self.precision,self.recall100,self.recall,self.predicted = load(iStream)
- iStream.close()
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Sat Jun 28 22:13:23 2014 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Sun Jun 29 18:28:27 2014 +0200
@@ -33,19 +33,16 @@
val decode_str : string -> string
val decode_strs : string -> string list
val encode_features : (string * real) list -> string
- val extract_suggestions : string -> string * (string * real) list
datatype mash_engine =
- MaSh_Py
- | MaSh_SML_kNN
- | MaSh_SML_kNN_Ext
- | MaSh_SML_NB
- | MaSh_SML_NB_Ext
+ MaSh_kNN
+ | MaSh_kNN_Ext
+ | MaSh_NB
+ | MaSh_NB_Ext
val is_mash_enabled : unit -> bool
val the_mash_engine : unit -> mash_engine
- val mash_unlearn : Proof.context -> params -> unit
val nickname_of_thm : thm -> string
val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
val mesh_facts : ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list -> 'a list
@@ -69,18 +66,20 @@
('b * thm) list -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
val mash_suggested_facts : Proof.context -> theory -> params -> int -> term list -> term ->
raw_fact list -> fact list * fact list
+
+ val mash_unlearn : unit -> unit
val mash_learn_proof : Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
- val mash_learn_facts : Proof.context -> params -> string -> bool -> int -> bool -> Time.time ->
+ val mash_learn_facts : Proof.context -> params -> string -> int -> bool -> Time.time ->
raw_fact list -> string
val mash_learn : Proof.context -> params -> fact_override -> thm list -> bool -> unit
- val mash_can_suggest_facts : Proof.context -> bool -> bool
+ val mash_can_suggest_facts : Proof.context -> bool
val generous_max_suggestions : int -> int
val mepo_weight : real
val mash_weight : real
val relevant_facts : Proof.context -> params -> string -> int -> fact_override -> term list ->
term -> raw_fact list -> (string * fact list) list
- val kill_learners : Proof.context -> params -> unit
+ val kill_learners : unit -> unit
val running_learners : unit -> unit
end;
@@ -140,84 +139,28 @@
end
datatype mash_engine =
- MaSh_Py
-| MaSh_SML_kNN
-| MaSh_SML_kNN_Ext
-| MaSh_SML_NB
-| MaSh_SML_NB_Ext
+ MaSh_kNN
+| MaSh_kNN_Ext
+| MaSh_NB
+| MaSh_NB_Ext
fun mash_engine () =
let val flag1 = Options.default_string @{system_option MaSh} in
(case if flag1 <> "none" (* default *) then flag1 else getenv "MASH" of
- "yes" => SOME MaSh_SML_NB
- | "py" => SOME MaSh_Py
- | "sml" => SOME MaSh_SML_NB
- | "sml_knn" => SOME MaSh_SML_kNN
- | "sml_knn_ext" => SOME MaSh_SML_kNN_Ext
- | "sml_nb" => SOME MaSh_SML_NB
- | "sml_nb_ext" => SOME MaSh_SML_NB_Ext
+ "yes" => SOME MaSh_NB
+ | "sml" => SOME MaSh_NB
+ | "knn" => SOME MaSh_kNN
+ | "knn_ext" => SOME MaSh_kNN_Ext
+ | "nb" => SOME MaSh_NB
+ | "nb_ext" => SOME MaSh_NB_Ext
| _ => NONE)
end
val is_mash_enabled = is_some o mash_engine
-val the_mash_engine = the_default MaSh_SML_NB o mash_engine
+val the_mash_engine = the_default MaSh_NB o mash_engine
-(*** Low-level communication with the Python version of MaSh ***)
-
-val save_models_arg = "--saveModels"
-val shutdown_server_arg = "--shutdownServer"
-
-fun wipe_out_file file = ignore (try (File.rm o Path.explode) file)
-
-fun write_file banner (xs, f) path =
- (case banner of SOME s => File.write path s | NONE => ();
- xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
- handle IO.Io _ => ()
-
-fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
- let
- val (temp_dir, serial) =
- if overlord then (getenv "ISABELLE_HOME_USER", "")
- else (getenv "ISABELLE_TMP", serial_string ())
- val log_file = temp_dir ^ "/mash_log" ^ serial
- val err_file = temp_dir ^ "/mash_err" ^ serial
- val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
- val sugg_path = Path.explode sugg_file
- val cmd_file = temp_dir ^ "/mash_commands" ^ serial
- val cmd_path = Path.explode cmd_file
- val model_dir = File.shell_path (mash_state_dir ())
-
- val command =
- "cd \"$ISABELLE_SLEDGEHAMMER_MASH\"/src; \
- \PYTHONDONTWRITEBYTECODE=y ./mash.py\
- \ --quiet\
- \ --port=$MASH_PORT\
- \ --outputDir " ^ model_dir ^
- " --modelFile=" ^ model_dir ^ "/model.pickle\
- \ --dictsFile=" ^ model_dir ^ "/dict.pickle\
- \ --log " ^ log_file ^
- " --inputFile " ^ cmd_file ^
- " --predictions " ^ sugg_file ^
- (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^ " >& " ^ err_file ^
- (if background then " &" else "")
-
- fun run_on () =
- (Isabelle_System.bash command
- |> tap (fn _ =>
- (case try File.read (Path.explode err_file) |> the_default "" of
- "" => trace_msg ctxt (K "Done")
- | s => warning ("MaSh error: " ^ elide_string 1000 s)));
- read_suggs (fn () => try File.read_lines sugg_path |> these))
-
- fun clean_up () =
- if overlord then () else List.app wipe_out_file [err_file, sugg_file, cmd_file]
- in
- write_file (SOME "") ([], K "") sugg_path;
- write_file (SOME "") write_cmds cmd_path;
- trace_msg ctxt (fn () => "Running " ^ command);
- with_cleanup clean_up run_on ()
- end
+(*** Maintenance of the persistent, string-based state ***)
fun meta_char c =
if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse c = #")" orelse
@@ -246,70 +189,10 @@
val encode_features = map encode_feature #> space_implode " "
-fun str_of_learn (name, parents, feats, deps) =
- "! " ^ encode_str name ^ ": " ^ encode_strs parents ^ "; " ^ encode_strs feats ^ "; " ^
- encode_strs deps ^ "\n"
-
-fun str_of_relearn (name, deps) = "p " ^ encode_str name ^ ": " ^ encode_strs deps ^ "\n"
-
-fun str_of_query max_suggs (learns, parents, feats) =
- implode (map str_of_learn learns) ^
- "? " ^ string_of_int max_suggs ^ " # " ^ encode_strs parents ^ "; " ^ encode_features feats ^ "\n"
-
-(* The suggested weights do not make much sense. *)
-fun extract_suggestion sugg =
- (case space_explode "=" sugg of
- [name, weight] => SOME (decode_str name, Real.fromString weight |> the_default 1.0)
- | [name] => SOME (decode_str name, 1.0)
- | _ => NONE)
-
-fun extract_suggestions line =
- (case space_explode ":" line of
- [goal, suggs] => (decode_str goal, map_filter extract_suggestion (space_explode " " suggs))
- | _ => ("", []))
-
-structure MaSh_Py =
-struct
-
-fun shutdown ctxt overlord =
- (trace_msg ctxt (K "MaSh_Py shutdown");
- run_mash_tool ctxt overlord [shutdown_server_arg] false ([], K "") (K ()))
-fun save ctxt overlord =
- (trace_msg ctxt (K "MaSh_Py save");
- run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ()))
-
-fun unlearn ctxt overlord =
- (trace_msg ctxt (K "MaSh_Py unlearn");
- shutdown ctxt overlord;
- wipe_out_mash_state_dir ())
-
-fun learn _ _ _ [] = ()
- | learn ctxt overlord save learns =
- (trace_msg ctxt (fn () =>
- "MaSh_Py learn {" ^ elide_string 1000 (space_implode " " (map #1 learns)) ^ "}");
- run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false (learns, str_of_learn)
- (K ()))
+(*** Isabelle-agnostic machine learning ***)
-fun relearn _ _ _ [] = ()
- | relearn ctxt overlord save relearns =
- (trace_msg ctxt (fn () => "MaSh_Py relearn " ^
- elide_string 1000 (space_implode " " (map #1 relearns)));
- run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
- (relearns, str_of_relearn) (K ()))
-
-fun query ctxt overlord max_suggs (query as (_, _, feats)) =
- (trace_msg ctxt (fn () => "MaSh_Py query " ^ encode_features feats);
- run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs) (fn suggs =>
- (case suggs () of [] => [] | suggs => snd (extract_suggestions (List.last suggs))))
- handle List.Empty => [])
-
-end;
-
-
-(*** Standard ML version of MaSh ***)
-
-structure MaSh_SML =
+structure MaSh =
struct
fun heap cmp bnd al a =
@@ -542,25 +425,6 @@
end
(* experimental *)
-fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs goal_feats =
- let
- fun name_of_fact j = "f" ^ string_of_int j
- fun fact_of_name s = the (Int.fromString (unprefix "f" s))
- fun name_of_feature j = "F" ^ string_of_int j
- fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
-
- val learns = map (fn j => (name_of_fact j, parents_of j,
- map name_of_feature (Vector.sub (featss, j)),
- map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
- val parents' = parents_of num_facts
- in
- MaSh_Py.unlearn ctxt overlord;
- OS.Process.sleep (seconds 2.0); (* hack *)
- MaSh_Py.query ctxt overlord max_suggs (learns, parents', goal_feats)
- |> map (apfst fact_of_name)
- end
-
-(* experimental *)
fun external_tool tool max_suggs learns goal_feats =
let
val ser = string_of_int (serial ()) (* poor person's attempt at thread-safety *)
@@ -600,17 +464,17 @@
val naive_bayes_ext = external_tool "predict/nbayes"
fun query_external ctxt engine max_suggs learns goal_feats =
- (trace_msg ctxt (fn () => "MaSh_SML query external " ^ encode_features goal_feats);
+ (trace_msg ctxt (fn () => "MaSh query external " ^ encode_features goal_feats);
(case engine of
- MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
- | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
+ MaSh_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
+ | MaSh_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
(freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
- (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_features goal_feats ^ " from {" ^
+ (trace_msg ctxt (fn () => "MaSh query internal " ^ encode_features goal_feats ^ " from {" ^
elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
(case engine of
- MaSh_SML_kNN =>
+ MaSh_kNN =>
let
val feat_facts = Array.array (num_feats, [])
val _ =
@@ -620,7 +484,7 @@
in
k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
end
- | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
+ | MaSh_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
|> map (curry Vector.sub fact_names o fst))
end;
@@ -698,7 +562,7 @@
Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)]
in
((fact_names, featss, depss),
- MaSh_SML.learn_facts freqs0 num_facts0 num_facts num_feats depss featss)
+ MaSh.learn_facts freqs0 num_facts0 num_facts num_feats depss featss)
end
fun reorder_learns (num_facts, fact_tab) learns =
@@ -734,7 +598,7 @@
| _ => NONE)
| _ => NONE)
-fun load_state ctxt overlord (time_state as (memory_time, _)) =
+fun load_state ctxt (time_state as (memory_time, _)) =
let val path = mash_state_file () in
(case try OS.FileSys.modTime (Path.implode path) of
NONE => time_state
@@ -758,11 +622,7 @@
EQUAL =>
try_graph ctxt "loading state" empty_G_etc
(fn () => fold extract_line_and_add_node node_lines empty_G_etc)
- | LESS =>
- (* cannot parse old file *)
- (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
- else wipe_out_mash_state_dir ();
- empty_G_etc)
+ | LESS => (wipe_out_mash_state_dir (); empty_G_etc) (* cannot parse old file *)
| GREATER => raise FILE_VERSION_TOO_NEW ())
val (ffds, freqs) =
@@ -794,7 +654,9 @@
SOME names => (NONE, fold (append_entry o Graph.get_entry access_G) names [])
| NONE => (SOME (version ^ "\n"), Graph.fold append_entry access_G []))
in
- write_file banner (entries, str_of_entry) path;
+ (case banner of SOME s => File.write path s | NONE => ();
+ entries |> chunk_list 500 |> List.app (File.append path o implode o map str_of_entry))
+ handle IO.Io _ => ();
trace_msg ctxt (fn () =>
"Saved fact graph (" ^ graph_info access_G ^
(case dirty_facts of
@@ -808,25 +670,19 @@
in
-fun map_state ctxt overlord f =
- Synchronized.change global_state (load_state ctxt overlord ##> f #> save_state ctxt)
+fun map_state ctxt f =
+ Synchronized.change global_state (load_state ctxt ##> f #> save_state ctxt)
handle FILE_VERSION_TOO_NEW () => ()
-fun peek_state ctxt overlord f =
- Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
+fun peek_state ctxt =
+ Synchronized.change_result global_state (perhaps (try (load_state ctxt)) #> `snd)
-fun clear_state ctxt overlord =
- (* "MaSh_Py.unlearn" also removes the state file *)
+fun clear_state () =
Synchronized.change global_state (fn _ =>
- (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
- else wipe_out_mash_state_dir ();
- (Time.zeroTime, empty_state)))
+ (wipe_out_mash_state_dir (); (Time.zeroTime, empty_state)))
end
-fun mash_unlearn ctxt ({overlord, ...} : params) =
- (clear_state ctxt overlord; Output.urgent_message "Reset MaSh.")
-
(*** Isabelle helpers ***)
@@ -1284,7 +1140,7 @@
(mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess, unknown)
end
-fun mash_suggested_facts ctxt thy ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
+fun mash_suggested_facts ctxt thy ({debug, ...} : params) max_suggs hyp_ts concl_t facts =
let
val thy_name = Context.theory_name thy
val engine = the_mash_engine ()
@@ -1323,50 +1179,40 @@
(parents, feats)
end
- val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) =
- peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} =>
- ((access_G, xtabs, ffds, freqs),
- if Graph.is_empty access_G then
- (trace_msg ctxt (K "Nothing has been learned yet"); [])
- else if engine = MaSh_Py then
- let val (parents, feats) = query_args access_G in
- MaSh_Py.query ctxt overlord max_suggs ([], parents, feats)
- |> map fst
- end
- else
- []))
+ val {access_G, xtabs = ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs, ...} =
+ peek_state ctxt
- val sml_suggs =
- if engine = MaSh_Py then
- []
- else
- let
- val (parents, goal_feats) = query_args access_G
- val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
- in
- if engine = MaSh_SML_kNN_Ext orelse engine = MaSh_SML_NB_Ext then
- let
- val learns =
- Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
- in
- MaSh_SML.query_external ctxt engine max_suggs learns goal_feats
- end
- else
- let
- val int_goal_feats =
- map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
- in
- MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
- max_suggs goal_feats int_goal_feats
- end
- end
+ val suggs =
+ let
+ val (parents, goal_feats) = query_args access_G
+ val visible_facts = map_filter (Symtab.lookup fact_tab) (Graph.all_preds access_G parents)
+ in
+ if engine = MaSh_kNN_Ext orelse engine = MaSh_NB_Ext then
+ let
+ val learns =
+ Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
+ in
+ MaSh.query_external ctxt engine max_suggs learns goal_feats
+ end
+ else
+ let
+ val int_goal_feats =
+ map_filter (fn (s, w) => Option.map (rpair w) (Symtab.lookup feat_tab s)) goal_feats
+ in
+ MaSh.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts max_suggs
+ goal_feats int_goal_feats
+ end
+ end
val unknown = filter_out (is_fact_in_graph access_G o snd) facts
in
- find_mash_suggestions ctxt max_suggs (py_suggs @ sml_suggs) facts chained unknown
+ find_mash_suggestions ctxt max_suggs suggs facts chained unknown
|> pairself (map fact_of_raw_fact)
end
+fun mash_unlearn () =
+ (clear_state (); Output.urgent_message "Reset MaSh.")
+
fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (access_G, (fact_xtab, feat_xtab)) =
let
fun maybe_learn_from from (accum as (parents, access_G)) =
@@ -1413,35 +1259,30 @@
fun learned_proof_name () =
Date.fmt ".%Y%m%d.%H%M%S." (Date.fromTimeLocal (Time.now ())) ^ serial_string ()
-fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) t facts used_ths =
+fun mash_learn_proof ctxt ({timeout, ...} : params) t facts used_ths =
if not (null used_ths) andalso is_mash_enabled () then
launch_thread timeout (fn () =>
let
val thy = Proof_Context.theory_of ctxt
val feats = features_of ctxt thy (Local, General) [t]
in
- map_state ctxt overlord
- (fn state as {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
+ map_state ctxt
+ (fn {access_G, xtabs as ((num_facts0, _), _), ffds, freqs, dirty_facts} =>
let
val parents = maximal_wrt_access_graph access_G facts
val deps = used_ths
|> filter (is_fact_in_graph access_G)
|> map nickname_of_thm
+
+ val name = learned_proof_name ()
+ val (access_G', xtabs', rev_learns) =
+ add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])
+
+ val (ffds', freqs') =
+ recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs
in
- if the_mash_engine () = MaSh_Py then
- (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
- else
- let
- val name = learned_proof_name ()
- val (access_G', xtabs', rev_learns) =
- add_node Automatic_Proof name parents feats deps (access_G, xtabs, [])
-
- val (ffds', freqs') =
- recompute_ffds_freqs_from_learns (rev rev_learns) xtabs' num_facts0 ffds freqs
- in
- {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
- dirty_facts = Option.map (cons name) dirty_facts}
- end
+ {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
+ dirty_facts = Option.map (cons name) dirty_facts}
end);
(true, "")
end)
@@ -1453,14 +1294,13 @@
val commit_timeout = seconds 30.0
(* The timeout is understood in a very relaxed fashion. *)
-fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover save auto_level
- run_prover learn_timeout facts =
+fun mash_learn_facts ctxt (params as {debug, verbose, ...}) prover auto_level run_prover
+ learn_timeout facts =
let
val timer = Timer.startRealTimer ()
fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
- val engine = the_mash_engine ()
- val {access_G, ...} = peek_state ctxt overlord I
+ val {access_G, ...} = peek_state ctxt
val is_in_access_G = is_fact_in_graph access_G o snd
val no_new_facts = forall is_in_access_G facts
in
@@ -1511,18 +1351,13 @@
else
recompute_ffds_freqs_from_access_G access_G xtabs
in
- if engine = MaSh_Py then
- (MaSh_Py.learn ctxt overlord (save andalso null relearns) learns;
- MaSh_Py.relearn ctxt overlord save relearns)
- else
- ();
{access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs',
dirty_facts = dirty_facts}
end
fun commit last learns relearns flops =
(if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
- map_state ctxt overlord (do_commit (rev learns) relearns flops);
+ map_state ctxt (do_commit (rev learns) relearns flops);
if not last andalso auto_level = 0 then
let val num_proofs = length learns + length relearns in
Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
@@ -1633,7 +1468,7 @@
val prover = hd provers
fun learn auto_level run_prover =
- mash_learn_facts ctxt params prover true auto_level run_prover one_year facts
+ mash_learn_facts ctxt params prover auto_level run_prover one_year facts
|> Output.urgent_message
in
if run_prover then
@@ -1650,8 +1485,8 @@
learn 0 false)
end
-fun mash_can_suggest_facts ctxt overlord =
- not (Graph.is_empty (#access_G (peek_state ctxt overlord I)))
+fun mash_can_suggest_facts ctxt =
+ not (Graph.is_empty (#access_G (peek_state ctxt)))
(* Generate more suggestions than requested, because some might be thrown out later for various
reasons (e.g., duplicates). *)
@@ -1666,7 +1501,7 @@
and Try. *)
val min_secs_for_learning = 15
-fun relevant_facts ctxt (params as {verbose, overlord, learn, fact_filter, timeout, ...}) prover
+fun relevant_facts ctxt (params as {verbose, learn, fact_filter, timeout, ...}) prover
max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
if not (subset (op =) (the_list fact_filter, fact_filters)) then
error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
@@ -1689,42 +1524,41 @@
else
());
launch_thread timeout
- (fn () => (true, mash_learn_facts ctxt params prover true 2 false timeout facts))
+ (fn () => (true, mash_learn_facts ctxt params prover 2 false timeout facts))
end
else
()
fun maybe_learn () =
- if is_mash_enabled () andalso learn then
+ if learn then
let
- val {access_G, xtabs = ((num_facts0, _), _), ...} = peek_state ctxt overlord I
+ val {access_G, xtabs = ((num_facts0, _), _), ...} = peek_state ctxt
val is_in_access_G = is_fact_in_graph access_G o snd
val min_num_facts_to_learn = length facts - num_facts0
in
if min_num_facts_to_learn <= max_facts_to_learn_before_query then
(case length (filter_out is_in_access_G facts) of
- 0 => false
+ 0 => ()
| num_facts_to_learn =>
if num_facts_to_learn <= max_facts_to_learn_before_query then
- (mash_learn_facts ctxt params prover false 2 false timeout facts
- |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s));
- true)
+ mash_learn_facts ctxt params prover 2 false timeout facts
+ |> (fn "" => () | s => Output.urgent_message (MaShN ^ ": " ^ s))
else
- (maybe_launch_thread num_facts_to_learn; false))
+ maybe_launch_thread num_facts_to_learn)
else
- (maybe_launch_thread min_num_facts_to_learn; false)
+ maybe_launch_thread min_num_facts_to_learn
end
else
- false
+ ()
- val (save, effective_fact_filter) =
+ val effective_fact_filter =
(case fact_filter of
- SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
+ SOME ff => ff
| NONE =>
if is_mash_enabled () then
- (maybe_learn (), if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
+ (maybe_learn (); if mash_can_suggest_facts ctxt then meshN else mepoN)
else
- (false, mepoN))
+ mepoN)
val unique_facts = drop_duplicate_facts facts
val add_ths = Attrib.eval_thms ctxt add
@@ -1754,7 +1588,6 @@
|> Par_List.map (apsnd (fn f => f ()))
val mesh = mesh_facts (eq_snd (gen_eq_thm ctxt)) max_facts mess |> add_and_take
in
- if the_mash_engine () = MaSh_Py andalso save then MaSh_Py.save ctxt overlord else ();
(case (fact_filter, mess) of
(NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
[(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),
@@ -1762,9 +1595,7 @@
| _ => [(effective_fact_filter, mesh)])
end
-fun kill_learners ctxt ({overlord, ...} : params) =
- (Async_Manager.kill_threads MaShN "learner";
- if the_mash_engine () = MaSh_Py then MaSh_Py.shutdown ctxt overlord else ())
+fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
fun running_learners () = Async_Manager.running_threads MaShN "learner"
--- a/src/HOL/Tools/etc/options Sat Jun 28 22:13:23 2014 +0200
+++ b/src/HOL/Tools/etc/options Sun Jun 29 18:28:27 2014 +0200
@@ -36,4 +36,4 @@
-- "status of Z3 activation for non-commercial use (yes, no, unknown)"
public option MaSh : string = "sml"
- -- "machine learning engine to use by Sledgehammer (sml, sml_knn, sml_nb, py, none)"
+ -- "machine learning engine to use by Sledgehammer (sml, nb, knn, none)"