have MaSh support nameless facts (i.e. proofs) and use that support
authorblanchet
Thu, 14 Nov 2013 15:57:48 +0100
changeset 54432 68f8bd1641da
parent 54431 e98996c2a32c
child 54433 b1721e5b8717
have MaSh support nameless facts (i.e. proofs) and use that support
src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py
src/HOL/Tools/Sledgehammer/MaSh/src/readData.py
src/HOL/Tools/Sledgehammer/MaSh/src/server.py
src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/dictionaries.py	Thu Nov 14 15:57:48 2013 +0100
@@ -22,7 +22,7 @@
         self.nameIdDict = {}
         self.idNameDict = {}
         self.featureIdDict={}
-        self.maxNameId = 0
+        self.maxNameId = 1
         self.maxFeatureId = 0
         self.featureDict = {}
         self.dependenciesDict = {}
@@ -30,6 +30,9 @@
         self.expandedAccessibles = {}
         self.accFile =  ''
         self.changed = True
+        # Unnamed facts
+        self.nameIdDict[''] = 0
+        self.idNameDict[0] = 'Unnamed Fact'
 
     """
     Init functions. nameIdDict, idNameDict, featureIdDict, articleDict get filled!
@@ -153,13 +156,18 @@
         name = line[0].strip()
         nameId = self.get_name_id(name)
         line = line[1].split(';')
-        # Accessible Ids
-        #unExpAcc = [self.nameIdDict[a.strip()] for a in line[0].split()]
-        #self.accessibleDict[nameId] = unExpAcc
-        self.accessibleDict[nameId] = self.parse_unExpAcc(line[0])
         features = self.get_features(line)
         self.featureDict[nameId] = features
-        self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()]        
+        try:
+            self.dependenciesDict[nameId] = [self.nameIdDict[d.strip()] for d in line[2].split()]        
+        except:
+            unknownDeps = []
+            for d in line[2].split():
+                if not self.nameIdDict.has_key(d):
+                    unknownDeps.append(d)
+            raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps))
+        self.accessibleDict[nameId] = self.parse_unExpAcc(line[0])
+
         self.changed = True
         return nameId
 
@@ -173,9 +181,18 @@
         # line = name:dependencies
         line = line.split(':')
         name = line[0].strip()
-        nameId = self.get_name_id(name)
-
-        dependencies = [self.nameIdDict[d.strip()] for d in line[1].split()]
+        try:
+            nameId = self.nameIdDict[name]
+        except:
+            raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % name)
+        try:
+            dependencies = [self.nameIdDict[d.strip()] for d in line[1].split()]
+        except:
+            unknownDeps = []
+            for d in line[1].split():
+                if not self.nameIdDict.has_key(d):
+                    unknownDeps.append(d)
+            raise LookupError('Unknown fact used as dependency: %s. Facts need to be introduced before being used as depedency.' % ','.join(unknownDeps))
         self.changed = True
         return nameId,dependencies
 
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/readData.py	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/readData.py	Thu Nov 14 15:57:48 2013 +0100
@@ -29,7 +29,10 @@
         nameId = nameIdDict[name]
         dependenciesIds = [nameIdDict[f.strip()] for f in line[1].split()]
         # Store results, add p proves p
-        dependenciesDict[nameId] = [nameId] + dependenciesIds
+        if nameId == 0:
+            dependenciesDict[nameId] = dependenciesIds
+        else:
+            dependenciesDict[nameId] = [nameId] + dependenciesIds
     IS.close()
     return dependenciesDict
 
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/server.py	Thu Nov 14 15:57:48 2013 +0100
@@ -14,6 +14,7 @@
 from sparseNaiveBayes import sparseNBClassifier
 from KNN import KNN,euclidean
 from KNNs import KNNAdaptPointFeatures,KNNUrban
+#from bayesPlusMetric import sparseNBPlusClassifier
 from predefined import Predefined
 from ExpandFeatures import ExpandFeatures
 from stats import Statistics
@@ -58,15 +59,28 @@
         else:
             argv = argv.split(';')
             self.server.args = init_parser(argv)
+
+        # Set up logging
+        logging.basicConfig(level=logging.DEBUG,
+                            format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
+                            datefmt='%d-%m %H:%M:%S',
+                            filename=self.server.args.log+'server',
+                            filemode='w')    
+        self.server.logger = logging.getLogger('server')
+            
         # Load all data
         self.server.dicts = Dictionaries()
         if os.path.isfile(self.server.args.dictsFile):
-            self.server.dicts.load(self.server.args.dictsFile)            
+            self.server.dicts.load(self.server.args.dictsFile)
+        #elif not self.server.args.dictsFile == '../tmp/dict.pickle':
+        #    raise IOError('Cannot find dictsFile at %s '% self.server.args.dictsFile)        
         elif self.server.args.init:
             self.server.dicts.init_all(self.server.args)
         # Pick model
         if self.server.args.algorithm == 'nb':
-            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
+            ###TODO: !! 
+            self.server.model = sparseNBClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)            
+            #self.server.model = sparseNBPlusClassifier(self.server.args.NBDefaultPriorWeight,self.server.args.NBPosWeight,self.server.args.NBDefVal)
         elif self.server.args.algorithm == 'KNN':
             #self.server.model = KNN(self.server.dicts)
             self.server.model = KNNAdaptPointFeatures(self.server.dicts)
@@ -80,6 +94,8 @@
         # Create Model
         if os.path.isfile(self.server.args.modelFile):
             self.server.model.load(self.server.args.modelFile)          
+        #elif not self.server.args.modelFile == '../tmp/model.pickle':
+        #    raise IOError('Cannot find modelFile at %s '% self.server.args.modelFile)        
         elif self.server.args.init:
             trainData = self.server.dicts.featureDict.keys()
             self.server.model.initializeModel(trainData,self.server.dicts)
@@ -89,13 +105,6 @@
             self.server.statementCounter = 1
             self.server.computeStats = False
 
-        # Set up logging
-        logging.basicConfig(level=logging.DEBUG,
-                            format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
-                            datefmt='%d-%m %H:%M:%S',
-                            filename=self.server.args.log+'server',
-                            filemode='w')    
-        self.server.logger = logging.getLogger('server')
         self.server.logger.debug('Initialized in '+str(round(time()-self.startTime,2))+' seconds.')
         self.request.sendall('Server initialized in '+str(round(time()-self.startTime,2))+' seconds.')
         self.server.callCounter = 1
@@ -117,8 +126,11 @@
         if self.server.args.expandFeatures:
             self.server.expandFeatures.update(self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
         # Update Dependencies, p proves p
-        self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
+        if not problemId == 0:
+            self.server.dicts.dependenciesDict[problemId] = [problemId]+self.server.dicts.dependenciesDict[problemId]
+        ###TODO: 
         self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId])
+        #self.server.model.update(problemId,self.server.dicts.featureDict[problemId],self.server.dicts.dependenciesDict[problemId],self.server.dicts)
 
     def overwrite(self):
         # Overwrite old proof.
--- a/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/MaSh/src/sparseNaiveBayes.py	Thu Nov 14 15:57:48 2013 +0100
@@ -41,10 +41,11 @@
             self.counts[d] = [self.defaultPriorWeight,dFeatureCounts]
 
         for key,keyDeps in dicts.dependenciesDict.iteritems():
+            keyFeatures = dicts.featureDict[key]
             for dep in keyDeps:
                 self.counts[dep][0] += 1
-                depFeatures = dicts.featureDict[key]
-                for f in depFeatures.iterkeys():
+                #depFeatures = dicts.featureDict[key]
+                for f in keyFeatures.iterkeys():
                     if self.counts[dep][1].has_key(f):
                         self.counts[dep][1][f] += 1
                     else:
@@ -55,7 +56,7 @@
         """
         Updates the Model.
         """
-        if not self.counts.has_key(dataPoint):
+        if (not self.counts.has_key(dataPoint)) and (not dataPoint == 0):
             dFeatureCounts = {}            
             # Give p |- p a higher weight
             if not self.defaultPriorWeight == 0:               
@@ -86,7 +87,10 @@
         """
         Deletes the old dependencies of problemId and replaces them with the new ones. Updates the model accordingly.
         """
-        assert self.counts.has_key(problemId)
+        try:
+            assert self.counts.has_key(problemId)
+        except:
+            raise LookupError('Trying to overwrite dependencies for unknown fact: %s. Facts need to be introduced before overwriting them.' % dicts.idNameDict[problemId])
         oldDeps = dicts.dependenciesDict[problemId]
         features = dicts.featureDict[problemId]
         self.delete(problemId,features,oldDeps)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Nov 14 15:40:06 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Nov 14 15:57:48 2013 +0100
@@ -216,10 +216,6 @@
 val unencode_strs =
   space_explode " " #> filter_out (curry (op =) "") #> map unencode_str
 
-fun freshish_name () =
-  Date.fmt ".%Y%m%d_%H%M%S__" (Date.fromTimeLocal (Time.now ())) ^
-  serial_string ()
-
 (* Avoid scientific notation *)
 fun safe_str_of_real r =
   if r < 0.00001 then "0.00001"
@@ -282,10 +278,11 @@
 
 fun learn _ _ _ [] = ()
   | learn ctxt overlord save learns =
-    (trace_msg ctxt (fn () => "MaSh learn " ^
-         elide_string 1000 (space_implode " " (map #1 learns)));
-     run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
-                   (learns, str_of_learn) (K ()))
+    let val names = elide_string 1000 (space_implode " " (map #1 learns)) in
+      (trace_msg ctxt (fn () => "MaSh learn" ^ (if names = "" then "" else " " ^ names));
+       run_mash_tool ctxt overlord ([] |> save ? cons save_models_arg) false
+                     (learns, str_of_learn) (K ()))
+    end
 
 fun relearn _ _ _ [] = ()
   | relearn ctxt overlord save relearns =
@@ -1026,7 +1023,6 @@
     launch_thread (timeout |> the_default one_day) (fn () =>
         let
           val thy = Proof_Context.theory_of ctxt
-          val name = freshish_name ()
           val feats = features_of ctxt thy 0 Symtab.empty (Local, General) [t] |> map fst
         in
           peek_state ctxt overlord (fn {access_G, ...} =>
@@ -1036,7 +1032,7 @@
                   used_ths |> filter (is_fact_in_graph access_G)
                            |> map nickname_of_thm
               in
-                MaSh.learn ctxt overlord true [(name, parents, feats, deps)]
+                MaSh.learn ctxt overlord true [("", parents, feats, deps)]
               end);
           (true, "")
         end)