a9d00c06d54aa6491489b90a903eacedff37d1b4
[quizzer.git] / quizzer / question.py
1 import logging as _logging
2 import os.path as _os_path
3 import tempfile as _tempfile
4
5 from . import error as _error
6 from . import util as _util
7
8
9 LOG = _logging.getLogger(__name__)
10 QUESTION_CLASS = {}
11
12
13 def register_question(question_class):
14     QUESTION_CLASS[question_class.__name__] = question_class
15
16
17 class Question (object):
18     _state_attributes = [
19         'id',
20         'prompt',
21         'answer',
22         'help',
23         'dependencies',
24         ]
25
26     def __init__(self, **kwargs):
27         self.__setstate__(kwargs)
28
29     def __str__(self):
30         return '<{} id:{!r}>'.format(type(self).__name__, self.id)
31
32     def __repr__(self):
33         return '<{} id:{!r} at {:#x}>'.format(
34             type(self).__name__, self.id, id(self))
35
36     def __getstate__(self):
37         return {attr: getattr(self, attr)
38                 for attr in self._state_attributes} 
39
40     def __setstate__(self, state):
41         if 'id' not in state:
42             state['id'] = state.get('prompt', None)
43         if 'dependencies' not in state:
44             state['dependencies'] = []
45         for attr in self._state_attributes:
46             if attr not in state:
47                 state[attr] = None
48         self.__dict__.update(state)
49
50     def check(self, answer):
51         return answer == self.answer
52
53
54 class NormalizedStringQuestion (Question):
55     def normalize(self, string):
56         return string.strip().lower()
57
58     def check(self, answer):
59         return self.normalize(answer) == self.normalize(self.answer)
60
61
62 class ChoiceQuestion (Question):
63     def check(self, answer):
64         return answer in self.answer
65
66
67 class ScriptQuestion (Question):
68     _state_attributes = Question._state_attributes + [
69         'interpreter',
70         'setup',
71         'teardown',
72         'timeout',
73         ]
74
75     def __setstate__(self, state):
76         if 'interpreter' not in state:
77             state['interpreter'] = 'sh'  # POSIX-compatible shell
78         if 'timeout' not in state:
79             state['timeout'] = 3
80         for key in ['setup', 'teardown']:
81             if key not in state:
82                 state[key] = []
83         super(ScriptQuestion, self).__setstate__(state)
84
85     def check(self, answer):
86         # figure out the expected values
87         e_status,e_stdout,e_stderr = self._invoke(self.answer)
88         # get values for the user-supplied answer
89         try:
90             a_status,a_stdout,a_stderr = self._invoke(answer)
91         except _error.CommandError as e:
92             LOG.warning(e)
93             return False
94         for (name, e, a) in [
95                 ('stderr', e_stderr, a_stderr),
96                 ('status', e_status, a_status),
97                 ('stdout', e_stdout, a_stdout),
98                 ]:
99             if a != e:
100                 if name == 'status':
101                     LOG.info(
102                         'missmatched {}, expected {!r} but got {!r}'.format(
103                             name, e, a))
104                 else:
105                     LOG.info('missmatched {}, expected:'.format(name))
106                     LOG.info(e)
107                     LOG.info('but got:')
108                     LOG.info(a)
109                 return False
110         return True
111
112     def _invoke(self, answer):
113         prefix = '{}-'.format(type(self).__name__)
114         with _tempfile.TemporaryDirectory(prefix=prefix) as tempdir:
115             script = '\n'.join(self.setup + [answer] + self.teardown)
116             status,stdout,stderr = _util.invoke(
117                 args=[self.interpreter],
118                 stdin=script,
119                 cwd=tempdir,
120                 universal_newlines=True,
121                 timeout=self.timeout,
122                 )
123             dirname = _os_path.basename(tempdir)
124         stdout = stdout.replace(dirname, '{}XXXXXX'.format(prefix))
125         stderr = stderr.replace(dirname, '{}XXXXXX'.format(prefix))
126         return status,stdout,stderr
127
128 for name,obj in list(locals().items()):
129     if name.startswith('_'):
130         continue
131     try:
132         subclass = issubclass(obj, Question)
133     except TypeError:  # obj is not a class
134         continue
135     if subclass:
136         register_question(obj)
137 del name, obj