# HG changeset patch # User blanchet # Date 1386558210 -3600 # Node ID b08e1bbde10a967379b87ea09e71345a23611380 # Parent 34496126a60c701fed9244f7340fc3aa0b9f0686 bugfixes to Python MaSh related to alternative features diff -r 34496126a60c -r b08e1bbde10a src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py Mon Dec 09 04:03:30 2013 +0100 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py Mon Dec 09 04:03:30 2013 +0100 @@ -105,10 +105,7 @@ self.changed = True fId = self.featureIdDict[f] fIds.append(fId) - if len(fIds) == 1: - return fIds[0] - else: - return fIds + return fIds def get_features(self,line): featureNames = [f.strip() for f in line[1].split()] @@ -119,8 +116,8 @@ if len(tmp) == 2: fn = tmp[0] weight = float(tmp[1]) - fId = self.add_feature(tmp[0]) - features[fId] = weight + fIds = self.add_feature(tmp[0]) + features[fIds[0]] = (weight,fIds[1:]) #features[fId] = 1.0 ### return features diff -r 34496126a60c -r b08e1bbde10a src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py --- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py Mon Dec 09 04:03:30 2013 +0100 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py Mon Dec 09 04:03:30 2013 +0100 @@ -9,7 +9,6 @@ @author: Daniel Kuehlwein ''' - from cPickle import dump,load from numpy import array from math import log @@ -104,32 +103,41 @@ tau = 0.05 # Jasmin, change value here predictions = [] observedFeatures = features.keys() + for fVal in features.itervalues(): + _w,alternateF = fVal + observedFeatures += alternateF + for a in accessibles: posA = self.counts[a][0] fA = set(self.counts[a][1].keys()) fWeightsA = self.counts[a][1] resultA = log(posA) - for f,w in features.iteritems(): + for f,fVal in features.iteritems(): + w,alternateF = fVal # DEBUG #w = 1.0 # Test for multiple features isMatch = False - if not isinstance( f, ( int, long ) ): - f = f[0] - inter = set(f).intersection(fA) + matchF = None + if f in fA: + isMatch = True + matchF = f + elif len(alternateF) > 0: + inter = set(alternateF).intersection(fA) if len(inter) > 0: isMatch = True - else: - if f in fA: - isMatch = True + for mF in inter: + ### TODO: matchF is randomly selected + matchF = mF + break if isMatch: #if f in fA: - if fWeightsA[f] == 0: + if fWeightsA[matchF] == 0: resultA += w*self.defVal else: - assert fWeightsA[f] <= posA - resultA += w*log(float(self.posWeight*fWeightsA[f])/posA) + assert fWeightsA[matchF] <= posA + resultA += w*log(float(self.posWeight*fWeightsA[matchF])/posA) else: resultA += w*self.defVal if not tau == 0.0: