|
1 #!/usr/bin/env python |
|
2 # Title: HOL/Tools/Sledgehammer/MaSh/src/server.py |
|
3 # Author: Daniel Kuehlwein, ICIS, Radboud University Nijmegen |
|
4 # Copyright 2013 |
|
5 # |
|
6 # The MaSh Server. |
|
7 |
|
8 import SocketServer,os,string,logging |
|
9 from multiprocessing import Manager |
|
10 from time import time |
|
11 from dictionaries import Dictionaries |
|
12 from parameters import init_parser |
|
13 from sparseNaiveBayes import sparseNBClassifier |
|
14 from stats import Statistics |
|
15 |
|
16 |
|
17 class ThreadingTCPServer(SocketServer.ThreadingTCPServer): |
|
18 def __init__(self, *args, **kwargs): |
|
19 SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs) |
|
20 self.manager = Manager() |
|
21 self.lock = Manager().Lock() |
|
22 |
|
23 class MaShHandler(SocketServer.BaseRequestHandler): |
|
24 |
|
25 def init(self,argv): |
|
26 if argv == '': |
|
27 self.server.args = init_parser([]) |
|
28 else: |
|
29 argv = argv.split(';') |
|
30 self.server.args = init_parser(argv) |
|
31 # Pick model |
|
32 if self.server.args.algorithm == 'nb': |
|
33 self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal) |
|
34 else: # Default case |
|
35 self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal) |
|
36 # Load all data |
|
37 # TODO: rewrite dicts for concurrency and without sine |
|
38 self.server.dicts = Dictionaries() |
|
39 if os.path.isfile(self.server.args.dictsFile): |
|
40 self.server.dicts.load(self.server.args.dictsFile) |
|
41 elif self.server.args.init: |
|
42 self.server.dicts.init_all(self.server.args) |
|
43 # Create Model |
|
44 if os.path.isfile(self.server.args.modelFile): |
|
45 self.server.model.load(self.server.args.modelFile) |
|
46 elif self.server.args.init: |
|
47 trainData = self.server.dicts.featureDict.keys() |
|
48 self.server.model.initializeModel(trainData,self.server.dicts) |
|
49 |
|
50 if self.server.args.statistics: |
|
51 self.server.stats = Statistics(self.server.args.cutOff) |
|
52 self.server.statementCounter = 1 |
|
53 self.server.computeStats = False |
|
54 |
|
55 # Set up logging |
|
56 logging.basicConfig(level=logging.DEBUG, |
|
57 format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', |
|
58 datefmt='%d-%m %H:%M:%S', |
|
59 filename=self.server.args.log+'server', |
|
60 filemode='w') |
|
61 self.server.logger = logging.getLogger('server') |
|
62 self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.') |
|
63 self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.') |
|
64 self.server.callCounter = 1 |
|
65 |
|
66 def update(self): |
|
67 problemId = self.server.dicts.parse_fact(self.data) |
|
68 # Statistics |
|
69 if self.server.args.statistics and self.server.computeStats: |
|
70 self.server.computeStats = False |
|
71 # Assume '!' comes after '?' |
|
72 if self.server.args.algorithm == 'predef': |
|
73 self.server.predictions = self.server.model.predict(problemId) |
|
74 self.server.stats.update(self.server.predictions,self.server.dicts.dependenciesDict[problemId],self.server.statementCounter) |
|
75 if not self.server.stats.badPreds == []: |
|
76 bp = string.join([str(self.server.dicts.idNameDict[x]) for x in self.server.stats.badPreds], ',') |
|
77 self.server.logger.debug('Poor predictions: %s',bp) |
|
78 self.server.statementCounter += 1 |
|
79 |
|
80 # Update Dependencies, p proves p |
|
81 self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId] |
|
82 self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId]) |
|
83 |
|
84 def overwrite(self): |
|
85 # Overwrite old proof. |
|
86 problemId,newDependencies = self.server.dicts.parse_overwrite(self.data) |
|
87 newDependencies = [problemId]+newDependencies |
|
88 self.server.model.overwrite(problemId,newDependencies,self.server.dicts) |
|
89 self.server.dicts.dependenciesDict[problemId] = newDependencies |
|
90 |
|
91 def predict(self): |
|
92 self.server.computeStats = True |
|
93 if self.server.args.algorithm == 'predef': |
|
94 return |
|
95 name,features,accessibles,hints,numberOfPredictions = self.server.dicts.parse_problem(self.data) |
|
96 if numberOfPredictions == None: |
|
97 numberOfPredictions = self.server.args.numberOfPredictions |
|
98 if not hints == []: |
|
99 self.server.model.update('hints',features,hints) |
|
100 |
|
101 # Create predictions |
|
102 self.server.logger.debug('Starting computation for line %s',self.server.callCounter) |
|
103 predictionsFeatures = features |
|
104 self.server.predictions,predictionValues = self.server.model.predict(predictionsFeatures,accessibles,self.server.dicts) |
|
105 assert len(self.server.predictions) == len(predictionValues) |
|
106 self.server.logger.debug('Time needed: '+str(round(time()-self.startTime,2))) |
|
107 |
|
108 # Output |
|
109 predictionNames = [str(self.server.dicts.idNameDict[p]) for p in self.server.predictions[:numberOfPredictions]] |
|
110 predictionValues = [str(x) for x in predictionValues[:self.server.args.numberOfPredictions]] |
|
111 predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))] |
|
112 predictionsString = string.join(predictionsStringList,' ') |
|
113 outString = '%s: %s' % (name,predictionsString) |
|
114 self.request.sendall(outString) |
|
115 |
|
116 def shutdown(self): |
|
117 self.request.sendall('Shutting down server.') |
|
118 # Save Models |
|
119 self.server.model.save(self.server.args.modelFile) |
|
120 self.server.dicts.save(self.server.args.dictsFile) |
|
121 if not self.server.args.saveStats == None: |
|
122 statsFile = os.path.join(self.server.args.outputDir,self.server.args.saveStats) |
|
123 self.server.stats.save(statsFile) |
|
124 self.server.shutdown() |
|
125 |
|
126 def handle(self): |
|
127 # self.request is the TCP socket connected to the client |
|
128 self.data = self.request.recv(262144).strip() |
|
129 #print "{} wrote:".format(self.client_address[0]) |
|
130 #print self.data |
|
131 self.startTime = time() |
|
132 if self.data == 's': |
|
133 self.shutdown() |
|
134 elif self.data.startswith('i'): |
|
135 self.init(self.data[2:]) |
|
136 elif self.data.startswith('!'): |
|
137 self.update() |
|
138 elif self.data.startswith('p'): |
|
139 self.overwrite() |
|
140 elif self.data.startswith('?'): |
|
141 self.predict() |
|
142 elif self.data == '': |
|
143 # Empty Socket |
|
144 return |
|
145 elif self.data == 'avgStats': |
|
146 self.request.sendall(self.server.stats.printAvg()) |
|
147 else: |
|
148 self.request.sendall('Unspecified input format: \n%s',self.data) |
|
149 self.server.callCounter += 1 |
|
150 |
|
151 #returnString = 'Time needed: '+str(round(time()-self.startTime,2)) |
|
152 #print returnString |
|
153 |
|
154 if __name__ == "__main__": |
|
155 HOST, PORT = "localhost", 9255 |
|
156 #print 'Started Server' |
|
157 # Create the server, binding to localhost on port 9999 |
|
158 SocketServer.TCPServer.allow_reuse_address = True |
|
159 server = ThreadingTCPServer((HOST, PORT), MaShHandler) |
|
160 #server = SocketServer.TCPServer((HOST, PORT), MaShHandler) |
|
161 |
|
162 # Activate the server; this will keep running until you |
|
163 # interrupt the program with Ctrl-C |
|
164 server.serve_forever() |
|
165 |
|
166 |
|
167 |
|
168 |
|
169 |
|
170 |