Transitioned severity to Command-format, also added Command._get_*()
[be.git] / libbe / command / base.py
1 # Copyright
2
3 import codecs
4 import optparse
5 import os.path
6 import sys
7
8 import libbe
9 import libbe.ui.util.user
10 import libbe.util.encoding
11 import libbe.util.plugin
12
13 class UserError(Exception):
14     pass
15
16 class UnknownCommand(UserError):
17     def __init__(self, cmd):
18         Exception.__init__(self, "Unknown command '%s'" % cmd)
19         self.cmd = cmd
20
21
22 def get_command(command_name):
23     """Retrieves the module for a user command
24
25     >>> try:
26     ...     get_command('asdf')
27     ... except UnknownCommand, e:
28     ...     print e
29     Unknown command 'asdf'
30     >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
31     True
32     """
33     try:
34         cmd = libbe.util.plugin.import_by_name(
35             'libbe.command.%s' % command_name.replace("-", "_"))
36     except ImportError, e:
37         raise UnknownCommand(command_name)
38     return cmd
39
40 def get_command_class(module, command_name):
41     """Retrieves a command class from a module.
42
43     >>> import_xml_mod = get_command('import-xml')
44     >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
45     >>> repr(import_xml)
46     "<class 'libbe.command.import_xml.Import_XML'>"
47     """
48     try:
49         cname = command_name.capitalize().replace('-', '_')
50         cmd = getattr(module, cname)
51     except ImportError, e:
52         raise UnknownCommand(command_name)
53     return cmd
54
55 def commands():
56     for modname in libbe.util.plugin.modnames('libbe.command'):
57         if modname not in ['base', 'util']:
58             yield modname
59
60 class CommandInput (object):
61     def __init__(self, name, help=''):
62         self.name = name
63         self.help = help
64
65 class Argument (CommandInput):
66     def __init__(self, metavar=None, default=None, type='string',
67                  optional=False, repeatable=False,
68                  completion_callback=None, *args, **kwargs):
69         CommandInput.__init__(self, *args, **kwargs)
70         self.metavar = metavar
71         self.default = default
72         self.type = type
73         self.optional = optional
74         self.repeatable = repeatable
75         self.completion_callback = completion_callback
76         if self.metavar == None:
77             self.metavar = self.name.upper()
78
79 class Option (CommandInput):
80     def __init__(self, callback=None, short_name=None, arg=None,
81                  *args, **kwargs):
82         CommandInput.__init__(self, *args, **kwargs)
83         self.callback = callback
84         self.short_name = short_name
85         self.arg = arg
86         if self.arg == None and self.callback == None:
87             # use an implicit boolean argument
88             self.arg = Argument(name=self.name, help=self.help,
89                                 default=False, type='bool')
90         self.validate()
91
92     def validate(self):
93         if self.arg == None:
94             assert self.callback != None, self.name
95             return
96         assert self.callback == None, '%s: %s' (self.name, self.callback)
97         assert self.arg.name == self.name, \
98             'Name missmatch: %s != %s' % (self.arg.name, self.name)
99         assert self.arg.optional == False, self.name
100         assert self.arg.repeatable == False, self.name
101
102     def __str__(self):
103         return '--%s' % self.name
104
105     def __repr__(self):
106         return '<Option %s>' % self.__str__()
107
108 class _DummyParser (optparse.OptionParser):
109     def __init__(self, command):
110         optparse.OptionParser.__init__(self)
111         self.remove_option('-h')
112         self.command = command
113         self._command_opts = []
114         for option in self.command.options:
115             self._add_option(option)
116
117     def _add_option(self, option):
118         # from libbe.ui.command_line.CmdOptionParser._add_option
119         option.validate()
120         long_opt = '--%s' % option.name
121         if option.short_name != None:
122             short_opt = '-%s' % option.short_name
123         assert '_' not in option.name, \
124             'Non-reconstructable option name %s' % option.name
125         kwargs = {'dest':option.name.replace('-', '_'),
126                   'help':option.help}
127         if option.arg == None or option.arg.type == 'bool':
128             kwargs['action'] = 'store_true'
129             kwargs['metavar'] = None
130             kwargs['default'] = False
131         else:
132             kwargs['type'] = option.arg.type
133             kwargs['action'] = 'store'
134             kwargs['metavar'] = option.arg.metavar
135             kwargs['default'] = option.arg.default
136         if option.short_name != None:
137             opt = optparse.Option(short_opt, long_opt, **kwargs)
138         else:
139             opt = optparse.Option(long_opt, **kwargs)
140         #option.takes_value = lambda : option.arg != None
141         opt._option = option
142         self._command_opts.append(opt)
143         self.add_option(opt)
144
145 class OptionFormatter (optparse.IndentedHelpFormatter):
146     def __init__(self, command):
147         optparse.IndentedHelpFormatter.__init__(self)
148         self.command = command
149     def option_help(self):
150         # based on optparse.OptionParser.format_option_help()
151         parser = _DummyParser(self.command)
152         self.store_option_strings(parser)
153         ret = []
154         ret.append(self.format_heading('Options'))
155         self.indent()
156         for option in parser._command_opts:
157             ret.append(self.format_option(option))
158             ret.append('\n')
159         self.dedent()
160         # Drop the last '\n', or the header if no options or option groups:
161         return ''.join(ret[:-1])
162
163 class Command (object):
164     """One-line command description.
165
166     >>> c = Command()
167     >>> print c.help()
168     usage: be command [options]
169     <BLANKLINE>
170     Options:
171       -h, --help  Print a help message.
172     <BLANKLINE>
173       --complete  Print a list of possible completions.
174     <BLANKLINE>
175     A detailed help message.
176     """
177
178     name = 'command'
179
180     def __init__(self, input_encoding=None, output_encoding=None,
181                  get_unconnected_storage=None, ui=None):
182         self.input_encoding = input_encoding
183         self.output_encoding = output_encoding
184         self.get_unconnected_storage = get_unconnected_storage
185         self.ui = ui # calling user-interface, e.g. for Help()
186         self.status = None
187         self.result = None
188         self.restrict_file_access = True
189         self.options = [
190             Option(name='help', short_name='h',
191                 help='Print a help message.',
192                 callback=self.help),
193             Option(name='complete',
194                 help='Print a list of possible completions.',
195                 callback=self.complete),
196                 ]
197         self.args = []
198
199     def run(self, options=None, args=None):
200         if options == None:
201             options = {}
202         if args == None:
203             args = []
204         params = {}
205         for option in self.options:
206             assert option.name not in params, params[option.name]
207             if option.name in options:
208                 params[option.name] = options.pop(option.name)
209             elif option.arg != None:
210                 params[option.name] = option.arg.default
211             else: # non-arg options are flags, set to default flag value
212                 params[option.name] = False
213         assert 'user-id' not in params, params['user-id']
214         if 'user-id' in options:
215             self._user_id = options.pop('user-id')
216         if len(options) > 0:
217             raise UserError, 'Invalid option passed to command %s:\n  %s' \
218                 % (self.name, '\n  '.join(['%s: %s' % (k,v)
219                                            for k,v in options.items()]))
220         in_optional_args = False
221         for i,arg in enumerate(self.args):
222             if arg.repeatable == True:
223                 assert i == len(self.args)-1, arg.name
224             if in_optional_args == True:
225                 assert arg.optional == True, arg.name
226             else:
227                 in_optional_args = arg.optional
228             if i < len(args):
229                 if arg.repeatable == True:
230                     params[arg.name] = [args[i]]
231                 else:
232                     params[arg.name] = args[i]
233             else:  # no value given
234                 assert in_optional_args == True, arg.name
235                 if arg.repeatable == True:
236                     params[arg.name] = [arg.default]
237                 else:
238                     params[arg.name] = arg.default
239         if len(args) > len(self.args):  # add some additional repeats
240             assert self.args[-1].repeatable == True, self.args[-1].name
241             params[self.args[-1].name].extend(args[len(self.args):])
242
243         if params['help'] == True:
244             pass
245         else:
246             params.pop('help')
247         if params['complete'] != None:
248             pass
249         else:
250             params.pop('complete')
251
252         self._setup_io(self.input_encoding, self.output_encoding)
253         self.status = self._run(**params)
254         return self.status
255
256     def _run(self, **kwargs):
257         raise NotImplementedError
258
259     def _setup_io(self, input_encoding=None, output_encoding=None):
260         if input_encoding == None:
261             input_encoding = libbe.util.encoding.get_input_encoding()
262         if output_encoding == None:
263             output_encoding = libbe.util.encoding.get_output_encoding()
264         self.stdin = codecs.getwriter(input_encoding)(sys.stdin)
265         self.stdin.encoding = input_encoding
266         self.stdout = codecs.getwriter(output_encoding)(sys.stdout)
267         self.stdout.encoding = output_encoding
268
269     def help(self, *args):       
270         return '\n\n'.join([self._usage(),
271                             self._option_help(),
272                             self._long_help()])
273
274     def _usage(self):
275         usage = 'usage: be %s [options]' % self.name
276         num_optional = 0
277         for arg in self.args:
278             usage += ' '
279             if arg.optional == True:
280                 usage += '['
281                 num_optional += 1
282             usage += arg.metavar
283             if arg.repeatable == True:
284                 usage += ' ...'
285         usage += ']'*num_optional
286         return usage
287
288     def _option_help(self):
289         o = OptionFormatter(self)
290         return o.option_help().strip('\n')
291
292     def _long_help(self):
293         return "A detailed help message."
294
295     def complete(self, argument=None, fragment=None):
296         if argument == None:
297             ret = ['--%s' % o.name for o in self.options]
298             if len(self.args) > 0 and self.args[0].completion_callback != None:
299                 ret.extend(self.args[0].completion_callback(self, argument))
300             return ret
301         elif argument.completion_callback != None:
302             # finish a particular argument
303             return argument.completion_callback(self, argument, fragment)
304         return [] # the particular argument doesn't supply completion info
305
306     def _check_restricted_access(self, storage, path):
307         """
308         Check that the file at path is inside bugdir.root.  This is
309         important if you allow other users to execute becommands with
310         your username (e.g. if you're running be-handle-mail through
311         your ~/.procmailrc).  If this check wasn't made, a user could
312         e.g.  run
313           be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
314         which would expose your ssh key to anyone who could read the
315         VCS log.
316
317         >>> class DummyStorage (object): pass
318         >>> s = DummyStorage()
319         >>> s.repo = os.path.expanduser('~/x/')
320         >>> c = Command()
321         >>> try:
322         ...     c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
323         ... except UserError, e:
324         ...     assert str(e).startswith('file access restricted!'), str(e)
325         ...     print 'we got the expected error'
326         we got the expected error
327         >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
328         >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
329         >>> c.restrict_file_access = False
330         >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
331         """
332         if self.restrict_file_access == True:
333             path = os.path.abspath(path)
334             repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
335             if path == repo or path.startswith(repo+os.path.sep):
336                 return
337             raise UserError('file access restricted!\n  %s not in %s'
338                             % (path, repo))
339
340     def _get_unconnected_storage(self):
341         """Callback for use by commands that need it."""
342         if not hasattr(self, '_unconnected_storage'):
343             if self.get_unconnected_storage == None:
344                 raise NotImplementedError
345             self._unconnected_storage = self.get_unconnected_storage()
346         return self._unconnected_storage
347
348     def _get_storage(self):
349         """
350         Callback for use by commands that need it.
351         
352         Note that with the current implementation,
353         _get_unconnected_storage() will not work after this method
354         runs, but that shouldn't be an issue for any command I can
355         think of...
356         """
357         if not hasattr(self, '_storage'):
358             self._storage = self._get_unconnected_storage()
359             self._storage.connect()
360         return self._storage
361
362     def _get_bugdir(self):
363         """Callback for use by commands that need it."""
364         if not hasattr(self, '_bugdir'):
365             self._bugdir = libbe.bugdir.BugDir(self._get_storage(), from_storage=True)
366         return self._bugdir
367
368     def _get_user_id(self):
369         """Callback for use by commands that need it."""
370         if not hasattr(self, '_user_id'):
371             self._user_id = libbe.ui.util.user.get_user_id(self._get_storage())
372         return self._user_id
373
374     def cleanup(self):
375         if hasattr(self, '_storage'):
376             self._storage.disconnect()