bugfixes to Python MaSh related to alternative features
authorblanchet
Mon, 09 Dec 2013 04:03:30 +0100
changeset 54697 b08e1bbde10a
parent 54696 34496126a60c
child 54698 fed04f257898
bugfixes to Python MaSh related to alternative features
src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Mon Dec 09 04:03:30 2013 +0100
@@ -105,10 +105,7 @@
                 self.changed = True
             fId = self.featureIdDict[f]
             fIds.append(fId)
-        if len(fIds) == 1:
-            return fIds[0]
-        else:
-            return fIds
+        return fIds
 
     def get_features(self,line):
         featureNames = [f.strip() for f in line[1].split()]
@@ -119,8 +116,8 @@
             if len(tmp) == 2:
                 fn = tmp[0]
                 weight = float(tmp[1])
-            fId = self.add_feature(tmp[0])
-            features[fId] = weight
+            fIds = self.add_feature(tmp[0])
+            features[fIds[0]] = (weight,fIds[1:])
             #features[fId] = 1.0 ###
         return features
 
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Mon Dec 09 04:03:30 2013 +0100
@@ -9,7 +9,6 @@
 
 @author: Daniel Kuehlwein
 '''
-
 from cPickle import dump,load
 from numpy import array
 from math import log
@@ -104,32 +103,41 @@
         tau = 0.05 # Jasmin, change value here
         predictions = []
         observedFeatures = features.keys()
+        for fVal in features.itervalues():
+            _w,alternateF = fVal
+            observedFeatures += alternateF
+            
         for a in accessibles:
             posA = self.counts[a][0]
             fA = set(self.counts[a][1].keys())
             fWeightsA = self.counts[a][1]
             resultA = log(posA)
-            for f,w in features.iteritems():
+            for f,fVal in features.iteritems():
+                w,alternateF = fVal
                 # DEBUG
                 #w = 1.0
                 # Test for multiple features
                 isMatch = False
-                if not isinstance( f, ( int, long ) ):
-                    f = f[0]
-                    inter = set(f).intersection(fA)
+                matchF = None
+                if f in fA:
+                    isMatch = True
+                    matchF = f
+                elif len(alternateF) > 0:
+                    inter = set(alternateF).intersection(fA)
                     if len(inter) > 0:
                         isMatch = True
-                else:
-                    if f in fA:
-                        isMatch = True
+                        for mF in inter:
+                            ### TODO: matchF is randomly selected
+                            matchF = mF
+                            break
                  
                 if isMatch:
                 #if f in fA:
-                    if fWeightsA[f] == 0:
+                    if fWeightsA[matchF] == 0:
                         resultA += w*self.defVal
                     else:
-                        assert fWeightsA[f] <= posA
-                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
+                        assert fWeightsA[matchF] <= posA
+                        resultA += w*log(float(self.posWeight*fWeightsA[matchF])/posA)
                 else:
                     resultA += w*self.defVal
             if not tau == 0.0: