Add a persistent answers database
authorW. Trevor King <wking@tremily.us>
Tue, 5 Feb 2013 15:01:25 +0000 (10:01 -0500)
committerW. Trevor King <wking@tremily.us>
Tue, 5 Feb 2013 15:01:25 +0000 (10:01 -0500)
quizzer/answerdb.py [new file with mode: 0644]
quizzer/cli.py
quizzer/ui/__init__.py
quizzer/ui/cli.py

diff --git a/quizzer/answerdb.py b/quizzer/answerdb.py
new file mode 100644 (file)
index 0000000..869641e
--- /dev/null
@@ -0,0 +1,58 @@
+import codecs as _codecs
+import json as _json
+
+from . import __version__
+
+
+class AnswerDatabase (dict):
+    def __init__(self, path=None, encoding=None):
+        super(AnswerDatabase, self).__init__()
+        self.path = path
+        self.encoding = encoding
+
+    def _open(self, mode='r', path=None, encoding=None):
+        if path:
+            self.path = path
+        if encoding:
+            self.encoding = encoding
+        return _codecs.open(self.path, mode, self.encoding)
+
+    def load(self, **kwargs):
+        with self._open(mode='r', **kwargs) as f:
+            data = _json.load(f)
+        version = data.get('version', None)
+        if version != __version__:
+            raise NotImplementedError('upgrade from {} to {}'.format(
+                    version, __version__))
+        self.update(data['answers'])
+
+    def save(self, **kwargs):
+        data = {
+            'version': __version__,
+            'answers': self,
+            }
+        with self._open(mode='w', **kwargs) as f:
+            _json.dump(
+                data, f, indent=2, separators=(',', ': '), sort_keys=True)
+
+    def add(self, question, answer, correct):
+        if question.prompt not in self:
+            self[question.prompt] = []
+        self[question.prompt].append({
+                'answer': answer,
+                'correct': correct,
+                })
+
+    def get_answered(self, questions):
+        return [q for q in questions if q.prompt in self]
+
+    def get_unanswered(self, questions):
+        return [q for q in questions if q.prompt not in self]
+
+    def get_correctly_answered(self, questions):
+        return [q for q in questions
+                if True in [a['correct'] for a in self.get(q.prompt, [])]]
+
+    def get_never_correctly_answered(self, questions):
+        return [q for q in questions
+                if True not in [a['correct'] for a in self.get(q.prompt, [])]]
index 6faf1cbf8b2eb8c55a86eb5ffd99f36a5cb10b4c..4cf25f8b065889ff460b39a359ca6198e39aa727 100644 (file)
@@ -3,6 +3,7 @@ import locale as _locale
 
 from . import __doc__ as _module_doc
 from . import __version__
+from . import answerdb as _answerdb
 from . import quiz as _quiz
 from .ui import cli as _cli
 
@@ -14,6 +15,9 @@ def main():
     parser.add_argument(
         '--version', action='version',
         version='%(prog)s {}'.format(__version__))
+    parser.add_argument(
+        '-a', '--answers', metavar='ANSWERS', default='answers.json',
+        help='path to an answers database')
     parser.add_argument(
         'quiz', metavar='QUIZ',
         help='path to a quiz file')
@@ -22,6 +26,12 @@ def main():
 
     quiz = _quiz.Quiz(path=args.quiz, encoding=encoding)
     quiz.load()
-    ui = _cli.CommandLineInterface(quiz=quiz)
+    answers = _answerdb.AnswerDatabase(path=args.answers, encoding=encoding)
+    try:
+        answers.load()
+    except IOError:
+        pass
+    ui = _cli.CommandLineInterface(quiz=quiz, answers=answers)
     ui.run()
+    ui.answers.save()
     ui.display_results()
index 69b120315e463074075c3e91e6e0c84a770441f7..8fb8b06cd9f53abfca28a39d7e031699b24687cb 100644 (file)
@@ -1,9 +1,12 @@
+from .. import answerdb as _answerdb
+
+
 class UserInterface (object):
     "Give a quiz over a generic user interface"
     def __init__(self, quiz=None, answers=None):
         self.quiz = quiz
         if answers is None:
-            answers = {}
+            answers = _answerdb.AnswerDatabase()
         self.answers = answers
 
     def run(self):
@@ -13,30 +16,11 @@ class UserInterface (object):
         raise NotImplementedError()
 
     def get_question(self):
-        remaining = self.get_unanswered()
+        remaining = self.answers.get_unanswered(questions=self.quiz)
         if remaining:
             return remaining[0]
 
     def process_answer(self, question, answer):
-        if question not in self.answers:
-            self.answers[question] = []
         correct = question.check(answer)
-        self.answers[question].append({
-                'answer': answer,
-                'correct': correct,
-                })
+        self.answers.add(question=question, answer=answer, correct=correct)
         return correct
-
-    def get_answered(self):
-        return [q for q in self.quiz if q in self.answers]
-
-    def get_unanswered(self):
-        return [q for q in self.quiz if q not in self.answers]
-
-    def get_correctly_answered(self):
-        return [q for q in self.quiz
-                if True in [a['correct'] for a in self.answers.get(q, [])]]
-
-    def get_never_correctly_answered(self):
-        return [q for q in self.quiz
-                if True not in [a['correct'] for a in self.answers.get(q, [])]]
index 385a9a4044ff51b4c78a79b60837a85a23656688..9b17339166fed9df843d5da511a62d17b83e21c8 100644 (file)
@@ -29,9 +29,8 @@ class CommandLineInterface (UserInterface):
     def display_results(self):
         print('results:')
         for question in self.quiz:
-            if question in self.answers:
-                for answer in self.answers[question]:
-                    self.display_result(question=question, answer=answer)
+            for answer in self.answers.get(question.prompt, []):
+                self.display_result(question=question, answer=answer)
         self.display_totals()
 
     def display_result(self, question, answer):
@@ -45,8 +44,9 @@ class CommandLineInterface (UserInterface):
         print()
 
     def display_totals(self):
-        answered = self.get_answered()
-        correctly_answered = self.get_correctly_answered()
+        answered = self.answers.get_answered(questions=self.quiz)
+        correctly_answered = self.answers.get_correctly_answered(
+            questions=self.quiz)
         la = len(answered)
         lc = len(correctly_answered)
         print('answered {} of {} questions'.format(la, len(self.quiz)))