--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Wed Aug 21 09:25:40 2013 +0200
@@ -35,7 +35,7 @@
try:
sock.connect((host,port))
sock.sendall(data)
- received = sock.recv(262144)
+ received = sock.recv(4194304)
except:
logger = logging.getLogger('communicate')
logger.warning('Communication with server failed.')
@@ -69,7 +69,16 @@
if not os.path.exists(args.outputDir):
os.makedirs(args.outputDir)
+ # Shutdown commands need not start the server fist.
+ if args.shutdownServer:
+ try:
+ communicate('shutdown',args.host,args.port)
+ except:
+ pass
+ return
+
# If server is not running, start it.
+ startedServer = False
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.host,args.port))
@@ -79,7 +88,10 @@
spawnDaemon('server.py')
# TODO: Make this fault tolerant
time.sleep(0.5)
- # Init server
+ startedServer = True
+
+ if args.init or startedServer:
+ logger.info('Initializing Server.')
data = "i "+";".join(argv)
received = communicate(data,args.host,args.port)
logger.info(received)
@@ -90,15 +102,10 @@
# IO Streams
OS = open(args.predictions,'w')
IS = open(args.inputFile,'r')
- count = 0
- for line in IS:
- count += 1
- #if count == 127:
- # break as
+ for line in IS:
received = communicate(line,args.host,args.port)
if not received == '':
OS.write('%s\n' % received)
- #logger.info(received)
OS.close()
IS.close()
@@ -106,6 +113,8 @@
if args.statistics:
received = communicate('avgStats',args.host,args.port)
logger.info(received)
+ elif args.saveModels:
+ communicate('save',args.host,args.port)
if __name__ == "__main__":
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Wed Aug 21 09:25:40 2013 +0200
@@ -22,7 +22,7 @@
parser.add_argument('--depFile', default='mash_dependencies',
help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
- parser.add_argument('--algorithm',default='nb',action='store_true',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.")
+ parser.add_argument('--algorithm',default='nb',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.")
# NB Parameters
parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float)
parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float)
@@ -42,5 +42,7 @@
parser.add_argument('--port', default='9255', help='Port of the Mash server. Default=9255',type=int)
parser.add_argument('--host', default='localhost', help='Host of the Mash server. Default=localhost')
+ parser.add_argument('--shutdownServer',default=False,action='store_true',help="Shutdown server without saving the models.")
+ parser.add_argument('--saveModels',default=False,action='store_true',help="Server saves the models.")
args = parser.parse_args(argv)
return args
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Wed Aug 21 09:25:40 2013 +0200
@@ -107,30 +107,36 @@
# Output
predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
- predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]]
+ predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
predictionsString = string.join(predictionsStringList,' ')
outString = '%s: %s' % (name,predictionsString)
self.request.sendall(outString)
- def shutdown(self):
+ def shutdown(self,saveModels=True):
self.request.sendall('Shutting down server.')
+ if saveModels:
+ self.save()
+ self.server.shutdown()
+
+ def save(self):
# Save Models
self.server.model.save(self.server.args.modelFile)
self.server.dicts.save(self.server.args.dictsFile)
if not self.server.args.saveStats == None:
statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
self.server.stats.save(statsFile)
- self.server.shutdown()
def handle(self):
# self.request is the TCP socket connected to the client
- self.data = self.request.recv(262144).strip()
+ self.data = self.request.recv(4194304).strip()
+ self.server.lock.acquire()
#print "{} wrote:".format(self.client_address[0])
- #print self.data
- self.startTime = time()
- if self.data == 's':
- self.shutdown()
+ self.startTime = time()
+ if self.data == 'shutdown':
+ self.shutdown()
+ elif self.data == 'save':
+ self.save()
elif self.data.startswith('i'):
self.init(self.data[2:])
elif self.data.startswith('!'):
@@ -141,15 +147,13 @@
self.predict()
elif self.data == '':
# Empty Socket
- return
+ pass
elif self.data == 'avgStats':
self.request.sendall(self.server.stats.printAvg())
else:
self.request.sendall('Unspecified input format: \n%s',self.data)
self.server.callCounter += 1
-
- #returnString = 'Time needed: '+str(round(time()-self.startTime,2))
- #print returnString
+ self.server.lock.release()
if __name__ == "__main__":
HOST, PORT = "localhost", 9255