src/HOL/Tools/Sledgehammer/MaSh/src/server.py
changeset 53115 e08a58161bf1
parent 53100 1133b9e83f09
child 53119 ac18480cbf9d
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Wed Aug 21 09:25:40 2013 +0200
@@ -107,30 +107,36 @@
 
         # Output        
         predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
-        predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]]
+        predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
         predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
         predictionsString = string.join(predictionsStringList,' ')
         outString = '%s: %s' % (name,predictionsString)
         self.request.sendall(outString)
     
-    def shutdown(self):
+    def shutdown(self,saveModels=True):
         self.request.sendall('Shutting down server.')
+        if saveModels:
+            self.save()
+        self.server.shutdown()
+    
+    def save(self):
         # Save Models
         self.server.model.save(self.server.args.modelFile)
         self.server.dicts.save(self.server.args.dictsFile)
         if not self.server.args.saveStats == None:
             statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
             self.server.stats.save(statsFile)
-        self.server.shutdown()
     
     def handle(self):
         # self.request is the TCP socket connected to the client
-        self.data = self.request.recv(262144).strip()
+        self.data = self.request.recv(4194304).strip()
+        self.server.lock.acquire()
         #print "{} wrote:".format(self.client_address[0])
-        #print self.data
-        self.startTime = time()
-        if self.data == 's':
-            self.shutdown()            
+        self.startTime = time()  
+        if self.data == 'shutdown':
+            self.shutdown()         
+        elif self.data == 'save':
+            self.save()       
         elif self.data.startswith('i'):            
             self.init(self.data[2:])
         elif self.data.startswith('!'):
@@ -141,15 +147,13 @@
             self.predict()
         elif self.data == '':
             # Empty Socket
-            return
+            pass
         elif self.data == 'avgStats':
             self.request.sendall(self.server.stats.printAvg())            
         else:
             self.request.sendall('Unspecified input format: \n%s',self.data)
         self.server.callCounter += 1
-
-        #returnString = 'Time needed: '+str(round(time()-self.startTime,2))
-        #print returnString
+        self.server.lock.release()
 
 if __name__ == "__main__":
     HOST, PORT = "localhost", 9255