new version of MaSh, with theory-level reasoning
authorblanchet
Thu, 27 Dec 2012 10:01:40 +0100
changeset 50619 b958a94cf811
parent 50617 9df2f825422b
child 50620 07e08250a880
new version of MaSh, with theory-level reasoning
src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py
src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py
src/HOL/Tools/Sledgehammer/MaSh/src/snow.py
src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Wed Dec 26 11:06:21 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Dec 27 10:01:40 2012 +0100
@@ -35,7 +35,7 @@
         self.changed = True
 
     """
-    Init functions. Side Effect: nameIdDict, idNameDict, featureIdDict get filled!
+    Init functions. Side Effect: nameIdDict, idNameDict, featureIdDict, articleDict get filled!
     """
     def init_featureDict(self,featureFile):
         self.featureDict,self.maxNameId,self.maxFeatureId = create_feature_dict(self.nameIdDict,self.idNameDict,self.maxNameId,self.featureIdDict,\
@@ -175,12 +175,8 @@
                 self.expandedAccessibles[accId] = self.expand_accessibles(accIdAcc)
                 self.changed = True
         accessibles = self.expand_accessibles(unExpAcc)
-#        # Feature Ids
-#        featureNames = [f.strip() for f in line[1].split()]
-#        for fn in featureNames:
-#            self.add_feature(fn)
-#        features = [self.featureIdDict[fn] for fn in featureNames]
         features = self.get_features(line)
+
         return name,features,accessibles
 
     def save(self,fileName):
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Wed Dec 26 11:06:21 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Thu Dec 27 10:01:40 2012 +0100
@@ -19,10 +19,11 @@
 from argparse import ArgumentParser,RawDescriptionHelpFormatter
 from time import time
 from stats import Statistics
+from theoryStats import TheoryStatistics
+from theoryModels import TheoryModels
 from dictionaries import Dictionaries
 #from fullNaiveBayes import NBClassifier
 from sparseNaiveBayes import sparseNBClassifier
-#from naiveBayes import sparseNBClassifier
 from snow import SNoW
 from predefined import Predefined
 
@@ -41,11 +42,13 @@
 parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int)
 
 parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
-parser.add_argument('--inputDir',default='../data/Jinja/',\
+parser.add_argument('--inputDir',default='../data/20121212/Jinja/',\
                     help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
 parser.add_argument('--depFile', default='mash_dependencies',
                     help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
 parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.")
+parser.add_argument('--learnTheories',default=False,action='store_true',help="Uses a two-lvl prediction mode. First the theories, then the premises. Default=False.")
+
 
 parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
 parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
@@ -101,6 +104,7 @@
         model = sparseNBClassifier()
         modelFile = os.path.join(args.outputDir,'NB.pickle')
     dictsFile = os.path.join(args.outputDir,'dicts.pickle')
+    theoryFile = os.path.join(args.outputDir,'theory.pickle')
 
     # Initializing model
     if args.init:
@@ -110,14 +114,17 @@
         # Load all data
         dicts = Dictionaries()
         dicts.init_all(args.inputDir,depFileName=args.depFile)
-
+        
         # Create Model
         trainData = dicts.featureDict.keys()
-        if args.predef:
-            model.initializeModel(trainData,dicts)
-        else:
-            model.initializeModel(trainData,dicts)
+        model.initializeModel(trainData,dicts)
 
+        if args.learnTheories:
+            depFile = os.path.join(args.inputDir,args.depFile)
+            theoryModels = TheoryModels()
+            theoryModels.init(depFile,dicts)
+            theoryModels.save(theoryFile)
+            
         model.save(modelFile)
         dicts.save(dictsFile)
 
@@ -129,11 +136,14 @@
         statementCounter = 1
         computeStats = False
         dicts = Dictionaries()
+        theoryModels = TheoryModels()
         # Load Files
         if os.path.isfile(dictsFile):
             dicts.load(dictsFile)
         if os.path.isfile(modelFile):
             model.load(modelFile)
+        if os.path.isfile(theoryFile) and args.learnTheories:
+            theoryModels.load(theoryFile)
 
         # IO Streams
         OS = open(args.predictions,'w')
@@ -142,32 +152,37 @@
         # Statistics
         if args.statistics:
             stats = Statistics(args.cutOff)
+            if args.learnTheories:
+                theoryStats = TheoryStatistics()
 
         predictions = None
+        predictedTheories = None
         #Reading Input File
         for line in IS:
 #           try:
             if True:
                 if line.startswith('!'):
-                    problemId = dicts.parse_fact(line)                    
+                    problemId = dicts.parse_fact(line)                        
                     # Statistics
                     if args.statistics and computeStats:
                         computeStats = False
-                        acc = dicts.accessibleDict[problemId]
+                        # Assume '!' comes after '?'
                         if args.predef:
                             predictions = model.predict(problemId)
-                        else:
-                            if args.snow:
-                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc),dicts)
-                            else:
-                                predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))                        
+                        if args.learnTheories:
+                            tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]]
+                            usedTheories = set([x.split('.')[0] for x in tmp]) 
+                            theoryStats.update((dicts.idNameDict[problemId]).split('.')[0],predictedTheories,usedTheories)                        
                         stats.update(predictions,dicts.dependenciesDict[problemId],statementCounter)
                         if not stats.badPreds == []:
                             bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
                             logger.debug('Bad predictions: %s',bp)
+
                     statementCounter += 1
                     # Update Dependencies, p proves p
                     dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
+                    if args.learnTheories:
+                        theoryModels.update(problemId,dicts)
                     if args.snow:
                         model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId],dicts)
                     else:
@@ -177,15 +192,19 @@
                     problemId,newDependencies = dicts.parse_overwrite(line)
                     newDependencies = [problemId]+newDependencies
                     model.overwrite(problemId,newDependencies,dicts)
+                    if args.learnTheories:
+                        theoryModels.overwrite(problemId,newDependencies,dicts)
                     dicts.dependenciesDict[problemId] = newDependencies
-                elif line.startswith('?'):                    
+                elif line.startswith('?'):               
                     startTime = time()
                     computeStats = True
                     if args.predef:
                         continue
-                    name,features,accessibles = dicts.parse_problem(line)
+                    name,features,accessibles = dicts.parse_problem(line)    
                     # Create predictions
                     logger.info('Starting computation for problem on line %s',lineCounter)
+                    if args.learnTheories:
+                        predictedTheories,accessibles = theoryModels.predict(features,accessibles,dicts)
                     if args.snow:
                         predictions,predictionValues = model.predict(features,accessibles,dicts)
                     else:
@@ -214,13 +233,20 @@
 
         # Statistics
         if args.statistics:
+            if args.learnTheories:
+                theoryStats.printAvg()
             stats.printAvg()
 
         # Save
         if args.saveModel:
             model.save(modelFile)
+            if args.learnTheories:
+                theoryModels.save(theoryFile)
         dicts.save(dictsFile)
         if not args.saveStats == None:
+            if args.learnTheories:
+                theoryStatsFile = os.path.join(args.outputDir,'theoryStats')
+                theoryStats.save(theoryStatsFile)
             statsFile = os.path.join(args.outputDir,args.saveStats)
             stats.save(statsFile)
     return 0
@@ -228,28 +254,37 @@
 if __name__ == '__main__':
     # Example:
     # Jinja
-    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef']
-    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats']
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']
-    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500']
-    # List
-    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle']
-    #args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics']
-    # Huffmann
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/','--depFile','mash_atp_dependencies']
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Huffman/']
-    #args = ['-i', '../data/Huffman/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/','--statistics']
-    # Jinja
-    # ISAR
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/']    
-    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
-    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef']
-    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
+    # ISAR Theories
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121221/Jinja/','--learnTheories']    
+    #args = ['-i', '../data/20121221/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--learnTheories']
+    # ISAR NB
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121221/Jinja/']    
+    #args = ['-i', '../data/20121221/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
+    # ISAR MePo
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--predef']
+    #args = ['-i', '../data/20121212/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
+    # ISAR NB ATP
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--depFile','mash_atp_dependencies']    
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies']
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies']
     #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies','--snow']    
-    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+    #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
+    # ISAR Snow
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--snow']    
+    #args = ['-i', '../data/20121212/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--snow','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500']
+ 
+
 
-    # ATP
-    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--depFile','mash_atp_dependencies']    
+    # Probability
+    # ISAR NB
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121213/Probability/']    
+    #args = ['-i', '../data/20121213/Probability/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/ProbIsarNB.stats','--cutOff','500']
+    # ISAR MePo
+    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121213/Probability/','--predef']
+    #args = ['-i', '../data/20121213/Probability/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats']
+    # ISAR NB ATP
+    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/20121212/Jinja/','--depFile','mash_atp_dependencies']    
     #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaIsarNB.stats','--cutOff','500','--depFile','mash_atp_dependencies']
     #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Jinja/','--predef','--depFile','mash_atp_dependencies']
     #args = ['-i', '../data/Jinja/mash_commands','-p','../tmp/JinjaMePo.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/JinjaMePo.stats','--depFile','mash_atp_dependencies']
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py	Wed Dec 26 11:06:21 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/predefined.py	Thu Dec 27 10:01:40 2012 +0100
@@ -37,8 +37,7 @@
             line = line[1].split()
             preds = [dicts.get_name_id(x.strip())for x in line]
             self.predictions[predId] = preds
-        IS.close()
-        return dicts
+        IS.close()        
 
     def update(self,dataPoint,features,dependencies):
         """
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py	Thu Dec 27 10:01:40 2012 +0100
@@ -0,0 +1,173 @@
+#     Title:      HOL/Tools/Sledgehammer/MaSh/src/singleNaiveBayes.py
+#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
+#     Copyright   2012
+#
+# An updatable sparse naive Bayes classifier.
+
+'''
+Created on Jul 11, 2012
+
+@author: Daniel Kuehlwein
+'''
+
+from cPickle import dump,load
+from math import log,exp
+
+
+class singleNBClassifier(object):
+    '''
+    An updateable naive Bayes classifier.
+    '''
+
+    def __init__(self):
+        '''
+        Constructor
+        '''
+        self.neg = 0.0
+        self.pos = 0.0
+        self.counts = {} # Counts is the tuple poscounts,negcounts
+    
+    def update(self,features,label):
+        """
+        Updates the Model.
+        
+        @param label: True or False, True if the features belong to a positive label, false else.
+        """
+        #print label,self.pos,self.neg,self.counts
+        if label:
+            self.pos += 1
+        else:
+            self.neg += 1
+        
+        for f,_w in features:
+            if not self.counts.has_key(f):
+                fPosCount = 0.0
+                fNegCount = 0.0
+                self.counts[f] = [fPosCount,fNegCount]
+            posCount,negCount = self.counts[f]
+            if label:
+                posCount += 1
+            else:
+                negCount += 1
+            self.counts[f] = [posCount,negCount]
+        #print label,self.pos,self.neg,self.counts
+                
+ 
+    def delete(self,features,label):
+        """
+        Deletes a single datapoint from the model.
+        """
+        if label:
+            self.pos -= 1
+        else:
+            self.neg -= 1
+        for f in features:
+            posCount,negCount = self.counts[f]
+            if label:
+                posCount -= 1
+            else:
+                negCount -= 1
+            self.counts[f] = [posCount,negCount]
+
+            
+    def overwrite(self,features,label):
+        """
+        Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
+        """
+        self.delete(features,label)
+        self.update(features,label)
+    
+    def predict_sparse(self,features):
+        """
+        Returns 1 if the probability is greater than 50%.
+        """
+        if self.neg == 0:
+            return 1
+        elif self.pos ==0:
+            return 0
+        defValPos = -7.5       
+        defValNeg = -15.0
+        posWeight = 10.0
+        
+        logneg = log(self.neg)
+        logpos = log(self.pos)
+        prob = logpos - logneg
+        
+        for f,_w in features:
+            if self.counts.has_key(f):
+                posCount,negCount = self.counts[f]
+                if posCount > 0:
+                    prob += (log(posWeight * posCount) - logpos)
+                else:
+                    prob += defValPos
+                if negCount > 0:
+                    prob -= (log(negCount) - logneg)
+                else:
+                    prob -= defValNeg 
+        if prob >= 0 : 
+            return 1
+        else:
+            return 0
+    
+    def predict(self,features):    
+        """
+        Returns 1 if the probability is greater than 50%.
+        """
+        if self.neg == 0:
+            return 1
+        elif self.pos ==0:
+            return 0
+        defVal = -15.0       
+        expDefVal = exp(defVal)
+        
+        logneg = log(self.neg)
+        logpos = log(self.pos)
+        prob = logpos - logneg
+        
+        for f in self.counts.keys():
+            posCount,negCount = self.counts[f]
+            if f in features:
+                if posCount == 0:
+                    prob += defVal
+                else:
+                    prob += log(float(posCount)/self.pos)
+                if negCount == 0:
+                    prob -= defVal
+                else:
+                    prob -= log(float(negCount)/self.neg)
+            else:
+                if posCount == self.pos:
+                    prob += log(1-expDefVal)
+                else:
+                    prob += log(1-float(posCount)/self.pos)
+                if negCount == self.neg:
+                    prob -= log(1-expDefVal)
+                else:
+                    prob -= log(1-float(negCount)/self.neg)
+
+        if prob >= 0 : 
+            return 1
+        else:
+            return 0        
+        
+    def save(self,fileName):
+        OStream = open(fileName, 'wb')
+        dump(self.counts,OStream)        
+        OStream.close()
+        
+    def load(self,fileName):
+        OStream = open(fileName, 'rb')
+        self.counts = load(OStream)      
+        OStream.close()
+
+if __name__ == '__main__':
+    x = singleNBClassifier()
+    x.update([0], True)
+    assert x.predict([0]) == 1
+    x = singleNBClassifier()
+    x.update([0], False)
+    assert x.predict([0]) == 0    
+    
+    x.update([0], True)
+    x.update([1], True)
+    print x.pos,x.neg,x.predict([0,1])
\ No newline at end of file
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py	Wed Dec 26 11:06:21 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/snow.py	Thu Dec 27 10:01:40 2012 +0100
@@ -5,16 +5,14 @@
 # Wrapper for SNoW.
 
 '''
-THIS FILE IS NOT UP TO DATE!
-NEEDS SOME FIXING BEFORE IT WILL WORK WITH THE MAIN ALGORITHM
 
 Created on Jul 12, 2012
 
 @author: daniel
 '''
 
-import logging,shlex,subprocess,string
-from cPickle import load,dump
+import logging,shlex,subprocess,string,shutil
+#from cPickle import load,dump
 
 class SNoW(object):
     '''
@@ -29,6 +27,7 @@
         self.SNoWTrainFile = '../tmp/snow.train'
         self.SNoWTestFile = '../snow.test'
         self.SNoWNetFile = '../tmp/snow.net'
+        self.defMaxNameId = 20000
 
     def initializeModel(self,trainData,dicts):
         """
@@ -38,7 +37,8 @@
         self.logger.debug('Creating IO Files')
         OS = open(self.SNoWTrainFile,'w')
         for nameId in trainData:
-            features = [f+dicts.maxNameId for f in dicts.featureDict[nameId]]
+            features = [f+dicts.maxNameId for f,_w in dicts.featureDict[nameId]]
+            #features = [f+self.defMaxNameId for f,_w in dicts.featureDict[nameId]]
             features = map(str,features)
             featureString = string.join(features,',')
             dependencies = dicts.dependenciesDict[nameId]
@@ -51,25 +51,51 @@
         # Build Model
         self.logger.debug('Building Model START.')
         snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,dicts.maxNameId-1)
+        #snowTrainCommand = '../bin/snow -train -M+ -I %s -F %s -g- -B :0-%s' % (self.SNoWTrainFile,self.SNoWNetFile,self.defMaxNameId-1)
         args = shlex.split(snowTrainCommand)
         p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
         p.wait()
         self.logger.debug('Building Model END.')
 
-
     def update(self,dataPoint,features,dependencies,dicts):
         """
         Updates the Model.
-        THIS IS NOT WORKING ATM< BUT WE DONT CARE
+        """
         """
         self.logger.debug('Updating Model START')
-        trainData = dicts.featureDict.keys()
-        self.initializeModel(trainData,dicts)
+        # Ignore Feature weights        
+        features = [f+self.defMaxNameId for f,_w in features]
+        
+        OS = open(self.SNoWTestFile,'w')
+        features = map(str,features)
+        featureString = string.join(features, ',')
+        dependencies = map(str,dependencies)
+        dependenciesString = string.join(dependencies,',')
+        snowString = string.join([featureString,dependenciesString],',')+':\n'
+        OS.write(snowString)
+        OS.close()
+        snowTestCommand = '../bin/snow -test -I %s -F %s -o allboth -i+' % (self.SNoWTestFile,self.SNoWNetFile) 
+        args = shlex.split(snowTestCommand)
+        p = subprocess.Popen(args,stdout=subprocess.PIPE,stderr=subprocess.STDOUT)
+        (_lines, _stderrdata) = p.communicate()
+        # Move new net file        
+        src = self.SNoWNetFile+'.new'
+        dst = self.SNoWNetFile
+        shutil.move(src, dst)        
         self.logger.debug('Updating Model END')
+        """
+        # Do nothing, only update at evaluation. Is a lot faster.
+        pass
 
 
     def predict(self,features,accessibles,dicts):
-        logger = logging.getLogger('predict_SNoW')
+        trainData = dicts.featureDict.keys()
+        self.initializeModel(trainData, dicts)        
+        
+        logger = logging.getLogger('predict_SNoW')        
+        # Ignore Feature weights
+        #features = [f+self.defMaxNameId for f,_w in features]
+        features = [f+dicts.maxNameId for f,_w in features]
 
         OS = open(self.SNoWTestFile,'w')
         features = map(str,features)
@@ -87,17 +113,22 @@
         assert lines[9].startswith('Example ')
         assert lines[-4] == ''
         predictionsCon = []
+        predictionsValues = []
         for line in lines[10:-4]:
             premiseId = int(line.split()[0][:-1])
             predictionsCon.append(premiseId)
-        return predictionsCon
+            val = line.split()[4]
+            if val.endswith('*'):
+                val = float(val[:-1])
+            else:
+                val = float(val)
+            predictionsValues.append(val)
+        return predictionsCon,predictionsValues
 
     def save(self,fileName):
-        OStream = open(fileName, 'wb')
-        dump(self.counts,OStream)
-        OStream.close()
-
+        # Nothing to do since we don't update
+        pass
+    
     def load(self,fileName):
-        OStream = open(fileName, 'rb')
-        self.counts = load(OStream)
-        OStream.close()
+        # Nothing to do since we don't update
+        pass
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Wed Dec 26 11:06:21 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Thu Dec 27 10:01:40 2012 +0100
@@ -2,7 +2,7 @@
 #     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
 #     Copyright   2012
 #
-# An updatable naive Bayes classifier.
+# An updatable sparse naive Bayes classifier.
 
 '''
 Created on Jul 11, 2012
@@ -37,7 +37,6 @@
         for key in dicts.dependenciesDict.keys():
             # Add p proves p
             keyDeps = [key]+dicts.dependenciesDict[key]
-
             for dep in keyDeps:
                 self.counts[dep][0] += 1
                 depFeatures = dicts.featureDict[key]
@@ -89,6 +88,8 @@
         For each accessible, predicts the probability of it being useful given the features.
         Returns a ranking of the accessibles.
         """
+        posWeight = 20.0
+        defVal = 15
         predictions = []
         for a in accessibles:
             posA = self.counts[a][0]
@@ -96,14 +97,16 @@
             fWeightsA = self.counts[a][1]
             resultA = log(posA)
             for f,w in features:
+                # DEBUG
+                #w = 1
                 if f in fA:
                     if fWeightsA[f] == 0:
-                        resultA -= w*15
+                        resultA -= w*defVal
                     else:
                         assert fWeightsA[f] <= posA
-                        resultA += w*log(float(fWeightsA[f])/posA)
+                        resultA += w*log(float(posWeight*fWeightsA[f])/posA)
                 else:
-                    resultA -= w*15
+                    resultA -= w*defVal
             predictions.append(resultA)
         #expPredictions = array([exp(x) for x in predictions])
         predictions = array(predictions)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py	Thu Dec 27 10:01:40 2012 +0100
@@ -0,0 +1,136 @@
+#     Title:      HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
+#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
+#     Copyright   2012
+#
+# An updatable sparse naive Bayes classifier.
+
+'''
+Created on Dec 26, 2012
+
+@author: Daniel Kuehlwein
+'''
+
+from singleNaiveBayes import singleNBClassifier
+from cPickle import load,dump
+import sys,logging
+
+class TheoryModels(object):
+    '''
+    MetaClass for all the theory models.
+    '''
+
+
+    def __init__(self):
+        '''
+        Constructor
+        '''
+        self.theoryModels = {}
+        self.theoryDict = {}
+        self.accessibleTheories = []
+        self.currentTheory = None
+  
+    def init(self,depFile,dicts):      
+        logger = logging.getLogger('TheoryModels')
+        IS = open(depFile,'r')
+        for line in IS:
+            line = line.split(':')
+            name = line[0]
+            theory = name.split('.')[0]
+            # Name Id
+            if not dicts.nameIdDict.has_key(name):
+                logger.warning('%s is missing in nameIdDict. Aborting.',name)
+                sys.exit(-1)
+    
+            nameId = dicts.nameIdDict[name]
+            features = dicts.featureDict[nameId]
+            if not self.theoryDict.has_key(theory):
+                assert not theory == self.currentTheory
+                if not self.currentTheory == None:
+                    self.accessibleTheories.append(self.currentTheory)
+                self.currentTheory = theory
+                self.theoryDict[theory] = set([nameId])
+                theoryModel = singleNBClassifier()
+                self.theoryModels[theory] = theoryModel 
+            else:
+                self.theoryDict[theory] = self.theoryDict[theory].union([nameId])               
+            
+            # Find the actually used theories
+            usedtheories = []    
+            dependencies = line[1].split()
+            if len(dependencies) == 0:
+                continue
+            for dep in dependencies:
+                depId = dicts.nameIdDict[dep.strip()]
+                deptheory = dep.split('.')[0]
+                usedtheories.append(deptheory)
+                if not self.theoryDict.has_key(deptheory):
+                    self.theoryDict[deptheory] = set([depId])
+                else:
+                    self.theoryDict[deptheory] = self.theoryDict[deptheory].union([depId])                   
+                        
+            # Update theoryModels
+            self.theoryModels[self.currentTheory].update(features,self.currentTheory in usedtheories)
+            for a in self.accessibleTheories:                
+                self.theoryModels[a].update(dicts.featureDict[nameId],a in usedtheories)
+        IS.close()
+    
+    def overwrite(self,problemId,newDependencies,dicts):
+        pass
+    
+    def delete(self):
+        pass
+    
+    def update(self,problemId,dicts):        
+        features = dicts.featureDict[problemId]
+        
+        # Find the actually used theories
+        tmp = [dicts.idNameDict[x] for x in dicts.dependenciesDict[problemId]]
+        usedTheories = set([x.split('.')[0] for x in tmp]) 
+        currentTheory = (dicts.idNameDict[problemId]).split('.')[0]       
+        # Create new theory model, if there is a new theory 
+        if not self.theoryDict.has_key(currentTheory):
+            assert not currentTheory == self.currentTheory
+            if not currentTheory == None:
+                self.theoryDict[currentTheory] = []
+                self.currentTheory = currentTheory
+                theoryModel = singleNBClassifier()
+                self.theoryModels[currentTheory] = theoryModel          
+        if not len(usedTheories) == 0:
+            for a in self.accessibleTheories:                
+                self.theoryModels[a].update(features,a in usedTheories)   
+    
+    def predict(self,features,accessibles,dicts):
+        """
+        Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories.
+        """         
+        # TODO: This can be made a lot faster!    
+        self.accessibleTheories = []
+        for x in accessibles:
+            xArt = (dicts.idNameDict[x]).split('.')[0]
+            self.accessibleTheories.append(xArt)
+        self.accessibleTheories = set(self.accessibleTheories)
+        
+        # Predict Theories
+        predictedTheories = [self.currentTheory]
+        for a in self.accessibleTheories:
+            if self.theoryModels[a].predict_sparse(features):
+            #if theoryModels[a].predict(dicts.featureDict[nameId]):
+                predictedTheories.append(a)
+        predictedTheories = set(predictedTheories)
+
+        # Delete accessibles in unpredicted theories
+        newAcc = []
+        for x in accessibles:
+            xArt = (dicts.idNameDict[x]).split('.')[0]
+            if xArt in predictedTheories:
+                newAcc.append(x)
+        return predictedTheories,newAcc
+        
+    def save(self,fileName):
+        outStream = open(fileName, 'wb')
+        dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict),outStream)
+        outStream.close()
+    def load(self,fileName):
+        inStream = open(fileName, 'rb')
+        self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict = load(inStream)
+        inStream.close()
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py	Thu Dec 27 10:01:40 2012 +0100
@@ -0,0 +1,63 @@
+#     Title:      HOL/Tools/Sledgehammer/MaSh/src/theoryStats.py
+#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
+#     Copyright   2012
+#
+# An updatable sparse naive Bayes classifier.
+
+'''
+Created on Dec 26, 2012
+
+@author: Daniel Kuehlwein
+'''
+
+from cPickle import load,dump
+import logging,string
+
+class TheoryStatistics(object):
+    '''
+    Stores statistics for theory lvl predictions
+    '''
+
+
+    def __init__(self):
+        '''
+        Constructor
+        '''
+        self.logger = logging.getLogger('TheoryStatistics')
+        self.count = 0
+        self.precision = 0.0
+        self.recall100 = 0
+        self.recall = 0.0
+        self.predicted = 0.0
+    
+    def update(self,currentTheory,predictedTheories,usedTheories):
+        self.count += 1
+        allPredTheories = predictedTheories.union([currentTheory])
+        if set(usedTheories).issubset(allPredTheories):
+            self.recall100 += 1
+        localPredicted = len(allPredTheories)
+        self.predicted += localPredicted 
+        localPrec = float(len(set(usedTheories).intersection(allPredTheories))) / localPredicted
+        self.precision += localPrec
+        localRecall = float(len(set(usedTheories).intersection(allPredTheories))) / len(set(usedTheories))
+        self.recall += localRecall
+        self.logger.info('Theory prediction results:')
+        self.logger.info('Problem: %s \t Recall100: %s \t Precision: %s \t Recall: %s \t PredictedTeories: %s',\
+                         self.count,self.recall100,round(localPrec,2),round(localRecall,2),localPredicted)
+        
+    def printAvg(self):
+        self.logger.info('Average theory results:')
+        self.logger.info('avgPrecision: %s \t avgRecall100: %s \t avgRecall: %s \t avgPredicted:%s', \
+                         round(self.precision/self.count,2),\
+                         round(float(self.recall100)/self.count,2),\
+                         round(self.recall/self.count,2),\
+                         round(self.predicted /self.count,2))
+        
+    def save(self,fileName):
+        oStream = open(fileName, 'wb')
+        dump((self.count,self.precision,self.recall100,self.recall,self.predicted),oStream)
+        oStream.close()
+    def load(self,fileName):
+        iStream = open(fileName, 'rb')
+        self.count,self.precision,self.recall100,self.recall,self.predicted = load(iStream)
+        iStream.close()
\ No newline at end of file