# HG changeset patch # User blanchet # Date 1377069940 -7200 # Node ID e08a58161bf1e3b4be401665085e3e4eef0352ab # Parent 4c2b1e64c99056b1bc14cd9933e6ce8f51853be4 new version of MaSh tool, with less broken server diff -r 4c2b1e64c990 -r e08a58161bf1 src/HOL/Tools/Sledgehammer/MaSh/src/mash.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Tue Aug 20 23:40:23 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py Wed Aug 21 09:25:40 2013 +0200 @@ -35,7 +35,7 @@ try: sock.connect((host,port)) sock.sendall(data) - received = sock.recv(262144) + received = sock.recv(4194304) except: logger = logging.getLogger('communicate') logger.warning('Communication with server failed.') @@ -69,7 +69,16 @@ if not os.path.exists(args.outputDir): os.makedirs(args.outputDir) + # Shutdown commands need not start the server fist. + if args.shutdownServer: + try: + communicate('shutdown',args.host,args.port) + except: + pass + return + # If server is not running, start it. + startedServer = False try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((args.host,args.port)) @@ -79,7 +88,10 @@ spawnDaemon('server.py') # TODO: Make this fault tolerant time.sleep(0.5) - # Init server + startedServer = True + + if args.init or startedServer: + logger.info('Initializing Server.') data = "i "+";".join(argv) received = communicate(data,args.host,args.port) logger.info(received) @@ -90,15 +102,10 @@ # IO Streams OS = open(args.predictions,'w') IS = open(args.inputFile,'r') - count = 0 - for line in IS: - count += 1 - #if count == 127: - # break as + for line in IS: received = communicate(line,args.host,args.port) if not received == '': OS.write('%s\n' % received) - #logger.info(received) OS.close() IS.close() @@ -106,6 +113,8 @@ if args.statistics: received = communicate('avgStats',args.host,args.port) logger.info(received) + elif args.saveModels: + communicate('save',args.host,args.port) if __name__ == "__main__": diff -r 4c2b1e64c990 -r e08a58161bf1 src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Tue Aug 20 23:40:23 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/parameters.py Wed Aug 21 09:25:40 2013 +0200 @@ -22,7 +22,7 @@ parser.add_argument('--depFile', default='mash_dependencies', help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies') - parser.add_argument('--algorithm',default='nb',action='store_true',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.") + parser.add_argument('--algorithm',default='nb',help="Which learning algorithm is used. nb = Naive Bayes,predef=predefined. Default=nb.") # NB Parameters parser.add_argument('--NBDefaultPriorWeight',default=20.0,help="Initializes classifiers with value * p |- p. Default=20.0.",type=float) parser.add_argument('--NBDefVal',default=-15.0,help="Default value for unknown features. Default=-15.0.",type=float) @@ -42,5 +42,7 @@ parser.add_argument('--port', default='9255', help='Port of the Mash server. Default=9255',type=int) parser.add_argument('--host', default='localhost', help='Host of the Mash server. Default=localhost') + parser.add_argument('--shutdownServer',default=False,action='store_true',help="Shutdown server without saving the models.") + parser.add_argument('--saveModels',default=False,action='store_true',help="Server saves the models.") args = parser.parse_args(argv) return args diff -r 4c2b1e64c990 -r e08a58161bf1 src/HOL/Tools/Sledgehammer/MaSh/src/server.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Tue Aug 20 23:40:23 2013 +0200 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py Wed Aug 21 09:25:40 2013 +0200 @@ -107,30 +107,36 @@ # Output predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]] - predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]] + predictionValues = [str(x) for x in predictionValues[:numberOfPredictions]] predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] predictionsString = string.join(predictionsStringList,' ') outString = '%s: %s' % (name,predictionsString) self.request.sendall(outString) - def shutdown(self): + def shutdown(self,saveModels=True): self.request.sendall('Shutting down server.') + if saveModels: + self.save() + self.server.shutdown() + + def save(self): # Save Models self.server.model.save(self.server.args.modelFile) self.server.dicts.save(self.server.args.dictsFile) if not self.server.args.saveStats == None: statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats) self.server.stats.save(statsFile) - self.server.shutdown() def handle(self): # self.request is the TCP socket connected to the client - self.data = self.request.recv(262144).strip() + self.data = self.request.recv(4194304).strip() + self.server.lock.acquire() #print "{} wrote:".format(self.client_address[0]) - #print self.data - self.startTime = time() - if self.data == 's': - self.shutdown() + self.startTime = time() + if self.data == 'shutdown': + self.shutdown() + elif self.data == 'save': + self.save() elif self.data.startswith('i'): self.init(self.data[2:]) elif self.data.startswith('!'): @@ -141,15 +147,13 @@ self.predict() elif self.data == '': # Empty Socket - return + pass elif self.data == 'avgStats': self.request.sendall(self.server.stats.printAvg()) else: self.request.sendall('Unspecified input format: \n%s',self.data) self.server.callCounter += 1 - - #returnString = 'Time needed: '+str(round(time()-self.startTime,2)) - #print returnString + self.server.lock.release() if __name__ == "__main__": HOST, PORT = "localhost", 9255