src/HOL/Tools/Sledgehammer/MaSh/src/server.py
changeset 53555 12251bc889f1
parent 53135 f08f66b55cb5
child 53557 5d3ec1198a64
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Sep 12 00:34:48 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Sep 12 09:59:45 2013 +0200
@@ -7,10 +7,15 @@
 
 import SocketServer,os,string,logging
 from multiprocessing import Manager
+from threading import Timer
 from time import time
 from dictionaries import Dictionaries
 from parameters import init_parser
 from sparseNaiveBayes import sparseNBClassifier
+from KNN import KNN,euclidean
+from KNNs import KNNAdaptPointFeatures,KNNUrban
+from predefined import Predefined
+from ExpandFeatures import ExpandFeatures
 from stats import Statistics
 
 
@@ -19,6 +24,21 @@
         SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs)
         self.manager = Manager()
         self.lock = Manager().Lock()
+        self.idle_timeout = 28800.0 # 8 hours in seconds
+        self.idle_timer = Timer(self.idle_timeout, self.shutdown)
+        self.idle_timer.start()        
+        
+    def save(self):
+        # Save Models
+        self.model.save(self.args.modelFile)
+        self.dicts.save(self.args.dictsFile)
+        if not self.args.saveStats == None:
+            statsFile = os.path.join(self.args.outputDir,self.args.saveStats)
+            self.stats.save(statsFile)   
+               
+    def save_and_shutdown(self):     
+        self.save()          
+        self.shutdown()
 
 class MaShHandler(SocketServer.BaseRequestHandler):
 
@@ -28,25 +48,32 @@
         else:
             argv = argv.split(';')
             self.server.args = init_parser(argv)
-        # Pick model
-        if self.server.args.algorithm == 'nb':
-            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
-        else: # Default case
-            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
         # Load all data
-        # TODO: rewrite dicts for concurrency and without sine
         self.server.dicts = Dictionaries()
         if os.path.isfile(self.server.args.dictsFile):
             self.server.dicts.load(self.server.args.dictsFile)            
         elif self.server.args.init:
             self.server.dicts.init_all(self.server.args)
+        # Pick model
+        if self.server.args.algorithm == 'nb':
+            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+        elif self.server.args.algorithm == 'KNN':
+            #self.server.model = KNN(self.server.dicts)
+            self.server.model = KNNAdaptPointFeatures(self.server.dicts)
+        elif self.server.args.algorithm == 'predef':
+            self.server.model = Predefined(self.server.args.predef)
+        else: # Default case
+            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+        if self.server.args.expandFeatures:
+            self.server.expandFeatures = ExpandFeatures(self.server.dicts)
+            self.server.expandFeatures.initialize(self.server.dicts)
         # Create Model
         if os.path.isfile(self.server.args.modelFile):
             self.server.model.load(self.server.args.modelFile)          
         elif self.server.args.init:
             trainData = self.server.dicts.featureDict.keys()
             self.server.model.initializeModel(trainData,self.server.dicts)
-            
+           
         if self.server.args.statistics:
             self.server.stats = Statistics(self.server.args.cutOff)
             self.server.statementCounter = 1
@@ -77,6 +104,8 @@
                 self.server.logger.debug('Poor predictions: %s',bp)
             self.server.statementCounter += 1
 
+        if self.server.args.expandFeatures:
+            self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
         # Update Dependencies, p proves p
         self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
         self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
@@ -92,22 +121,25 @@
         self.server.computeStats = True
         if self.server.args.algorithm == 'predef':
             return
-        name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)  
+        name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)
         if numberOfPredictions == None:
             numberOfPredictions = self.server.args.numberOfPredictions
         if not hints == []:
             self.server.model.update('hints',features,hints)
-        
+        if self.server.args.expandFeatures:
+            features = self.server.expandFeatures.expand(features)
         # Create predictions
         self.server.logger.debug('Starting computation for line %s',self.server.callCounter)
-        predictionsFeatures = features                    
-        self.server.predictions,predictionValues = self.server.model.predict(predictionsFeatures,accessibles,self.server.dicts)
+                
+        self.server.predictions,predictionValues = self.server.model.predict(features,accessibles,self.server.dicts)
         assert len(self.server.predictions) == len(predictionValues)
         self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2)))
 
         # Output        
         predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
-        predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
+        #predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
+        #predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
+        #predictionsString = string.join(predictionsStringList,' ')
         predictionsString = string.join(predictionNames,' ')
         outString = '%s: %s' % (name,predictionsString)
         self.request.sendall(outString)
@@ -115,27 +147,18 @@
     def shutdown(self,saveModels=True):
         self.request.sendall('Shutting down server.')
         if saveModels:
-            self.save()
+            self.server.save()
         self.server.shutdown()
     
-    def save(self):
-        # Save Models
-        self.server.model.save(self.server.args.modelFile)
-        self.server.dicts.save(self.server.args.dictsFile)
-        if not self.server.args.saveStats == None:
-            statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
-            self.server.stats.save(statsFile)
-    
     def handle(self):
         # self.request is the TCP socket connected to the client
         self.data = self.request.recv(4194304).strip()
         self.server.lock.acquire()
-        #print "{} wrote:".format(self.client_address[0])
         self.startTime = time()  
         if self.data == 'shutdown':
             self.shutdown()         
         elif self.data == 'save':
-            self.save()       
+            self.server.save()       
         elif self.data.startswith('i'):            
             self.init(self.data[2:])
         elif self.data.startswith('!'):
@@ -153,15 +176,16 @@
         else:
             self.request.sendall('Unspecified input format: \n%s',self.data)
         self.server.callCounter += 1
+        # Update idle shutdown timer
+        self.server.idle_timer.cancel()
+        self.server.idle_timer = Timer(self.server.idle_timeout, self.server.save_and_shutdown)
+        self.server.idle_timer.start()        
         self.server.lock.release()
 
 if __name__ == "__main__":
     HOST, PORT = "localhost", 9255
-    #print 'Started Server'
-    # Create the server, binding to localhost on port 9999
     SocketServer.TCPServer.allow_reuse_address = True
     server = ThreadingTCPServer((HOST, PORT), MaShHandler)
-    #server = SocketServer.TCPServer((HOST, PORT), MaShHandler)
 
     # Activate the server; this will keep running until you
     # interrupt the program with Ctrl-C
@@ -171,4 +195,4 @@
 
     
     
-    
+    
\ No newline at end of file