src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
author blanchet
Mon, 26 Nov 2012 13:35:05 +0100
changeset 50222 40e3c3be6bca
parent 50220 90280d85cd03
child 50388 a5b666e0c3c2
permissions -rwxr-xr-x
added file headers

#!/usr/bin/python
#     Title:      HOL/Tools/Sledgehammer/MaSh/src/mash.py
#     Author:     Daniel Kuehlwein, ICIS, Radboud University Nijmegen
#     Copyright   2012
#
# Entry point for MaSh (Machine Learning for Sledgehammer).

'''
MaSh - Machine Learning for Sledgehammer

MaSh allows to use different machine learning algorithms to predict relevant fact for Sledgehammer.

Created on July 12, 2012

@author: Daniel Kuehlwein
'''

import logging,datetime,string,os,sys
from argparse import ArgumentParser,RawDescriptionHelpFormatter
from time import time
from stats import Statistics
from dictionaries import Dictionaries
from naiveBayes import NBClassifier
from snow import SNoW
from predefined import Predefined

# Set up command-line parser
parser = ArgumentParser(description='MaSh - Machine Learning for Sledgehammer.  \n\n\
MaSh allows to use different machine learning algorithms to predict relevant facts for Sledgehammer.\n\n\
--------------- Example Usage ---------------\n\
First initialize:\n./mash.py -l test.log -o ../tmp/ --init --inputDir ../data/Nat/ \n\
Then create predictions:\n./mash.py -i ../data/Nat/mash_commands -p ../tmp/test.pred -l test.log -o ../tmp/ --statistics\n\
\n\n\
Author: Daniel Kuehlwein, July 2012',formatter_class=RawDescriptionHelpFormatter)
parser.add_argument('-i','--inputFile',help='File containing all problems to be solved.')
parser.add_argument('-o','--outputDir', default='../tmp/',help='Directory where all created files are stored. Default=../tmp/.')
parser.add_argument('-p','--predictions',default='../tmp/%s.predictions' % datetime.datetime.now(), 
                    help='File where the predictions stored. Default=../tmp/dateTime.predictions.')
parser.add_argument('--numberOfPredictions',default=200,help="Number of premises to write in the output. Default=200.",type=int)

parser.add_argument('--init',default=False,action='store_true',help="Initialize Mash. Requires --inputDir to be defined. Default=False.")
parser.add_argument('--inputDir',default='../data/Nat/',\
                    help='Directory containing all the input data. MaSh expects the following files: mash_features,mash_dependencies,mash_accessibility')
parser.add_argument('--depFile', default='mash_dependencies',
                    help='Name of the file with the premise dependencies. The file must be in inputDir. Default = mash_dependencies')
parser.add_argument('--saveModel',default=False,action='store_true',help="Stores the learned Model at the end of a prediction run. Default=False.")

parser.add_argument('--nb',default=False,action='store_true',help="Use Naive Bayes for learning. This is the default learning method.")
parser.add_argument('--snow',default=False,action='store_true',help="Use SNoW's naive bayes instead of Naive Bayes for learning.")
parser.add_argument('--predef',default=False,action='store_true',\
                    help="Use predefined predictions. Used only for comparison with the actual learning. Expects mash_meng_paulson_suggestions in inputDir.")
parser.add_argument('--statistics',default=False,action='store_true',help="Create and show statistics for the top CUTOFF predictions.\
                    WARNING: This will make the program a lot slower! Default=False.")
parser.add_argument('--saveStats',default=None,help="If defined, stores the statistics in the filename provided.")
parser.add_argument('--cutOff',default=500,help="Option for statistics. Only consider the first cutOff predictions. Default=500.",type=int)
parser.add_argument('-l','--log', default='../tmp/%s.log' % datetime.datetime.now(), help='Log file name. Default=../tmp/dateTime.log')
parser.add_argument('-q','--quiet',default=False,action='store_true',help="If enabled, only print warnings. Default=False.")

def main(argv = sys.argv[1:]):        
    # Initializing command-line arguments
    args = parser.parse_args(argv)

    # Set up logging 
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%d-%m %H:%M:%S',
                        filename=args.log,
                        filemode='w')
    console = logging.StreamHandler(sys.stdout)
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('# %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    logger = logging.getLogger('main.py')
    if args.quiet:
        logger.setLevel(logging.WARNING)
        console.setLevel(logging.WARNING)
    if not os.path.exists(args.outputDir):
        os.makedirs(args.outputDir)

    logger.info('Using the following settings: %s',args)
    # Pick algorithm
    if args.nb:
        logger.info('Using Naive Bayes for learning.')        
        model = NBClassifier() 
        modelFile = os.path.join(args.outputDir,'NB.pickle')
    elif args.snow:
        logger.info('Using naive bayes (SNoW) for learning.')
        model = SNoW()
        modelFile = os.path.join(args.outputDir,'SNoW.pickle')
    elif args.predef:
        logger.info('Using predefined predictions.')
        predictionFile = os.path.join(args.inputDir,'mash_meng_paulson_suggestions') 
        model = Predefined(predictionFile)
        modelFile = os.path.join(args.outputDir,'isabelle.pickle')        
    else:
        logger.info('No algorithm specified. Using Naive Bayes.')        
        model = NBClassifier() 
        modelFile = os.path.join(args.outputDir,'NB.pickle')    
    dictsFile = os.path.join(args.outputDir,'dicts.pickle')    
    
    # Initializing model
    if args.init:        
        logger.info('Initializing Model.')
        startTime = time()
        
        # Load all data        
        dicts = Dictionaries()
        dicts.init_all(args.inputDir,depFileName=args.depFile)
        
        # Create Model
        trainData = dicts.featureDict.keys()
        if args.predef:
            dicts = model.initializeModel(trainData,dicts)
        else:
            model.initializeModel(trainData,dicts)
        
        model.save(modelFile)
        dicts.save(dictsFile)

        logger.info('All Done. %s seconds needed.',round(time()-startTime,2))
        return 0
    # Create predictions and/or update model       
    else:
        lineCounter = 0
        dicts = Dictionaries()
        # Load Files
        if os.path.isfile(dictsFile):
            dicts.load(dictsFile)
        if os.path.isfile(modelFile):
            model.load(modelFile)
        
        # IO Streams
        OS = open(args.predictions,'a')
        IS = open(args.inputFile,'r')
        
        # Statistics
        if args.statistics:
            stats = Statistics(args.cutOff)
        
        predictions = None
        #Reading Input File
        for line in IS:
 #           try:
            if True:
                if line.startswith('!'):
                    problemId = dicts.parse_fact(line)
                    # Statistics
                    if args.statistics:
                        acc = dicts.accessibleDict[problemId]
                        if args.predef:
                            predictions = model.predict[problemId]
                        else:
                            predictions,_predictionsValues = model.predict(dicts.featureDict[problemId],dicts.expand_accessibles(acc))        
                        stats.update(predictions,dicts.dependenciesDict[problemId])
                        if not stats.badPreds == []:
                            bp = string.join([str(dicts.idNameDict[x]) for x in stats.badPreds], ',')
                            logger.debug('Bad predictions: %s',bp)    
                    # Update Dependencies, p proves p
                    dicts.dependenciesDict[problemId] = [problemId]+dicts.dependenciesDict[problemId]
                    model.update(problemId,dicts.featureDict[problemId],dicts.dependenciesDict[problemId])
                elif line.startswith('p'):
                    # Overwrite old proof.
                    problemId,newDependencies = dicts.parse_overwrite(line)
                    newDependencies = [problemId]+newDependencies
                    model.overwrite(problemId,newDependencies,dicts)
                    dicts.dependenciesDict[problemId] = newDependencies
                elif line.startswith('?'):
                    startTime = time()
                    if args.predef:
                        continue
                    name,features,accessibles = dicts.parse_problem(line)
                    # Create predictions
                    logger.info('Starting computation for problem on line %s',lineCounter)                
                    predictions,predictionValues = model.predict(features,accessibles)        
                    assert len(predictions) == len(predictionValues)
                    logger.info('Done. %s seconds needed.',round(time()-startTime,2))
                    
                    # Output        
                    predictionNames = [str(dicts.idNameDict[p]) for p in predictions[:args.numberOfPredictions]]
                    predictionValues = [str(x) for x in predictionValues[:args.numberOfPredictions]]                    
                    predictionsStringList = ['%s=%s' % (predictionNames[i],predictionValues[i]) for i in range(len(predictionNames))]                
                    predictionsString = string.join(predictionsStringList,' ')
                    outString = '%s: %s' % (name,predictionsString)
                    OS.write('%s\n' % outString)
                    lineCounter += 1
                else:
                    logger.warning('Unspecified input format: \n%s',line)
                    sys.exit(-1)
            """
            except:
                logger.warning('An error occurred on line %s .',line)
                lineCounter += 1
                continue
            """    
        OS.close()
        IS.close()
        
        # Statistics
        if args.statistics:
            stats.printAvg()
        
        # Save
        if args.saveModel:
            model.save(modelFile)
        dicts.save(dictsFile)
        if not args.saveStats == None:
            statsFile = os.path.join(args.outputDir,args.saveStats)
            stats.save(statsFile)
    return 0

if __name__ == '__main__':
    # Example:
    # Nat
    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Nat/','--predef']
    #args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--predef','-o','../tmp/','--statistics','--saveStats','../tmp/natATPMP.stats']
    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../data/Nat/']    
    #args = ['-i', '../data/Nat/mash_commands','-p','../tmp/testNB.pred','-l','../tmp/testNB.log','--nb','-o','../tmp/','--statistics','--saveStats','../tmp/natATPNB.stats','--cutOff','500']
    # BUG
    #args = ['-l','testIsabelle.log','-o','../tmp/','--statistics','--init','--inputDir','../data/List/','--isabelle']
    #args = ['-i', '../data/List/mash_commands','-p','../tmp/testIsabelle.pred','-l','testIsabelle.log','--isabelle','-o','../tmp/','--statistics']
    #args = ['-l','testNB.log','-o','../tmp/','--statistics','--init','--inputDir','../bug/init','--init']
    #args = ['-i', '../bug/adds/mash_commands','-p','../tmp/testNB.pred','-l','testNB.log','--nb','-o','../tmp/']
    #startTime = time()
    #sys.exit(main(args))
    #print 'New ' + str(round(time()-startTime,2))    
    sys.exit(main())