new version of MaSh tool, with less broken server
authorblanchet
Wed, 21 Aug 2013 09:25:40 +0200
changeset 53115 e08a58161bf1
parent 53114 4c2b1e64c990
child 53116 b1907f6b3c86
new version of MaSh tool, with less broken server
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py
src/HOL/Tools/Sledgehammer/MaSh/src/server.py
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Wed Aug 21 09:25:40 2013 +0200
@@ -35,7 +35,7 @@
     try:
         sock.connect((host,port))
         sock.sendall(data)
-        received = sock.recv(262144)
+        received = sock.recv(4194304)
     except:
         logger = logging.getLogger('communicate')
         logger.warning('Communication with server failed.')
@@ -69,7 +69,16 @@
     if not os.path.exists(args.outputDir):
         os.makedirs(args.outputDir)
 
+    # Shutdown commands need not start the server fist.
+    if args.shutdownServer:
+        try:
+            communicate('shutdown',args.host,args.port)
+        except:
+            pass
+        return
+
     # If server is not running, start it.
+    startedServer = False
     try:
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect((args.host,args.port))       
@@ -79,7 +88,10 @@
         spawnDaemon('server.py')
         # TODO: Make this fault tolerant
         time.sleep(0.5)
-        # Init server
+        startedServer = True
+        
+    if args.init or startedServer:
+        logger.info('Initializing Server.')
         data = "i "+";".join(argv)
         received = communicate(data,args.host,args.port)
         logger.info(received)     
@@ -90,15 +102,10 @@
     # IO Streams
     OS = open(args.predictions,'w')
     IS = open(args.inputFile,'r')
-    count = 0
-    for line in IS:
-        count += 1
-        #if count == 127:
-        #    break as       
+    for line in IS:    
         received = communicate(line,args.host,args.port)
         if not received == '':
             OS.write('%s\n' % received)
-        #logger.info(received)
     OS.close()
     IS.close()
 
@@ -106,6 +113,8 @@
     if args.statistics:
         received = communicate('avgStats',args.host,args.port)
         logger.info(received)
+    elif args.saveModels:
+        communicate('save',args.host,args.port)
 
 
 if __name__ == "__main__":
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py	Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py	Wed Aug 21 09:25:40 2013 +0200
@@ -22,7 +22,7 @@
     parser.add_argument('--depFile', default='mash_dependencies',
                         help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
     
-    parser.add_argument('--algorithm',default='nb',action='store_true',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.")
+    parser.add_argument('--algorithm',default='nb',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.")
     # NB Parameters
     parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float)
     parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float)
@@ -42,5 +42,7 @@
     
     parser.add_argument('--port', default='9255', help='Port of the Mash server. Default=9255',type=int)
     parser.add_argument('--host', default='localhost', help='Host of the Mash server. Default=localhost')
+    parser.add_argument('--shutdownServer',default=False,action='store_true',help="Shutdown server without saving the models.")
+    parser.add_argument('--saveModels',default=False,action='store_true',help="Server saves the models.")
     args = parser.parse_args(argv)    
     return args
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Tue Aug 20 23:40:23 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Wed Aug 21 09:25:40 2013 +0200
@@ -107,30 +107,36 @@
 
         # Output        
         predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
-        predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]]
+        predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]]
         predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
         predictionsString = string.join(predictionsStringList,' ')
         outString = '%s: %s' % (name,predictionsString)
         self.request.sendall(outString)
     
-    def shutdown(self):
+    def shutdown(self,saveModels=True):
         self.request.sendall('Shutting down server.')
+        if saveModels:
+            self.save()
+        self.server.shutdown()
+    
+    def save(self):
         # Save Models
         self.server.model.save(self.server.args.modelFile)
         self.server.dicts.save(self.server.args.dictsFile)
         if not self.server.args.saveStats == None:
             statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
             self.server.stats.save(statsFile)
-        self.server.shutdown()
     
     def handle(self):
         # self.request is the TCP socket connected to the client
-        self.data = self.request.recv(262144).strip()
+        self.data = self.request.recv(4194304).strip()
+        self.server.lock.acquire()
         #print "{} wrote:".format(self.client_address[0])
-        #print self.data
-        self.startTime = time()
-        if self.data == 's':
-            self.shutdown()            
+        self.startTime = time()  
+        if self.data == 'shutdown':
+            self.shutdown()         
+        elif self.data == 'save':
+            self.save()       
         elif self.data.startswith('i'):            
             self.init(self.data[2:])
         elif self.data.startswith('!'):
@@ -141,15 +147,13 @@
             self.predict()
         elif self.data == '':
             # Empty Socket
-            return
+            pass
         elif self.data == 'avgStats':
             self.request.sendall(self.server.stats.printAvg())            
         else:
             self.request.sendall('Unspecified input format: \n%s',self.data)
         self.server.callCounter += 1
-
-        #returnString = 'Time needed: '+str(round(time()-self.startTime,2))
-        #print returnString
+        self.server.lock.release()
 
 if __name__ == "__main__":
     HOST, PORT = "localhost", 9255