18 ''' |
18 ''' |
19 MetaClass for all the theory models. |
19 MetaClass for all the theory models. |
20 ''' |
20 ''' |
21 |
21 |
22 |
22 |
23 def __init__(self): |
23 def __init__(self,defValPos = -7.5,defValNeg = -15.0,posWeight = 10.0): |
24 ''' |
24 ''' |
25 Constructor |
25 Constructor |
26 ''' |
26 ''' |
27 self.theoryModels = {} |
27 self.theoryModels = {} |
|
28 # Model Params |
|
29 self.defValPos = defValPos |
|
30 self.defValNeg = defValNeg |
|
31 self.posWeight = posWeight |
28 self.theoryDict = {} |
32 self.theoryDict = {} |
29 self.accessibleTheories = set([]) |
33 self.accessibleTheories = set([]) |
30 self.currentTheory = None |
34 self.currentTheory = None |
31 |
35 |
32 def init(self,depFile,dicts): |
36 def init(self,depFile,dicts): |
47 assert not theory == self.currentTheory |
51 assert not theory == self.currentTheory |
48 if not self.currentTheory == None: |
52 if not self.currentTheory == None: |
49 self.accessibleTheories.add(self.currentTheory) |
53 self.accessibleTheories.add(self.currentTheory) |
50 self.currentTheory = theory |
54 self.currentTheory = theory |
51 self.theoryDict[theory] = set([nameId]) |
55 self.theoryDict[theory] = set([nameId]) |
52 theoryModel = singleNBClassifier() |
56 theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight) |
53 self.theoryModels[theory] = theoryModel |
57 self.theoryModels[theory] = theoryModel |
54 else: |
58 else: |
55 self.theoryDict[theory] = self.theoryDict[theory].union([nameId]) |
59 self.theoryDict[theory] = self.theoryDict[theory].union([nameId]) |
56 |
60 |
57 # Find the actually used theories |
61 # Find the actually used theories |
92 usedTheories = set([x.split('.')[0] for x in tmp]) |
96 usedTheories = set([x.split('.')[0] for x in tmp]) |
93 for a in self.accessibleTheories: |
97 for a in self.accessibleTheories: |
94 self.theoryModels[a].delete(features,a in usedTheories) |
98 self.theoryModels[a].delete(features,a in usedTheories) |
95 |
99 |
96 def update(self,problemId,features,dependencies,dicts): |
100 def update(self,problemId,features,dependencies,dicts): |
|
101 # TODO: Implicit assumption that self.accessibleTheories contains all accessible theories! |
97 currentTheory = (dicts.idNameDict[problemId]).split('.')[0] |
102 currentTheory = (dicts.idNameDict[problemId]).split('.')[0] |
98 # Create new theory model, if there is a new theory |
103 # Create new theory model, if there is a new theory |
99 if not self.theoryDict.has_key(currentTheory): |
104 if not self.theoryDict.has_key(currentTheory): |
100 assert not currentTheory == self.currentTheory |
105 assert not currentTheory == self.currentTheory |
101 if not currentTheory == None: |
106 if not currentTheory == None: |
102 self.theoryDict[currentTheory] = [] |
107 self.theoryDict[currentTheory] = [] |
103 self.currentTheory = currentTheory |
108 self.currentTheory = currentTheory |
104 theoryModel = singleNBClassifier() |
109 theoryModel = singleNBClassifier(self.defValPos,self.defValNeg,self.posWeight) |
105 self.theoryModels[currentTheory] = theoryModel |
110 self.theoryModels[currentTheory] = theoryModel |
106 self.accessibleTheories.add(self.currentTheory) |
111 self.accessibleTheories.add(self.currentTheory) |
107 self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories) |
112 self.update_with_acc(problemId,features,dependencies,dicts,self.accessibleTheories) |
108 |
113 |
109 def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories): |
114 def update_with_acc(self,problemId,features,dependencies,dicts,accessibleTheories): |
116 |
121 |
117 def predict(self,features,accessibles,dicts): |
122 def predict(self,features,accessibles,dicts): |
118 """ |
123 """ |
119 Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories. |
124 Predicts the relevant theories. Returns the predicted theories and a list of all accessible premises in these theories. |
120 """ |
125 """ |
121 # TODO: This can be made a lot faster! |
126 self.accessibleTheories = set([(dicts.idNameDict[x]).split('.')[0] for x in accessibles]) |
122 self.accessibleTheories = [] |
|
123 for x in accessibles: |
|
124 xArt = (dicts.idNameDict[x]).split('.')[0] |
|
125 self.accessibleTheories.append(xArt) |
|
126 self.accessibleTheories = set(self.accessibleTheories) |
|
127 |
127 |
128 # Predict Theories |
128 # Predict Theories |
129 predictedTheories = [self.currentTheory] |
129 predictedTheories = [self.currentTheory] |
130 for a in self.accessibleTheories: |
130 for a in self.accessibleTheories: |
131 if self.theoryModels[a].predict_sparse(features): |
131 if self.theoryModels[a].predict_sparse(features): |
141 newAcc.append(x) |
141 newAcc.append(x) |
142 return predictedTheories,newAcc |
142 return predictedTheories,newAcc |
143 |
143 |
144 def save(self,fileName): |
144 def save(self,fileName): |
145 outStream = open(fileName, 'wb') |
145 outStream = open(fileName, 'wb') |
146 dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict),outStream) |
146 dump((self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight),outStream) |
147 outStream.close() |
147 outStream.close() |
148 def load(self,fileName): |
148 def load(self,fileName): |
149 inStream = open(fileName, 'rb') |
149 inStream = open(fileName, 'rb') |
150 self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict = load(inStream) |
150 self.currentTheory,self.accessibleTheories,self.theoryModels,self.theoryDict,self.defValPos,self.defValNeg,self.posWeight = load(inStream) |
151 inStream.close() |
151 inStream.close() |