1 import logging as _logging
2 import os.path as _os_path
3 import tempfile as _tempfile
5 from . import error as _error
6 from . import util as _util
9 LOG = _logging.getLogger(__name__)
13 def register_question(question_class):
14 QUESTION_CLASS[question_class.__name__] = question_class
17 class Question (object):
26 def __init__(self, **kwargs):
27 self.__setstate__(kwargs)
30 return '<{} id:{!r}>'.format(type(self).__name__, self.id)
33 return '<{} id:{!r} at {:#x}>'.format(
34 type(self).__name__, self.id, id(self))
36 def __getstate__(self):
37 return {attr: getattr(self, attr)
38 for attr in self._state_attributes}
40 def __setstate__(self, state):
42 state['id'] = state.get('prompt', None)
43 if 'dependencies' not in state:
44 state['dependencies'] = []
45 for attr in self._state_attributes:
48 self.__dict__.update(state)
50 def check(self, answer):
51 return answer == self.answer
54 class NormalizedStringQuestion (Question):
55 def normalize(self, string):
56 return string.strip().lower()
58 def check(self, answer):
59 return self.normalize(answer) == self.normalize(self.answer)
62 class ChoiceQuestion (Question):
63 def check(self, answer):
64 return answer in self.answer
67 class ScriptQuestion (Question):
68 _state_attributes = Question._state_attributes + [
75 def __setstate__(self, state):
76 if 'interpreter' not in state:
77 state['interpreter'] = 'sh' # POSIX-compatible shell
78 if 'timeout' not in state:
80 for key in ['setup', 'teardown']:
83 super(ScriptQuestion, self).__setstate__(state)
85 def check(self, answer):
86 # figure out the expected values
87 e_status,e_stdout,e_stderr = self._invoke(self.answer)
88 # get values for the user-supplied answer
90 a_status,a_stdout,a_stderr = self._invoke(answer)
91 except _error.CommandError as e:
95 ('stderr', e_stderr, a_stderr),
96 ('status', e_status, a_status),
97 ('stdout', e_stdout, a_stdout),
102 'missmatched {}, expected {!r} but got {!r}'.format(
105 LOG.info('missmatched {}, expected:'.format(name))
112 def _invoke(self, answer):
113 prefix = '{}-'.format(type(self).__name__)
114 with _tempfile.TemporaryDirectory(prefix=prefix) as tempdir:
115 script = '\n'.join(self.setup + [answer] + self.teardown)
116 status,stdout,stderr = _util.invoke(
117 args=[self.interpreter],
120 universal_newlines=True,
121 timeout=self.timeout,
123 dirname = _os_path.basename(tempdir)
124 stdout = stdout.replace(dirname, '{}XXXXXX'.format(prefix))
125 stderr = stderr.replace(dirname, '{}XXXXXX'.format(prefix))
126 return status,stdout,stderr
128 for name,obj in list(locals().items()):
129 if name.startswith('_'):
132 subclass = issubclass(obj, Question)
133 except TypeError: # obj is not a class
136 register_question(obj)