Add Question.multiline and associated handling
[quizzer.git] / quizzer / question.py
index f2e8ed037cb725f24189ad1b088b3665040b78cc..64594cf14cfa2fbab6a5a6d3c90cd7a41a20ddc0 100644 (file)
@@ -1,3 +1,12 @@
+import logging as _logging
+import os.path as _os_path
+import tempfile as _tempfile
+
+from . import error as _error
+from . import util as _util
+
+
+LOG = _logging.getLogger(__name__)
 QUESTION_CLASS = {}
 
 
@@ -10,21 +19,13 @@ class Question (object):
         'id',
         'prompt',
         'answer',
+        'multiline',
         'help',
         'dependencies',
         ]
 
-    def __init__(self, id=None, prompt=None, answer=None, help=None,
-                 dependencies=None):
-        if id is None:
-            id = prompt
-        self.id = id
-        self.prompt = prompt
-        self.answer = answer
-        self.help = help
-        if dependencies is None:
-            dependencies = []
-        self.dependencies = dependencies
+    def __init__(self, **kwargs):
+        self.__setstate__(kwargs)
 
     def __str__(self):
         return '<{} id:{!r}>'.format(type(self).__name__, self.id)
@@ -40,8 +41,13 @@ class Question (object):
     def __setstate__(self, state):
         if 'id' not in state:
             state['id'] = state.get('prompt', None)
+        if 'multiline' not in state:
+            state['multiline'] = False
         if 'dependencies' not in state:
             state['dependencies'] = []
+        for attr in self._state_attributes:
+            if attr not in state:
+                state[attr] = None
         self.__dict__.update(state)
 
     def check(self, answer):
@@ -56,6 +62,74 @@ class NormalizedStringQuestion (Question):
         return self.normalize(answer) == self.normalize(self.answer)
 
 
+class ChoiceQuestion (Question):
+    def check(self, answer):
+        return answer in self.answer
+
+
+class ScriptQuestion (Question):
+    _state_attributes = Question._state_attributes + [
+        'interpreter',
+        'setup',
+        'teardown',
+        'timeout',
+        ]
+
+    def __setstate__(self, state):
+        if 'interpreter' not in state:
+            state['interpreter'] = 'sh'  # POSIX-compatible shell
+        if 'timeout' not in state:
+            state['timeout'] = 3
+        for key in ['setup', 'teardown']:
+            if key not in state:
+                state[key] = []
+        super(ScriptQuestion, self).__setstate__(state)
+
+    def check(self, answer):
+        # figure out the expected values
+        e_status,e_stdout,e_stderr = self._invoke(self.answer)
+        # get values for the user-supplied answer
+        try:
+            a_status,a_stdout,a_stderr = self._invoke(answer)
+        except _error.CommandError as e:
+            LOG.warning(e)
+            return False
+        for (name, e, a) in [
+                ('stderr', e_stderr, a_stderr),
+                ('status', e_status, a_status),
+                ('stdout', e_stdout, a_stdout),
+                ]:
+            if a != e:
+                if name == 'status':
+                    LOG.info(
+                        'missmatched {}, expected {!r} but got {!r}'.format(
+                            name, e, a))
+                else:
+                    LOG.info('missmatched {}, expected:'.format(name))
+                    LOG.info(e)
+                    LOG.info('but got:')
+                    LOG.info(a)
+                return False
+        return True
+
+    def _invoke(self, answer):
+        prefix = '{}-'.format(type(self).__name__)
+        if not self.multiline:
+            answer = [answer]
+        with _tempfile.TemporaryDirectory(prefix=prefix) as tempdir:
+            script = '\n'.join(self.setup + answer + self.teardown)
+            status,stdout,stderr = _util.invoke(
+                args=[self.interpreter],
+                stdin=script,
+                cwd=tempdir,
+                universal_newlines=True,
+                timeout=self.timeout,
+                )
+            dirname = _os_path.basename(tempdir)
+        stdout = stdout.replace(dirname, '{}XXXXXX'.format(prefix))
+        stderr = stderr.replace(dirname, '{}XXXXXX'.format(prefix))
+        return status,stdout,stderr
+
 for name,obj in list(locals().items()):
     if name.startswith('_'):
         continue