more robustness in MaSh
authorblanchet
Fri, 04 Oct 2013 09:46:08 +0200
changeset 54056 8298976acb54
parent 54055 5bf55a713232
child 54057 a2c4e0b7b1e2
more robustness in MaSh
src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
src/HOL/Tools/Sledgehammer/MaSh/src/server.py
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Oct 03 19:01:10 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Fri Oct 04 09:46:08 2013 +0200
@@ -4,11 +4,12 @@
 #
 # Persistent dictionaries: accessibility, dependencies, and features.
 
-import logging,sys
+import sys
 from os.path import join
 from Queue import Queue
 from readData import create_accessible_dict,create_dependencies_dict
 from cPickle import load,dump
+from exceptions import LookupError
 
 class Dictionaries(object):
     '''
@@ -56,7 +57,6 @@
         self.changed = True
 
     def create_feature_dict(self,inputFile):
-        logger = logging.getLogger('create_feature_dict')
         self.featureDict = {}
         IS = open(inputFile,'r')
         for line in IS:
@@ -64,7 +64,7 @@
             name = line[0]
             # Name Id
             if self.nameIdDict.has_key(name):
-                logger.warning('%s appears twice in the feature file. Aborting.',name)
+                raise LookupError('%s appears twice in the feature file. Aborting.'% name)
                 sys.exit(-1)
             else:
                 self.nameIdDict[name] = self.maxNameId
@@ -134,6 +134,13 @@
                         unexpandedQueue.put(a)
         return list(accessibles)
 
+    def parse_unExpAcc(self,line):
+        try:
+            unExpAcc = [self.nameIdDict[a.strip()] for a in line.split()]            
+        except:
+            raise LookupError('Cannot find the accessibles:%s. Accessibles need to be introduced before referring to them.' % line)
+        return unExpAcc
+
     def parse_fact(self,line):
         """
         Parses a single line, extracting accessibles, features, and dependencies.
@@ -147,8 +154,9 @@
         nameId = self.get_name_id(name)
         line = line[1].split(';')
         # Accessible Ids
-        unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
-        self.accessibleDict[nameId] = unExpAcc
+        #unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
+        #self.accessibleDict[nameId] = unExpAcc
+        self.accessibleDict[nameId] = self.parse_unExpAcc(line[0])
         features = self.get_features(line)
         self.featureDict[nameId] = features
         self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()]        
@@ -180,7 +188,7 @@
         name = None
         numberOfPredictions = None
 
-        # Check whether there is a problem name:
+        # How many predictions should be returned:
         tmp = line.split('#')
         if len(tmp) == 2:
             numberOfPredictions = int(tmp[0].strip())
@@ -194,8 +202,11 @@
 
         # line = accessibles;features
         line = line.split(';')
+        features = self.get_features(line)
+        
         # Accessible Ids, expand and store the accessibles.
-        unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
+        #unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
+        unExpAcc = self.parse_unExpAcc(line[0])        
         if len(self.expandedAccessibles.keys())>=100:
             self.expandedAccessibles = {}
             self.changed = True
@@ -205,7 +216,7 @@
                 self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc)
                 self.changed = True
         accessibles = self.expand_accessibles(unExpAcc)
-        features = self.get_features(line)
+        
         # Get hints:
         if len(line) == 3:
             hints = [self.nameIdDict[d.strip()] for d in line[2].split()]
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Oct 03 19:01:10 2013 +0200
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Fri Oct 04 09:46:08 2013 +0200
@@ -19,7 +19,8 @@
 from stats import Statistics
 
 
-class ThreadingTCPServer(SocketServer.ThreadingTCPServer): 
+class ThreadingTCPServer(SocketServer.ThreadingTCPServer):
+    
     def __init__(self, *args, **kwargs):
         SocketServer.ThreadingTCPServer.__init__(self,*args, **kwargs)
         self.manager = Manager()
@@ -27,8 +28,17 @@
         self.idle_timeout = 28800.0 # 8 hours in seconds
         self.idle_timer = Timer(self.idle_timeout, self.shutdown)
         self.idle_timer.start()        
+        self.model = None
+        self.dicts = None
+        self.callCounter = 0
         
     def save(self):
+        if self.model == None or self.dicts == None:
+            try:
+                self.logger.warning('Cannot save nonexisting models.')
+            except:
+                pass
+            return
         # Save Models
         self.model.save(self.args.modelFile)
         self.dicts.save(self.args.dictsFile)
@@ -193,8 +203,11 @@
             self.server.lock.release()
 
 if __name__ == "__main__":
-    HOST, PORT = sys.argv[1:]    
-    #HOST, PORT = "localhost", 9255
+    if not len(sys.argv[1:]) == 2:
+        print 'No Arguments for HOST and PORT found. Using localhost and 9255'
+        HOST, PORT = "localhost", 9255
+    else:
+        HOST, PORT = sys.argv[1:]
     SocketServer.TCPServer.allow_reuse_address = True
     server = ThreadingTCPServer((HOST, int(PORT)), MaShHandler)
     server.serve_forever()