--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Thu Sep 12 00:34:48 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Thu Sep 12 09:59:45 2013 +0200
@@ -7,10 +7,15 @@
import SocketServer,os,string,logging
from multiprocessing import Manager
+from threading import Timer
from time import time
from dictionaries import Dictionaries
from parameters import init_parser
from sparseNaiveBayes import sparseNBClassifier
+from KNN import KNN,euclidean
+from KNNs import KNNAdaptPointFeatures,KNNUrban
+from predefined import Predefined
+from ExpandFeatures import ExpandFeatures
from stats import Statistics
@@ -19,6 +24,21 @@
SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs)
self.manager = Manager()
self.lock = Manager().Lock()
+ self.idle_timeout = 28800.0 # 8 hours in seconds
+ self.idle_timer = Timer(self.idle_timeout, self.shutdown)
+ self.idle_timer.start()
+
+ def save(self):
+ # Save Models
+ self.model.save(self.args.modelFile)
+ self.dicts.save(self.args.dictsFile)
+ if not self.args.saveStats == None:
+ statsFile = os.path.join(self.args.outputDir,self.args.saveStats)
+ self.stats.save(statsFile)
+
+ def save_and_shutdown(self):
+ self.save()
+ self.shutdown()
class MaShHandler(SocketServer.BaseRequestHandler):
@@ -28,25 +48,32 @@
else:
argv = argv.split(';')
self.server.args = init_parser(argv)
- # Pick model
- if self.server.args.algorithm == 'nb':
- self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
- else: # Default case
- self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
# Load all data
- # TODO: rewrite dicts for concurrency and without sine
self.server.dicts = Dictionaries()
if os.path.isfile(self.server.args.dictsFile):
self.server.dicts.load(self.server.args.dictsFile)
elif self.server.args.init:
self.server.dicts.init_all(self.server.args)
+ # Pick model
+ if self.server.args.algorithm == 'nb':
+ self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+ elif self.server.args.algorithm == 'KNN':
+ #self.server.model = KNN(self.server.dicts)
+ self.server.model = KNNAdaptPointFeatures(self.server.dicts)
+ elif self.server.args.algorithm == 'predef':
+ self.server.model = Predefined(self.server.args.predef)
+ else: # Default case
+ self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+ if self.server.args.expandFeatures:
+ self.server.expandFeatures = ExpandFeatures(self.server.dicts)
+ self.server.expandFeatures.initialize(self.server.dicts)
# Create Model
if os.path.isfile(self.server.args.modelFile):
self.server.model.load(self.server.args.modelFile)
elif self.server.args.init:
trainData = self.server.dicts.featureDict.keys()
self.server.model.initializeModel(trainData,self.server.dicts)
-
+
if self.server.args.statistics:
self.server.stats = Statistics(self.server.args.cutOff)
self.server.statementCounter = 1
@@ -77,6 +104,8 @@
self.server.logger.debug('Poor predictions: %s',bp)
self.server.statementCounter += 1
+ if self.server.args.expandFeatures:
+ self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
# Update Dependencies, p proves p
self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
@@ -92,22 +121,25 @@
self.server.computeStats = True
if self.server.args.algorithm == 'predef':
return
- name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)
+ name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)
if numberOfPredictions == None:
numberOfPredictions = self.server.args.numberOfPredictions
if not hints == []:
self.server.model.update('hints',features,hints)
-
+ if self.server.args.expandFeatures:
+ features = self.server.expandFeatures.expand(features)
# Create predictions
self.server.logger.debug('Starting computation for line %s',self.server.callCounter)
- predictionsFeatures = features
- self.server.predictions,predictionValues = self.server.model.predict(predictionsFeatures,accessibles,self.server.dicts)
+
+ self.server.predictions,predictionValues = self.server.model.predict(features,accessibles,self.server.dicts)
assert len(self.server.predictions) == len(predictionValues)
self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2)))
# Output
predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
- predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
+ #predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
+ #predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
+ #predictionsString = string.join(predictionsStringList,' ')
predictionsString = string.join(predictionNames,' ')
outString = '%s: %s' % (name,predictionsString)
self.request.sendall(outString)
@@ -115,27 +147,18 @@
def shutdown(self,saveModels=True):
self.request.sendall('Shutting down server.')
if saveModels:
- self.save()
+ self.server.save()
self.server.shutdown()
- def save(self):
- # Save Models
- self.server.model.save(self.server.args.modelFile)
- self.server.dicts.save(self.server.args.dictsFile)
- if not self.server.args.saveStats == None:
- statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
- self.server.stats.save(statsFile)
-
def handle(self):
# self.request is the TCP socket connected to the client
self.data = self.request.recv(4194304).strip()
self.server.lock.acquire()
- #print "{} wrote:".format(self.client_address[0])
self.startTime = time()
if self.data == 'shutdown':
self.shutdown()
elif self.data == 'save':
- self.save()
+ self.server.save()
elif self.data.startswith('i'):
self.init(self.data[2:])
elif self.data.startswith('!'):
@@ -153,15 +176,16 @@
else:
self.request.sendall('Unspecified input format: \n%s',self.data)
self.server.callCounter += 1
+ # Update idle shutdown timer
+ self.server.idle_timer.cancel()
+ self.server.idle_timer = Timer(self.server.idle_timeout, self.server.save_and_shutdown)
+ self.server.idle_timer.start()
self.server.lock.release()
if __name__ == "__main__":
HOST, PORT = "localhost", 9255
- #print 'Started Server'
- # Create the server, binding to localhost on port 9999
SocketServer.TCPServer.allow_reuse_address = True
server = ThreadingTCPServer((HOST, PORT), MaShHandler)
- #server = SocketServer.TCPServer((HOST, PORT), MaShHandler)
# Activate the server; this will keep running until you
# interrupt the program with Ctrl-C
@@ -171,4 +195,4 @@
-
+
\ No newline at end of file