Add Question.multiline and associated handling
[quizzer.git] / quizzer / question.py
1 import logging as _logging
2 import os.path as _os_path
3 import tempfile as _tempfile
4
5 from . import error as _error
6 from . import util as _util
7
8
9 LOG = _logging.getLogger(__name__)
10 QUESTION_CLASS = {}
11
12
13 def register_question(question_class):
14     QUESTION_CLASS[question_class.__name__] = question_class
15
16
17 class Question (object):
18     _state_attributes = [
19         'id',
20         'prompt',
21         'answer',
22         'multiline',
23         'help',
24         'dependencies',
25         ]
26
27     def __init__(self, **kwargs):
28         self.__setstate__(kwargs)
29
30     def __str__(self):
31         return '<{} id:{!r}>'.format(type(self).__name__, self.id)
32
33     def __repr__(self):
34         return '<{} id:{!r} at {:#x}>'.format(
35             type(self).__name__, self.id, id(self))
36
37     def __getstate__(self):
38         return {attr: getattr(self, attr)
39                 for attr in self._state_attributes} 
40
41     def __setstate__(self, state):
42         if 'id' not in state:
43             state['id'] = state.get('prompt', None)
44         if 'multiline' not in state:
45             state['multiline'] = False
46         if 'dependencies' not in state:
47             state['dependencies'] = []
48         for attr in self._state_attributes:
49             if attr not in state:
50                 state[attr] = None
51         self.__dict__.update(state)
52
53     def check(self, answer):
54         return answer == self.answer
55
56
57 class NormalizedStringQuestion (Question):
58     def normalize(self, string):
59         return string.strip().lower()
60
61     def check(self, answer):
62         return self.normalize(answer) == self.normalize(self.answer)
63
64
65 class ChoiceQuestion (Question):
66     def check(self, answer):
67         return answer in self.answer
68
69
70 class ScriptQuestion (Question):
71     _state_attributes = Question._state_attributes + [
72         'interpreter',
73         'setup',
74         'teardown',
75         'timeout',
76         ]
77
78     def __setstate__(self, state):
79         if 'interpreter' not in state:
80             state['interpreter'] = 'sh'  # POSIX-compatible shell
81         if 'timeout' not in state:
82             state['timeout'] = 3
83         for key in ['setup', 'teardown']:
84             if key not in state:
85                 state[key] = []
86         super(ScriptQuestion, self).__setstate__(state)
87
88     def check(self, answer):
89         # figure out the expected values
90         e_status,e_stdout,e_stderr = self._invoke(self.answer)
91         # get values for the user-supplied answer
92         try:
93             a_status,a_stdout,a_stderr = self._invoke(answer)
94         except _error.CommandError as e:
95             LOG.warning(e)
96             return False
97         for (name, e, a) in [
98                 ('stderr', e_stderr, a_stderr),
99                 ('status', e_status, a_status),
100                 ('stdout', e_stdout, a_stdout),
101                 ]:
102             if a != e:
103                 if name == 'status':
104                     LOG.info(
105                         'missmatched {}, expected {!r} but got {!r}'.format(
106                             name, e, a))
107                 else:
108                     LOG.info('missmatched {}, expected:'.format(name))
109                     LOG.info(e)
110                     LOG.info('but got:')
111                     LOG.info(a)
112                 return False
113         return True
114
115     def _invoke(self, answer):
116         prefix = '{}-'.format(type(self).__name__)
117         if not self.multiline:
118             answer = [answer]
119         with _tempfile.TemporaryDirectory(prefix=prefix) as tempdir:
120             script = '\n'.join(self.setup + answer + self.teardown)
121             status,stdout,stderr = _util.invoke(
122                 args=[self.interpreter],
123                 stdin=script,
124                 cwd=tempdir,
125                 universal_newlines=True,
126                 timeout=self.timeout,
127                 )
128             dirname = _os_path.basename(tempdir)
129         stdout = stdout.replace(dirname, '{}XXXXXX'.format(prefix))
130         stderr = stderr.replace(dirname, '{}XXXXXX'.format(prefix))
131         return status,stdout,stderr
132
133 for name,obj in list(locals().items()):
134     if name.startswith('_'):
135         continue
136     try:
137         subclass = issubclass(obj, Question)
138     except TypeError:  # obj is not a class
139         continue
140     if subclass:
141         register_question(obj)
142 del name, obj