src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
changeset 54697 b08e1bbde10a
parent 54692 5ce1b9613705
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Mon Dec 09 04:03:30 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Mon Dec 09 04:03:30 2013 +0100
@@ -9,7 +9,6 @@
 
 @author: Daniel Kuehlwein
 '''
-
 from cPickle import dump,load
 from numpy import array
 from math import log
@@ -104,32 +103,41 @@
         tau = 0.05 # Jasmin, change value here
         predictions = []
         observedFeatures = features.keys()
+        for fVal in features.itervalues():
+            _w,alternateF = fVal
+            observedFeatures += alternateF
+            
         for a in accessibles:
             posA = self.counts[a][0]
             fA = set(self.counts[a][1].keys())
             fWeightsA = self.counts[a][1]
             resultA = log(posA)
-            for f,w in features.iteritems():
+            for f,fVal in features.iteritems():
+                w,alternateF = fVal
                 # DEBUG
                 #w = 1.0
                 # Test for multiple features
                 isMatch = False
-                if not isinstance( f, ( int, long ) ):
-                    f = f[0]
-                    inter = set(f).intersection(fA)
+                matchF = None
+                if f in fA:
+                    isMatch = True
+                    matchF = f
+                elif len(alternateF) > 0:
+                    inter = set(alternateF).intersection(fA)
                     if len(inter) > 0:
                         isMatch = True
-                else:
-                    if f in fA:
-                        isMatch = True
+                        for mF in inter:
+                            ### TODO: matchF is randomly selected
+                            matchF = mF
+                            break
                  
                 if isMatch:
                 #if f in fA:
-                    if fWeightsA[f] == 0:
+                    if fWeightsA[matchF] == 0:
                         resultA += w*self.defVal
                     else:
-                        assert fWeightsA[f] <= posA
-                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
+                        assert fWeightsA[matchF] <= posA
+                        resultA += w*log(float(self.posWeight*fWeightsA[matchF])/posA)
                 else:
                     resultA += w*self.defVal
             if not tau == 0.0: