1 import logging as _logging
2 import tempfile as _tempfile
4 from . import error as _error
5 from . import util as _util
8 LOG = _logging.getLogger(__name__)
12 def register_question(question_class):
13 QUESTION_CLASS[question_class.__name__] = question_class
16 class Question (object):
25 def __init__(self, **kwargs):
26 self.__setstate__(kwargs)
29 return '<{} id:{!r}>'.format(type(self).__name__, self.id)
32 return '<{} id:{!r} at {:#x}>'.format(
33 type(self).__name__, self.id, id(self))
35 def __getstate__(self):
36 return {attr: getattr(self, attr)
37 for attr in self._state_attributes}
39 def __setstate__(self, state):
41 state['id'] = state.get('prompt', None)
42 if 'dependencies' not in state:
43 state['dependencies'] = []
44 for attr in self._state_attributes:
47 self.__dict__.update(state)
49 def check(self, answer):
50 return answer == self.answer
53 class NormalizedStringQuestion (Question):
54 def normalize(self, string):
55 return string.strip().lower()
57 def check(self, answer):
58 return self.normalize(answer) == self.normalize(self.answer)
61 class ScriptQuestion (Question):
62 _state_attributes = Question._state_attributes + [
69 def __setstate__(self, state):
70 if 'interpreter' not in state:
71 state['interpreter'] = 'sh' # POSIX-compatible shell
72 if 'timeout' not in state:
74 for key in ['setup', 'teardown']:
77 super(ScriptQuestion, self).__setstate__(state)
79 def check(self, answer):
80 # figure out the expected values
81 e_status,e_stdout,e_stderr = self._invoke(self.answer)
82 # get values for the user-supplied answer
84 a_status,a_stdout,a_stderr = self._invoke(answer)
85 except _error.CommandError as e:
89 ('stderr', e_stderr, a_stderr),
90 ('status', e_status, a_status),
91 ('stdout', e_stdout, a_stdout),
96 'missmatched {}, expected {!r} but got {!r}'.format(
99 LOG.info('missmatched {}, expected:'.format(name))
106 def _invoke(self, answer):
107 with _tempfile.TemporaryDirectory(
108 prefix='{}-'.format(type(self).__name__),
110 script = '\n'.join(self.setup + [answer] + self.teardown)
112 args=[self.interpreter],
115 universal_newlines=True,
116 timeout=self.timeout,)
118 for name,obj in list(locals().items()):
119 if name.startswith('_'):
122 subclass = issubclass(obj, Question)
123 except TypeError: # obj is not a class
126 register_question(obj)