Add trailing newlines to saved JSON files
[quizzer.git] / quizzer / answerdb.py
1 import codecs as _codecs
2 import json as _json
3
4 from . import __version__
5
6
7 class AnswerDatabase (dict):
8     def __init__(self, path=None, encoding=None):
9         super(AnswerDatabase, self).__init__()
10         self.path = path
11         self.encoding = encoding
12
13     def _open(self, mode='r', path=None, encoding=None):
14         if path:
15             self.path = path
16         if encoding:
17             self.encoding = encoding
18         return _codecs.open(self.path, mode, self.encoding)
19
20     def load(self, **kwargs):
21         with self._open(mode='r', **kwargs) as f:
22             data = _json.load(f)
23         version = data.get('version', None)
24         if version != __version__:
25             raise NotImplementedError('upgrade from {} to {}'.format(
26                     version, __version__))
27         self.update(data['answers'])
28
29     def save(self, **kwargs):
30         data = {
31             'version': __version__,
32             'answers': self,
33             }
34         with self._open(mode='w', **kwargs) as f:
35             _json.dump(
36                 data, f, indent=2, separators=(',', ': '), sort_keys=True)
37             f.write('\n')
38
39     def add(self, question, answer, correct):
40         if question.prompt not in self:
41             self[question.prompt] = []
42         self[question.prompt].append({
43                 'answer': answer,
44                 'correct': correct,
45                 })
46
47     def get_answered(self, questions):
48         return [q for q in questions if q.prompt in self]
49
50     def get_unanswered(self, questions):
51         return [q for q in questions if q.prompt not in self]
52
53     def get_correctly_answered(self, questions):
54         return [q for q in questions
55                 if True in [a['correct'] for a in self.get(q.prompt, [])]]
56
57     def get_never_correctly_answered(self, questions):
58         return [q for q in questions
59                 if True not in [a['correct'] for a in self.get(q.prompt, [])]]