src/HOL/Tools/Sledgehammer/MaSh/src/theoryModels.py
changeset 50840 a5cc092156da
parent 50827 aba769dc82e9
equal deleted inserted replaced
50839:9cc70b273e90 50840:a5cc092156da
    18     '''
    18     '''
    19     MetaClass for all the theory models.
    19     MetaClass for all the theory models.
    20     '''
    20     '''
    21 
    21 
    22 
    22 
    23     def __init__(self):
    23     def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0):
    24         '''
    24         '''
    25         Constructor
    25         Constructor
    26         '''
    26         '''
    27         self.theoryModels = {}
    27         self.theoryModels = {}
       
    28         # Model Params
       
    29         self.defValPos = defValPos       
       
    30         self.defValNeg = defValNeg
       
    31         self.posWeight = posWeight
    28         self.theoryDict = {}
    32         self.theoryDict = {}
    29         self.accessibleTheories = set([])
    33         self.accessibleTheories = set([])
    30         self.currentTheory = None
    34         self.currentTheory = None
    31   
    35   
    32     def init(self,depFile,dicts):      
    36     def init(self,depFile,dicts):      
    47                 assert not theory == self.currentTheory
    51                 assert not theory == self.currentTheory
    48                 if not self.currentTheory == None:
    52                 if not self.currentTheory == None:
    49                     self.accessibleTheories.add(self.currentTheory)
    53                     self.accessibleTheories.add(self.currentTheory)
    50                 self.currentTheory = theory
    54                 self.currentTheory = theory
    51                 self.theoryDict[theory] = set([nameId])
    55                 self.theoryDict[theory] = set([nameId])
    52                 theoryModel = singleNBClassifier()
    56                 theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight)
    53                 self.theoryModels[theory] = theoryModel 
    57                 self.theoryModels[theory] = theoryModel 
    54             else:
    58             else:
    55                 self.theoryDict[theory] = self.theoryDict[theory].union([nameId])               
    59                 self.theoryDict[theory] = self.theoryDict[theory].union([nameId])               
    56             
    60             
    57             # Find the actually used theories
    61             # Find the actually used theories
    92         usedTheories = set([x.split('.')[0] for x in tmp])   
    96         usedTheories = set([x.split('.')[0] for x in tmp])   
    93         for a in self.accessibleTheories:        
    97         for a in self.accessibleTheories:        
    94             self.theoryModels[a].delete(features,a in usedTheories)  
    98             self.theoryModels[a].delete(features,a in usedTheories)  
    95     
    99     
    96     def update(self,problemId,features,dependencies,dicts): 
   100     def update(self,problemId,features,dependencies,dicts): 
       
   101         # TODO: Implicit assumption that self.accessibleTheories contains all accessible theories!
    97         currentTheory = (dicts.idNameDict[problemId]).split('.')[0]       
   102         currentTheory = (dicts.idNameDict[problemId]).split('.')[0]       
    98         # Create new theory model, if there is a new theory 
   103         # Create new theory model, if there is a new theory 
    99         if not self.theoryDict.has_key(currentTheory):
   104         if not self.theoryDict.has_key(currentTheory):
   100             assert not currentTheory == self.currentTheory            
   105             assert not currentTheory == self.currentTheory            
   101             if not currentTheory == None:
   106             if not currentTheory == None:
   102                 self.theoryDict[currentTheory] = []
   107                 self.theoryDict[currentTheory] = []
   103                 self.currentTheory = currentTheory
   108                 self.currentTheory = currentTheory
   104                 theoryModel = singleNBClassifier()
   109                 theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight)
   105                 self.theoryModels[currentTheory] = theoryModel
   110                 self.theoryModels[currentTheory] = theoryModel
   106                 self.accessibleTheories.add(self.currentTheory) 
   111                 self.accessibleTheories.add(self.currentTheory) 
   107         self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories)  
   112         self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories)  
   108     
   113     
   109     def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories):        
   114     def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories):        
   116     
   121     
   117     def predict(self,features,accessibles,dicts):
   122     def predict(self,features,accessibles,dicts):
   118         """
   123         """
   119         Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories.
   124         Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories.
   120         """         
   125         """         
   121         # TODO: This can be made a lot faster!    
   126         self.accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles])
   122         self.accessibleTheories = []
       
   123         for x in accessibles:
       
   124             xArt = (dicts.idNameDict[x]).split('.')[0]
       
   125             self.accessibleTheories.append(xArt)
       
   126         self.accessibleTheories = set(self.accessibleTheories)
       
   127         
   127         
   128         # Predict Theories
   128         # Predict Theories
   129         predictedTheories = [self.currentTheory]
   129         predictedTheories = [self.currentTheory]
   130         for a in self.accessibleTheories:
   130         for a in self.accessibleTheories:
   131             if self.theoryModels[a].predict_sparse(features):
   131             if self.theoryModels[a].predict_sparse(features):
   141                 newAcc.append(x)
   141                 newAcc.append(x)
   142         return predictedTheories,newAcc
   142         return predictedTheories,newAcc
   143         
   143         
   144     def save(self,fileName):
   144     def save(self,fileName):
   145         outStream = open(fileName, 'wb')
   145         outStream = open(fileName, 'wb')
   146         dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict),outStream)
   146         dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight),outStream)
   147         outStream.close()
   147         outStream.close()
   148     def load(self,fileName):
   148     def load(self,fileName):
   149         inStream = open(fileName, 'rb')
   149         inStream = open(fileName, 'rb')
   150         self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict = load(inStream)
   150         self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight = load(inStream)
   151         inStream.close()
   151         inStream.close()