src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
changeset 50388 a5b666e0c3c2
parent 50222 40e3c3be6bca
child 50619 b958a94cf811
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Wed Dec 05 15:59:08 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Dec 06 11:25:10 2012 +0100
@@ -33,18 +33,18 @@
         self.accessibleDict = {}
         self.expandedAccessibles = {}
         self.changed = True
-    
+
     """
     Init functions. Side Effect: nameIdDict, idNameDict, featureIdDict get filled!
-    """    
+    """
     def init_featureDict(self,featureFile):
         self.featureDict,self.maxNameId,self.maxFeatureId = create_feature_dict(self.nameIdDict,self.idNameDict,self.maxNameId,self.featureIdDict,\
-                                                                                self.maxFeatureId,featureFile)        
+                                                                                self.maxFeatureId,featureFile)
     def init_dependenciesDict(self,depFile):
         self.dependenciesDict = create_dependencies_dict(self.nameIdDict,depFile)
     def init_accessibleDict(self,accFile):
         self.accessibleDict,self.maxNameId = create_accessible_dict(self.nameIdDict,self.idNameDict,self.maxNameId,accFile)
-    
+
     def init_all(self,inputFolder,featureFileName = 'mash_features',depFileName='mash_dependencies',accFileName = 'mash_accessibility'):
         featureFile = join(inputFolder,featureFileName)
         depFile = join(inputFolder,depFileName)
@@ -54,7 +54,7 @@
         self.init_dependenciesDict(depFile)
         self.expandedAccessibles = {}
         self.changed = True
-        
+
     def get_name_id(self,name):
         """
         Return the Id for a name.
@@ -66,7 +66,7 @@
             self.nameIdDict[name] = self.maxNameId
             self.idNameDict[self.maxNameId] = name
             nameId = self.maxNameId
-            self.maxNameId += 1 
+            self.maxNameId += 1
             self.changed = True
         return nameId
 
@@ -74,8 +74,23 @@
         if not self.featureIdDict.has_key(featureName):
             self.featureIdDict[featureName] = self.maxFeatureId
             self.maxFeatureId += 1
-            self.changed = True 
-            
+            self.changed = True
+        return self.featureIdDict[featureName]
+
+    def get_features(self,line):
+        # Feature Ids
+        featureNames = [f.strip() for f in line[1].split()]
+        features = []
+        for fn in featureNames:
+            tmp = fn.split('=')
+            if len(tmp) == 2:
+                fId = self.add_feature(tmp[0])
+                features.append((fId,float(tmp[1])))
+            else:
+                fId = self.add_feature(fn)
+                features.append((fId,1.0))
+        return features
+
     def expand_accessibles(self,acc):
         accessibles = set(acc)
         unexpandedQueue = Queue()
@@ -86,71 +101,67 @@
                 unexpandedQueue.put(a)
         while not unexpandedQueue.empty():
             nextUnExp = unexpandedQueue.get()
-            nextUnExpAcc = self.accessibleDict[nextUnExp]            
+            nextUnExpAcc = self.accessibleDict[nextUnExp]
             for a in nextUnExpAcc:
                 if not a in accessibles:
                     accessibles = accessibles.union([a])
                     if self.expandedAccessibles.has_key(a):
                         accessibles = accessibles.union(self.expandedAccessibles[a])
                     else:
-                        unexpandedQueue.put(a)                    
+                        unexpandedQueue.put(a)
         return list(accessibles)
-            
+
     def parse_fact(self,line):
         """
         Parses a single line, extracting accessibles, features, and dependencies.
         """
         assert line.startswith('! ')
         line = line[2:]
-        
+
         # line = name:accessibles;features;dependencies
         line = line.split(':')
         name = line[0].strip()
         nameId = self.get_name_id(name)
-    
-        line = line[1].split(';')       
+
+        line = line[1].split(';')
         # Accessible Ids
         unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
         self.accessibleDict[nameId] = unExpAcc
-        # Feature Ids
-        featureNames = [f.strip() for f in line[1].split()]
-        for fn in featureNames:
-            self.add_feature(fn)
-        self.featureDict[nameId] = [self.featureIdDict[fn] for fn in featureNames]
+        self.featureDict[nameId] = self.get_features(line)
         self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()]
         self.changed = True
         return nameId
-    
+
     def parse_overwrite(self,line):
         """
         Parses a single line, extracts the problemId and the Ids of the dependencies.
         """
         assert line.startswith('p ')
         line = line[2:]
-        
+
         # line = name:dependencies
         line = line.split(':')
         name = line[0].strip()
         nameId = self.get_name_id(name)
-    
+
         dependencies = [self.nameIdDict[d.strip()] for d in line[1].split()]
         self.changed = True
         return nameId,dependencies
-    
+
     def parse_problem(self,line):
         """
-        Parses a problem and returns the features and the accessibles. 
+        Parses a problem and returns the features and the accessibles.
         """
         assert line.startswith('? ')
         line = line[2:]
         name = None
-        
+
         # Check whether there is a problem name:
         tmp = line.split(':')
         if len(tmp) == 2:
             name = tmp[0].strip()
             line = tmp[1]
-        
+
         # line = accessibles;features
         line = line.split(';')
         # Accessible Ids, expand and store the accessibles.
@@ -164,13 +175,14 @@
                 self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc)
                 self.changed = True
         accessibles = self.expand_accessibles(unExpAcc)
-        # Feature Ids
-        featureNames = [f.strip() for f in line[1].split()]
-        for fn in featureNames:
-            self.add_feature(fn)
-        features = [self.featureIdDict[fn] for fn in featureNames]
-        return name,features,accessibles    
-    
+#        # Feature Ids
+#        featureNames = [f.strip() for f in line[1].split()]
+#        for fn in featureNames:
+#            self.add_feature(fn)
+#        features = [self.featureIdDict[fn] for fn in featureNames]
+        features = self.get_features(line)
+        return name,features,accessibles
+
     def save(self,fileName):
         if self.changed:
             dictsStream = open(fileName, 'wb')
@@ -179,10 +191,8 @@
             self.changed = False
             dictsStream.close()
     def load(self,fileName):
-        dictsStream = open(fileName, 'rb')        
+        dictsStream = open(fileName, 'rb')
         self.accessibleDict,self.dependenciesDict,self.expandedAccessibles,self.featureDict,\
               self.featureIdDict,self.idNameDict,self.maxFeatureId,self.maxNameId,self.nameIdDict = load(dictsStream)
         self.changed = False
         dictsStream.close()
-    
-            
\ No newline at end of file