src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
author blanchet
Sat, 12 Jan 2013 16:49:39 +0100
changeset 50840 a5cc092156da
parent 50827 aba769dc82e9
child 50951 e1cbaa7d5536
permissions -rw-r--r--
new version of MaSh Python component

#     Title:      HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
#     Copyright   2012
#
# An updatable sparse naive Bayes classifier.

'''
Created on Jul 11, 2012

@author: Daniel Kuehlwein
'''

from cPickle import dump,load
from numpy import array
from math import log

class sparseNBClassifier(object):
    '''
    An updateable naive Bayes classifier.
    '''

    def __init__(self,defaultPriorWeight = 20.0,posWeight = 20.0,defVal = -15.0,useSinePrior = False,sineWeight = 100.0):
        '''
        Constructor
        '''
        self.counts = {}
        self.sinePrior = useSinePrior
        self.sineWeight = sineWeight
        self.defaultPriorWeight = defaultPriorWeight
        self.posWeight = posWeight
        self.defVal = defVal

    def initializeModel(self,trainData,dicts):
        """
        Build basic model from training data.
        """
        for d in trainData:            
            dFeatureCounts = {}
            # Give p |- p a higher weight
            if not self.defaultPriorWeight == 0:            
                for f,_w in dicts.featureDict[d]:
                    dFeatureCounts[f] = self.defaultPriorWeight
            self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]

        for key in dicts.dependenciesDict.keys():
            # Add p proves p
            keyDeps = [key]+dicts.dependenciesDict[key]
            for dep in keyDeps:
                self.counts[dep][0] += 1
                depFeatures = dicts.featureDict[key]
                for f,_w in depFeatures:
                    if self.counts[dep][1].has_key(f):
                        self.counts[dep][1][f] += 1
                    else:
                        self.counts[dep][1][f] = 1


    def update(self,dataPoint,features,dependencies):
        """
        Updates the Model.
        """
        if not self.counts.has_key(dataPoint):
            dFeatureCounts = {}            
            # Give p |- p a higher weight
            if not self.defaultPriorWeight == 0:               
                for f,_w in features:
                    dFeatureCounts[f] = self.defaultPriorWeight
            self.counts[dataPoint] = [self.defaultPriorWeight,dFeatureCounts]            
        for dep in dependencies:
            self.counts[dep][0] += 1
            for f,_w in features:
                if self.counts[dep][1].has_key(f):
                    self.counts[dep][1][f] += 1
                else:
                    self.counts[dep][1][f] = 1

    def delete(self,dataPoint,features,dependencies):
        """
        Deletes a single datapoint from the model.
        """
        for dep in dependencies:
            self.counts[dep][0] -= 1
            for f,_w in features:
                self.counts[dep][1][f] -= 1


    def overwrite(self,problemId,newDependencies,dicts):
        """
        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
        """
        assert self.counts.has_key(problemId)
        oldDeps = dicts.dependenciesDict[problemId]
        features = dicts.featureDict[problemId]
        self.delete(problemId,features,oldDeps)
        self.update(problemId,features,newDependencies)

    def predict(self,features,accessibles,dicts):
        """
        For each accessible, predicts the probability of it being useful given the features.
        Returns a ranking of the accessibles.
        """
        predictions = []
        fSet = set([f for f,_w in features])
        for a in accessibles:
            posA = self.counts[a][0]
            fA = set(self.counts[a][1].keys())
            fWeightsA = self.counts[a][1]
            prior = posA
            if self.sinePrior:
                triggerFeatures = dicts.triggerFeatures[a]
                triggeredFeatures = fSet.intersection(triggerFeatures)
                for f in triggeredFeatures:
                    posW = dicts.featureCountDict[f]
                    prior += self.sineWeight /  posW 
            resultA = log(prior)
            for f,w in features:
                # DEBUG
                #w = 1
                if f in fA:
                    if fWeightsA[f] == 0:
                        resultA += w*self.defVal
                    else:
                        assert fWeightsA[f] <= posA
                        resultA += w*log(float(self.posWeight*fWeightsA[f])/posA)
                else:
                    resultA += w*self.defVal
            predictions.append(resultA)
        predictions = array(predictions)
        perm = (-predictions).argsort()
        return array(accessibles)[perm],predictions[perm]

    def save(self,fileName):
        OStream = open(fileName, 'wb')
        dump((self.counts,self.defaultPriorWeight,self.posWeight,self.defVal,self.sinePrior,self.sineWeight),OStream)
        OStream.close()

    def load(self,fileName):
        OStream = open(fileName, 'rb')
        self.counts,self.defaultPriorWeight,self.posWeight,self.defVal,self.sinePrior,self.sineWeight = load(OStream)
        OStream.close()


if __name__ == '__main__':
    featureDict = {0:[0,1,2],1:[3,2,1]}
    dependenciesDict = {0:[0],1:[0,1]}
    libDicts = (featureDict,dependenciesDict,{})
    c = sparseNBClassifier()
    c.initializeModel([0,1],libDicts)
    c.update(2,[14,1,3],[0,2])
    print c.counts
    print c.predict([0,14],[0,1,2])
    c.storeModel('x')
    d = sparseNBClassifier()
    d.loadModel('x')
    print c.counts
    print d.counts
    print 'Done'