src/HOL/Tools/Sledgehammer/MaSh/src/server.py
changeset 54432 68f8bd1641da
parent 54150 942bb9d9b7a8
child 57124 e4c2c792226f
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Nov 14 15:57:48 2013 +0100
@@ -14,6 +14,7 @@
 from sparseNaiveBayes import sparseNBClassifier
 from KNN import KNN,euclidean
 from KNNs import KNNAdaptPointFeatures,KNNUrban
+#from bayesPlusMetric import sparseNBPlusClassifier
 from predefined import Predefined
 from ExpandFeatures import ExpandFeatures
 from stats import Statistics
@@ -58,15 +59,28 @@
         else:
             argv = argv.split(';')
             self.server.args = init_parser(argv)
+
+        # Set up logging
+        logging.basicConfig(level=logging.DEBUG,
+                            format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
+                            datefmt='%d-%m %H:%M:%S',
+                            filename=self.server.args.log+'server',
+                            filemode='w')    
+        self.server.logger = logging.getLogger('server')
+            
         # Load all data
         self.server.dicts = Dictionaries()
         if os.path.isfile(self.server.args.dictsFile):
-            self.server.dicts.load(self.server.args.dictsFile)            
+            self.server.dicts.load(self.server.args.dictsFile)
+        #elif not self.server.args.dictsFile == '../tmp/dict.pickle':
+        #    raise IOError('Cannot find dictsFile at %s '% self.server.args.dictsFile)        
         elif self.server.args.init:
             self.server.dicts.init_all(self.server.args)
         # Pick model
         if self.server.args.algorithm == 'nb':
-            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+            ###TODO: !! 
+            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)            
+            #self.server.model = sparseNBPlusClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
         elif self.server.args.algorithm == 'KNN':
             #self.server.model = KNN(self.server.dicts)
             self.server.model = KNNAdaptPointFeatures(self.server.dicts)
@@ -80,6 +94,8 @@
         # Create Model
         if os.path.isfile(self.server.args.modelFile):
             self.server.model.load(self.server.args.modelFile)          
+        #elif not self.server.args.modelFile == '../tmp/model.pickle':
+        #    raise IOError('Cannot find modelFile at %s '% self.server.args.modelFile)        
         elif self.server.args.init:
             trainData = self.server.dicts.featureDict.keys()
             self.server.model.initializeModel(trainData,self.server.dicts)
@@ -89,13 +105,6 @@
             self.server.statementCounter = 1
             self.server.computeStats = False
 
-        # Set up logging
-        logging.basicConfig(level=logging.DEBUG,
-                            format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
-                            datefmt='%d-%m %H:%M:%S',
-                            filename=self.server.args.log+'server',
-                            filemode='w')    
-        self.server.logger = logging.getLogger('server')
         self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.')
         self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.')
         self.server.callCounter = 1
@@ -117,8 +126,11 @@
         if self.server.args.expandFeatures:
             self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
         # Update Dependencies, p proves p
-        self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
+        if not problemId == 0:
+            self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
+        ###TODO: 
         self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
+        #self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId],self.server.dicts)
 
     def overwrite(self):
         # Overwrite old proof.