src/HOL/Tools/Sledgehammer/MaSh/src/server.py
changeset 53100 1133b9e83f09
child 53115 e08a58161bf1
equal deleted inserted replaced
53099:5c7780d21d24 53100:1133b9e83f09
       
     1 #!/usr/bin/env python
       
     2 #     Title:      HOL/Tools/Sledgehammer/MaSh/src/server.py
       
     3 #     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
       
     4 #     Copyright   2013
       
     5 #
       
     6 # The MaSh Server.
       
     7 
       
     8 import SocketServer,os,string,logging
       
     9 from multiprocessing import Manager
       
    10 from time import time
       
    11 from dictionaries import Dictionaries
       
    12 from parameters import init_parser
       
    13 from sparseNaiveBayes import sparseNBClassifier
       
    14 from stats import Statistics
       
    15 
       
    16 
       
    17 class ThreadingTCPServer(SocketServer.ThreadingTCPServer): 
       
    18     def __init__(self, *args, **kwargs):
       
    19         SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs)
       
    20         self.manager = Manager()
       
    21         self.lock = Manager().Lock()
       
    22 
       
    23 class MaShHandler(SocketServer.BaseRequestHandler):
       
    24 
       
    25     def init(self,argv):
       
    26         if argv == '':
       
    27             self.server.args = init_parser([])
       
    28         else:
       
    29             argv = argv.split(';')
       
    30             self.server.args = init_parser(argv)
       
    31         # Pick model
       
    32         if self.server.args.algorithm == 'nb':
       
    33             self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
       
    34         else: # Default case
       
    35             self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
       
    36         # Load all data
       
    37         # TODO: rewrite dicts for concurrency and without sine
       
    38         self.server.dicts = Dictionaries()
       
    39         if os.path.isfile(self.server.args.dictsFile):
       
    40             self.server.dicts.load(self.server.args.dictsFile)            
       
    41         elif self.server.args.init:
       
    42             self.server.dicts.init_all(self.server.args)
       
    43         # Create Model
       
    44         if os.path.isfile(self.server.args.modelFile):
       
    45             self.server.model.load(self.server.args.modelFile)          
       
    46         elif self.server.args.init:
       
    47             trainData = self.server.dicts.featureDict.keys()
       
    48             self.server.model.initializeModel(trainData,self.server.dicts)
       
    49             
       
    50         if self.server.args.statistics:
       
    51             self.server.stats = Statistics(self.server.args.cutOff)
       
    52             self.server.statementCounter = 1
       
    53             self.server.computeStats = False
       
    54 
       
    55         # Set up logging
       
    56         logging.basicConfig(level=logging.DEBUG,
       
    57                             format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
       
    58                             datefmt='%d-%m %H:%M:%S',
       
    59                             filename=self.server.args.log+'server',
       
    60                             filemode='w')    
       
    61         self.server.logger = logging.getLogger('server')
       
    62         self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.')
       
    63         self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.')
       
    64         self.server.callCounter = 1
       
    65 
       
    66     def update(self):
       
    67         problemId = self.server.dicts.parse_fact(self.data)    
       
    68         # Statistics
       
    69         if self.server.args.statistics and self.server.computeStats:
       
    70             self.server.computeStats = False
       
    71             # Assume '!' comes after '?'
       
    72             if self.server.args.algorithm == 'predef':
       
    73                 self.server.predictions = self.server.model.predict(problemId)
       
    74             self.server.stats.update(self.server.predictions,self.server.dicts.dependenciesDict[problemId],self.server.statementCounter)
       
    75             if not self.server.stats.badPreds == []:
       
    76                 bp = string.join([str(self.server.dicts.idNameDict[x]) for x in self.server.stats.badPreds], ',')
       
    77                 self.server.logger.debug('Poor predictions: %s',bp)
       
    78             self.server.statementCounter += 1
       
    79 
       
    80         # Update Dependencies, p proves p
       
    81         self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
       
    82         self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
       
    83 
       
    84     def overwrite(self):
       
    85         # Overwrite old proof.
       
    86         problemId,newDependencies = self.server.dicts.parse_overwrite(self.data)
       
    87         newDependencies = [problemId]+newDependencies
       
    88         self.server.model.overwrite(problemId,newDependencies,self.server.dicts)
       
    89         self.server.dicts.dependenciesDict[problemId] = newDependencies
       
    90         
       
    91     def predict(self):
       
    92         self.server.computeStats = True
       
    93         if self.server.args.algorithm == 'predef':
       
    94             return
       
    95         name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data)  
       
    96         if numberOfPredictions == None:
       
    97             numberOfPredictions = self.server.args.numberOfPredictions
       
    98         if not hints == []:
       
    99             self.server.model.update('hints',features,hints)
       
   100         
       
   101         # Create predictions
       
   102         self.server.logger.debug('Starting computation for line %s',self.server.callCounter)
       
   103         predictionsFeatures = features                    
       
   104         self.server.predictions,predictionValues = self.server.model.predict(predictionsFeatures,accessibles,self.server.dicts)
       
   105         assert len(self.server.predictions) == len(predictionValues)
       
   106         self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2)))
       
   107 
       
   108         # Output        
       
   109         predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]]
       
   110         predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]]
       
   111         predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]
       
   112         predictionsString = string.join(predictionsStringList,' ')
       
   113         outString = '%s: %s' % (name,predictionsString)
       
   114         self.request.sendall(outString)
       
   115     
       
   116     def shutdown(self):
       
   117         self.request.sendall('Shutting down server.')
       
   118         # Save Models
       
   119         self.server.model.save(self.server.args.modelFile)
       
   120         self.server.dicts.save(self.server.args.dictsFile)
       
   121         if not self.server.args.saveStats == None:
       
   122             statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats)
       
   123             self.server.stats.save(statsFile)
       
   124         self.server.shutdown()
       
   125     
       
   126     def handle(self):
       
   127         # self.request is the TCP socket connected to the client
       
   128         self.data = self.request.recv(262144).strip()
       
   129         #print "{} wrote:".format(self.client_address[0])
       
   130         #print self.data
       
   131         self.startTime = time()
       
   132         if self.data == 's':
       
   133             self.shutdown()            
       
   134         elif self.data.startswith('i'):            
       
   135             self.init(self.data[2:])
       
   136         elif self.data.startswith('!'):
       
   137             self.update()
       
   138         elif self.data.startswith('p'):
       
   139             self.overwrite()
       
   140         elif self.data.startswith('?'):               
       
   141             self.predict()
       
   142         elif self.data == '':
       
   143             # Empty Socket
       
   144             return
       
   145         elif self.data == 'avgStats':
       
   146             self.request.sendall(self.server.stats.printAvg())            
       
   147         else:
       
   148             self.request.sendall('Unspecified input format: \n%s',self.data)
       
   149         self.server.callCounter += 1
       
   150 
       
   151         #returnString = 'Time needed: '+str(round(time()-self.startTime,2))
       
   152         #print returnString
       
   153 
       
   154 if __name__ == "__main__":
       
   155     HOST, PORT = "localhost", 9255
       
   156     #print 'Started Server'
       
   157     # Create the server, binding to localhost on port 9999
       
   158     SocketServer.TCPServer.allow_reuse_address = True
       
   159     server = ThreadingTCPServer((HOST, PORT), MaShHandler)
       
   160     #server = SocketServer.TCPServer((HOST, PORT), MaShHandler)
       
   161 
       
   162     # Activate the server; this will keep running until you
       
   163     # interrupt the program with Ctrl-C
       
   164     server.serve_forever()        
       
   165 
       
   166 
       
   167 
       
   168     
       
   169     
       
   170