Updated the docs template with a Location regexp form.
[chemdb.git] / text_db.py
1 #!/usr/bin/python
2 """
3 Simple database-style interface to text-delimited, single files.
4 Use this if, for example, your coworkers insist on keeping data compatible with M$ Excel.
5 """
6
7 import copy
8 from sys import stdin, stdout, stderr
9
10 # import os, shutil, and chem_db for managing the database files
11 import os, shutil, stat, time
12 import os.path
13
14 FILE = 'default.db'
15 STANDARD_TAB = 8
16
17 class fieldError (Exception) :
18     "database fields are not unique"
19     pass
20
21 class text_db (object) :
22     """
23     Define a simple database interface for a spread-sheet style database file.
24     
25     field_list()  : return an ordered list of available fields (fields unique)
26     long_fields() : return an dict of long field names (keyed by field names)
27     record(id)    : return a record dict (keyed by field names)
28     records()     : return an ordered list of available records.
29     backup()      : save a copy of the current database somehow.
30     set_record(i, newvals) : set a record by overwriting any preexisting data
31                            : with data from the field-name-keyed dict NEWVALS
32     new_record()  : add a blank record (use set_record to change its values)
33     """
34     def __init__(self, filename=FILE, COM_CHAR='#', FS='\t', RS='\n',
35                  current_dir='./current/', backup_dir='./backup/' ) :
36         self.filename = filename
37         self.COM_CHAR = COM_CHAR # comment character (also signals header row)
38         self.FS = FS             # field seperator
39         self.RS = RS             # record seperator
40         self.db_id_field = "db_id" # add a record field for the database index
41         # define directories used by the database
42         self.cur = current_dir
43         self.bak = backup_dir
44         
45         if self.filename == None :
46             return # for testing, don't touch the file system
47         
48         ## Generate the neccessary directory structure if neccessary
49         for d in [self.cur,self.bak] :
50             self.check_dir(d)
51
52         self._open()
53         
54     # directory and file IO operations
55     def check_dir(self, dir) :
56         "Create the database directory if it's missing"
57         if os.path.isdir(dir) :
58             return # all set to go
59         elif os.path.exists(dir) :
60             raise Exception, "Error: a non-directory file exists at %s" % dir
61         else :
62             os.mkdir(dir)
63     def curpath(self) :
64         "Return the path to the current database file."
65         return os.path.join(self.cur, self.filename)
66     def _get_mtime(self) :
67         "Get the timestamp of the last modification to the database."
68         s = os.stat(self.curpath())
69         return s[stat.ST_MTIME]
70     def exists(self) :
71         "Check if the database exists"
72         # for some reason, my system's os.path.exists
73         # returns false for valid symbolic links...
74         return os.path.exists(self.curpath())
75     def _assert_exists(self) :
76         """
77         Assert that the database exists on disk.
78         Print a reasonable error if it does not.
79         """
80         assert self.exists(), "Missing database file %s" % self.curpath()
81     def _open(self) :
82         "Load the database from disk"
83         self._assert_exists()
84         # precedence AND > OR
85         fulltext = file(self.curpath(), 'r').read()
86         self._mtime = self._get_mtime()
87         self._parse(fulltext)
88     def iscurrent(self) :
89         "Check if our memory-space db is still syncd with the disk-space db."
90         return self._mtime == self._get_mtime()
91     def _refresh(self) :
92         "If neccessary, reload the database from disk."
93         if not self.iscurrent() :
94             self._open()
95     def _save(self) :
96         "Create a new database file from a header and list of records."
97         # save the new text
98         fid = file(self.curpath(), 'w')
99         fid.write( self._file_header_string(self._header) )
100         if self._long_header :
101             fid.write( self._file_header_string(self._long_header) )
102         for record in self._records :
103             fid.write( self._file_record_string(record) )
104         fid.close()
105     def backup(self) :
106         "Back up database file"
107         if not self.exists(): return None # nothing to back up.
108         # Append a timestamp to the file & copy to self.bak
109         # the str() builtin ensures a nice, printable string
110         tname = self.filename+'.'+str(int(time.time()))
111         tpath = os.path.join(self.bak, tname)
112         spath = self.curpath()
113         shutil.copy(spath, tpath)
114         
115     # file-text to memory  operations
116     def _get_header(self, head_line, assert_unique=False,
117                      assert_no_db_id_field=True) :
118         """
119         Parse a header line (starts with the comment character COM_CHAR).
120         
121         Because doctest doesn't play well with tabs, use colons as field seps.
122         >>> db = text_db(FS=':', filename=None)
123         >>> print db._get_header("#Name:Field 1:ID: another field")
124         ['Name', 'Field 1', 'ID', ' another field']
125         >>> try :
126         ...    x = db._get_header("#Name:Field 1:Name: another field", assert_unique=True)
127         ... except fieldError, s :
128         ...    print s
129         fields 0 and 2 both 'Name'
130         """
131         assert len(head_line) > 0, 'empty header'
132         assert head_line[0] == self.COM_CHAR, 'bad header: "%s"' % head_line
133         fields = head_line[1:].split(self.FS)
134         if assert_unique :
135             for i in range(len(fields)) :
136                 for j in range(i+1,len(fields)) :
137                     if fields[i] == fields[j] :
138                         raise fieldError, "fields %d and %d both '%s'" \
139                                           % (i,j,fields[i])
140         if assert_no_db_id_field :
141             for i in range(len(fields)) :
142                 if fields[i] == self.db_id_field :
143                     raise fieldError, "fields %d uses db_id field '%s'" \
144                                       % (i,fields[i])
145         return fields
146     def _get_fields(self, line, num_fields=None) :
147         """
148         Parse a record line.
149         
150         Because doctest doesn't play well with tabs, use colons as field seps.
151         >>> db = text_db(FS=':', filename=None)
152         >>> print db._get_fields("2-Propanol:4 L:67-63-0:Fisher:6/6/2004",7)
153         ['2-Propanol', '4 L', '67-63-0', 'Fisher', '6/6/2004', '', '']
154         """
155         vals = line.split(self.FS)
156         if num_fields != None :
157             assert len(vals) <= num_fields, "Too many values in '%s'" % line
158             for i in range(len(vals), num_fields) :
159                 vals.append('') # pad with empty strings if neccessary
160         return vals
161     def _parse(self, text) :
162         reclines = text.split(self.RS)
163         assert len(reclines) > 0, "Empty database file"
164         self._header = self._get_header(reclines[0], assert_unique=True)
165         self._long_header = None
166         self._records = []
167         if len(reclines) == 1 :
168             return # Only a header
169         # check for a long-header line
170         if reclines[1][0] == self.COM_CHAR :
171             self._long_header = self._get_header(reclines[1])
172             startline = 2
173         else :
174             startline = 1
175         for recline in reclines[startline:] :
176             if len(recline) == 0 :
177                 continue # ignore blank lines
178             self._records.append(self._get_fields(recline, len(self._header)))
179             
180             
181     # memory to file-text  operations
182     def _file_record_string(self, record) :
183         """
184         Format record for creating a new database file.
185         
186         Because doctest doesn't play well with tabs, use colons as field seps.
187         >>> db = text_db(FS=':', RS=';', filename=None)
188         >>> rs="2-Propanol:4 L:67-63-0:BPA426P-4:Fisher:6/6/2004:2:3:0"
189         >>> print db._file_record_string( db._get_fields(rs)) == (rs+";")
190         True
191         """
192         return "%s%s" % (self.FS.join(record), self.RS)
193     def _file_header_string(self, header) :
194         """
195         Format header for creating a new database file.
196         """
197         return "%s%s%s" % (self.COM_CHAR, self.FS.join(header), self.RS)
198
199     
200     # nice, stable api for our users
201     def field_list(self) : 
202         "return an ordered list of available fields (fields unique)"
203         return copy.copy(self._header)
204     def long_fields(self) :
205         "return an dict of long field names (keyed by field names)"
206         if self._long_header :
207             return dict(zip(self._header, self._long_header))
208         else : # default to the standard field names
209             return dict(zip(self._header, self._header))
210     def record(self, db_id) :
211         "return a record dict (keyed by field names)"
212         assert type(db_id) == type(1), "id %s not an int!" % str(db_id)
213         assert db_id < len(self._records), "record %d does not exist" % db_id
214         d = dict(zip(self._header, self._records[db_id]))
215         d['db_id'] = db_id
216         return d
217     def records(self) :
218         "return an ordered list of available records."
219         ret = []
220         for id in range(len(self._records)) :
221             ret.append(self.record(id))
222         return ret
223     def len_records(self) :
224         "return len(self.records()), but more efficiently"
225         return len(self._records)
226     def set_record(self, db_id, newvals, backup=True) :
227         """
228         set a record by overwriting any preexisting data
229         with data from the field-name-keyed dict NEWVALS
230         """
231         if backup :
232             self.backup()
233         for k,v in newvals.items() :
234             if k == self.db_id_field :
235                 assert int(v) == db_id, \
236                     "don't set the db_id field! (attempted %d -> %d)" \
237                     % (db_id, int(v))
238                 continue
239             # get the index for the specified field
240             assert k in self._header, "unrecognized field '%s'" % k
241             fi = self._header.index(k)
242             # overwrite the record value
243             self._records[db_id][fi] = v
244         self._save()
245     def new_record(self, db_id=None) :
246         """
247         create a blank new record and return it.
248         """
249         record = {}
250         for field in self._header :
251             record[field] = ""
252         record[self.db_id_field] = len(self._records)
253         self._records.append(['']*len(self._header))
254         return record
255
256 class indexStringError (Exception) :
257     "invalid index string format"
258     pass
259
260
261 class db_pretty_printer (object) :
262     """
263     Define some pretty-print functions for text_db objects.
264     """
265     def __init__(self, db) :
266         self.db = db
267     def _norm_active_fields(self, active_fields_in=None) :
268         """
269         Normalize the active field parameter
270         
271         >>> db = text_db(FS=':', RS=' ; ', filename=None)
272         >>> pp = db_pretty_printer(db)
273         >>> db._parse("#Name:Amount:CAS#:Vendor ; 2-Propanol:4 L:67-63-0:Fisher")
274         >>> print pp._norm_active_fields(None) == {'Name':True, 'Amount':True, 'CAS#':True, 'Vendor':True}
275         True
276         >>> print pp._norm_active_fields(['Vendor', 'Amount']) == {'Name':False, 'Amount':True, 'CAS#':False, 'Vendor':True}
277         True
278         >>> print pp._norm_active_fields('1:3') == {'Name':False, 'Amount':True, 'CAS#':True, 'Vendor':False}
279         True
280         """
281         if active_fields_in == None :
282             active_fields = {}
283             for field in self.db.field_list() :
284                 active_fields[field] = True
285             return active_fields
286         elif type(active_fields_in) == type('') :
287             active_i = self._istr2ilist(active_fields_in)
288             active_fields = {}
289             fields = self.db.field_list()
290             for i in range(len(fields)) :
291                 if i in active_i :
292                     active_fields[fields[i]] = True
293                 else :
294                     active_fields[fields[i]] = False
295         else :
296             if type(active_fields_in) == type([]) :
297                 active_fields = {}
298                 for field in active_fields_in :
299                     active_fields[field] = True
300             elif type(active_fields_in) == type({}) :
301                 active_fields = active_fields_in
302             assert type(active_fields) == type({}), 'by this point, should be a dict'
303             for field in self.db.field_list() :
304                 if not field in active_fields :
305                     active_fields[field] = False
306         return active_fields
307     def full_record_string(self, record, active_fields=None) :
308         """
309         Because doctest doesn't play well with tabs, use colons as field seps.
310         >>> db = text_db(FS=':', RS=' ; ', filename=None)
311         >>> pp = db_pretty_printer(db)
312         >>> db._parse("#Name:Amount:CAS#:Vendor ; 2-Propanol:4 L:67-63-0:Fisher")
313         >>> print pp.full_record_string( db.record(0) ),
314           Name : 2-Propanol
315         Amount : 4 L
316           CAS# : 67-63-0
317         Vendor : Fisher
318         """
319         fields = self.db.field_list()
320         long_fields = self.db.long_fields()
321         active_fields = self._norm_active_fields(active_fields)
322         # scan through and determine the width of the largest field
323         w = 1
324         for field in fields :
325             if active_fields[field] and len(field) > w :
326                 w = len(field)
327         # generate the pretty-print string
328         string = ""
329         for field in fields :
330             if field in active_fields and active_fields[field] :
331                 string += "%*.*s : %s\n" \
332                           % (w, w, long_fields[field], record[field])
333         return string
334     def full_record_string_id(self, id, active_fields=None) :
335         record = self.db.record(id)
336         return self.full_record_string(record, active_fields)
337     def _istr2ilist(self, index_string) :
338         """
339         Generate index lists from assorted string formats.
340         
341         Parse index strings
342         >>> pp = db_pretty_printer('dummy')
343         >>> print pp._istr2ilist('0,2,89,4')
344         [0, 2, 89, 4]
345         >>> print pp._istr2ilist('1:6')
346         [1, 2, 3, 4, 5]
347         >>> print pp._istr2ilist('0,3,6:9,2')
348         [0, 3, 6, 7, 8, 2]
349         """
350         ret = []
351         for spl in index_string.split(',') :
352             s = spl.split(':')
353             if len(s) == 1 :
354                 ret.append(int(spl))
355             elif len(s) == 2 :
356                 for i in range(int(s[0]),int(s[1])) :
357                     ret.append(i)
358             else :
359                 raise indexStringError, "unrecognized index '%s'" % spl
360         return ret
361     def _norm_width(self, width_in=None, active_fields=None, skinny=True) :
362         "Normalize the width parameter"
363         active_fields = self._norm_active_fields(active_fields)
364         if type(width_in) == type(1) or width_in == 'a' : # constant width
365             width = {} # set all fields to this width
366             for field in active_fields.keys() :
367                 width[field] = width_in
368         else :
369             if width_in == None :
370                 width_in = {}
371             width = {}
372             for field in active_fields.keys() :
373                 # fill in the gaps in the current width 
374                 if field in width_in :
375                     width[field] = width_in[field]
376                 else : # field doesn't exist
377                     if skinny : # set to a fixed width
378                         # -1 to leave room for FS
379                         width[field] = STANDARD_TAB-1
380                     else : # set to automatic
381                         width[field] = 'a'
382         return width
383     def _norm_record_ids(self, record_ids=None) :
384         "Normalize the record_ids parameter"
385         if record_ids == None :
386             record_ids = range(len(self.db.records()))
387         if type(record_ids) == type('') :
388             record_ids = self._istr2ilist(record_ids)
389         stderr.flush()
390         return record_ids
391     def _line_record_string(self, record, width=None, active_fields=None,
392                                FS=None, RS=None, TRUNC_STRING=None) :
393         """
394         Because doctest doesn't play well with tabs, use colons as field seps.
395         >>> db = text_db(FS=':', RS=' ; ', filename=None)
396         >>> pp = db_pretty_printer(db)
397         >>> db._parse("#Name:Amount:CAS#:Vendor ; 2-Propanol:4 L:67-63-0:Fisher")
398         >>> print pp._line_record_string_id(0)
399         2-Propa:    4 L:67-63-0: Fisher ; 
400         """
401         fields = self.db.field_list()
402         active_fields = self._norm_active_fields(active_fields)
403         width = self._norm_width(width)
404         if FS == None :
405             FS = self.db.FS
406         if RS == None :
407             RS = self.db.RS
408         for field in fields :
409             if field in active_fields and active_fields[field] :
410                 lastfield = field
411         # generate the pretty-print string
412         string = ""
413         for field in fields :
414             if field in active_fields and active_fields[field] :
415                 w = width[field]
416                 string += "%*.*s" % (w, w, record[field])
417                 if field != lastfield :
418                     string += "%s" % (FS)
419         string += RS
420         return string
421     def _line_record_string_id(self, id, width=None, active_fields=None,
422                                FS=None, RS=None, TRUNC_STRING=None) :
423         return self._line_record_string(self.db.record(id),
424                                         width, active_fields, FS, RS,
425                                         TRUNC_STRING)
426     def _get_field_width(self, record_ids, field) :
427         """
428         Return the width of the longest value in FIELD
429         for all the records with db_ids in record_ids.
430         """
431         width = 1
432         for i in record_ids :
433             w = len(self.db.record(i)[field])
434             if w > width :
435                 width = w
436         return width
437     def _get_width(self, width_in, active_fields=None, record_ids=None) :
438         """
439         Return the width of the largest value in FIELD
440         for all the records with db_ids in record_ids.
441         """
442         active_fields = self._norm_active_fields(active_fields)
443         width = self._norm_width(width_in, active_fields)
444         record_ids = self._norm_record_ids(record_ids)
445
446         for field in active_fields :
447             if width[field] == 'a' :
448                 width[field] = self._get_field_width(record_ids, field)
449         return width
450     def multi_record_string(self, record_ids=None, active_fields=None,
451                             width=None, FS=None, RS=None, COM_CHAR=None,
452                             TRUNC_STRING=None) :
453         """
454         Because doctest doesn't play well with tabs, use colons as field seps.
455         >>> db = text_db(FS=':', RS=' ; ', filename=None)
456         >>> pp = db_pretty_printer(db)
457         >>> db._parse("#Name:Amount:CAS#:Vendor ; 2-Propanol:4 L:67-63-0:Fisher")
458         >>> print pp.multi_record_string('0'),
459            Name: Amount:   CAS#: Vendor ; 2-Propa:    4 L:67-63-0: Fisher ; 
460         """
461         if FS == None :
462             FS = self.db.FS
463         if RS == None :
464             RS = self.db.RS
465         if COM_CHAR == None :
466             COM_CHAR = self.db.COM_CHAR
467         active_fields = self._norm_active_fields(active_fields)
468         record_ids = self._norm_record_ids(record_ids)
469         width = self._get_width(width, active_fields, record_ids)
470         # generate the pretty-print string
471         string = ""
472         # print a header line:
473         fields = self.db.field_list()
474         hvals = dict(zip(fields,fields))
475         string += "%s" % self._line_record_string(hvals, width,
476                                                   active_fields, FS, RS,
477                                                   TRUNC_STRING=TRUNC_STRING)
478         # print the records
479         for id in record_ids :
480             string += self._line_record_string_id(id,  width,
481                                                   active_fields, FS, RS,
482                                                   TRUNC_STRING=TRUNC_STRING)
483         return string
484
485
486 def _test():
487     import doctest
488     doctest.testmod()
489     
490 if __name__ == "__main__" :
491     _test()