1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25 import libbe.ui.util.user
26 import libbe.util.encoding
27 import libbe.util.plugin
29 class UserError(Exception):
32 class UnknownCommand(UserError):
33 def __init__(self, cmd):
34 Exception.__init__(self, "Unknown command '%s'" % cmd)
37 def get_command(command_name):
38 """Retrieves the module for a user command
41 ... get_command('asdf')
42 ... except UnknownCommand, e:
44 Unknown command 'asdf'
45 >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
49 cmd = libbe.util.plugin.import_by_name(
50 'libbe.command.%s' % command_name.replace("-", "_"))
51 except ImportError, e:
52 raise UnknownCommand(command_name)
55 def get_command_class(module=None, command_name=None):
56 """Retrieves a command class from a module.
58 >>> import_xml_mod = get_command('import-xml')
59 >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
61 "<class 'libbe.command.import_xml.Import_XML'>"
62 >>> import_xml = get_command_class(command_name='import-xml')
64 "<class 'libbe.command.import_xml.Import_XML'>"
67 module = get_command(command_name)
69 cname = command_name.capitalize().replace('-', '_')
70 cmd = getattr(module, cname)
71 except ImportError, e:
72 raise UnknownCommand(command_name)
75 def modname_to_command_name(modname):
76 """Little hack to replicate
78 >>> def real_modname_to_command_name(modname):
79 ... mod = libbe.util.plugin.import_by_name(
80 ... 'libbe.command.%s' % modname)
81 ... attrs = [getattr(mod, name) for name in dir(mod)]
83 ... for attr_name in dir(mod):
84 ... attr = getattr(mod, attr_name)
86 ... if issubclass(attr, Command):
87 ... commands.append(attr)
88 ... except TypeError, e:
90 ... if len(commands) == 0:
91 ... raise Exception('No Command classes in %s' % dir(mod))
92 ... return commands[0].name
93 >>> real_modname_to_command_name('new')
95 >>> real_modname_to_command_name('import_xml')
98 return modname.replace('_', '-')
100 def commands(command_names=False):
101 for modname in libbe.util.plugin.modnames('libbe.command'):
102 if modname not in ['base', 'util']:
103 if command_names == False:
106 yield modname_to_command_name(modname)
108 class CommandInput (object):
109 def __init__(self, name, help=''):
114 return '<%s %s>' % (self.__class__.__name__, self.name)
117 return self.__str__()
119 class Argument (CommandInput):
120 def __init__(self, metavar=None, default=None, type='string',
121 optional=False, repeatable=False,
122 completion_callback=None, *args, **kwargs):
123 CommandInput.__init__(self, *args, **kwargs)
124 self.metavar = metavar
125 self.default = default
127 self.optional = optional
128 self.repeatable = repeatable
129 self.completion_callback = completion_callback
130 if self.metavar == None:
131 self.metavar = self.name.upper()
133 class Option (CommandInput):
134 def __init__(self, callback=None, short_name=None, arg=None,
136 CommandInput.__init__(self, *args, **kwargs)
137 self.callback = callback
138 self.short_name = short_name
140 if self.arg == None and self.callback == None:
141 # use an implicit boolean argument
142 self.arg = Argument(name=self.name, help=self.help,
143 default=False, type='bool')
148 assert self.callback != None, self.name
150 assert self.callback == None, '%s: %s' (self.name, self.callback)
151 assert self.arg.name == self.name, \
152 'Name missmatch: %s != %s' % (self.arg.name, self.name)
153 assert self.arg.optional == False, self.name
154 assert self.arg.repeatable == False, self.name
157 return '--%s' % self.name
160 return '<Option %s>' % self.__str__()
162 class _DummyParser (optparse.OptionParser):
163 def __init__(self, command):
164 optparse.OptionParser.__init__(self)
165 self.remove_option('-h')
166 self.command = command
167 self._command_opts = []
168 for option in self.command.options:
169 self._add_option(option)
171 def _add_option(self, option):
172 # from libbe.ui.command_line.CmdOptionParser._add_option
174 long_opt = '--%s' % option.name
175 if option.short_name != None:
176 short_opt = '-%s' % option.short_name
177 assert '_' not in option.name, \
178 'Non-reconstructable option name %s' % option.name
179 kwargs = {'dest':option.name.replace('-', '_'),
181 if option.arg == None or option.arg.type == 'bool':
182 kwargs['action'] = 'store_true'
183 kwargs['metavar'] = None
184 kwargs['default'] = False
186 kwargs['type'] = option.arg.type
187 kwargs['action'] = 'store'
188 kwargs['metavar'] = option.arg.metavar
189 kwargs['default'] = option.arg.default
190 if option.short_name != None:
191 opt = optparse.Option(short_opt, long_opt, **kwargs)
193 opt = optparse.Option(long_opt, **kwargs)
194 #option.takes_value = lambda : option.arg != None
196 self._command_opts.append(opt)
199 class OptionFormatter (optparse.IndentedHelpFormatter):
200 def __init__(self, command):
201 optparse.IndentedHelpFormatter.__init__(self)
202 self.command = command
203 def option_help(self):
204 # based on optparse.OptionParser.format_option_help()
205 parser = _DummyParser(self.command)
206 self.store_option_strings(parser)
208 ret.append(self.format_heading('Options'))
210 for option in parser._command_opts:
211 ret.append(self.format_option(option))
214 # Drop the last '\n', or the header if no options or option groups:
215 return ''.join(ret[:-1])
217 class Command (object):
218 """One-line command description here.
222 usage: be command [options]
225 -h, --help Print a help message.
227 --complete Print a list of possible completions.
229 A detailed help message.
234 def __init__(self, ui=None):
235 self.ui = ui # calling user-interface
238 self.restrict_file_access = True
240 Option(name='help', short_name='h',
241 help='Print a help message.',
243 Option(name='complete',
244 help='Print a list of possible completions.',
245 callback=self.complete),
249 def run(self, options=None, args=None):
250 self.status = 1 # in case we raise an exception
251 params = self._parse_options_args(options, args)
252 if params['help'] == True:
256 if params['complete'] != None:
259 params.pop('complete')
261 self.status = self._run(**params)
264 def _parse_options_args(self, options=None, args=None):
270 for option in self.options:
271 assert option.name not in params, params[option.name]
272 if option.name in options:
273 params[option.name] = options.pop(option.name)
274 elif option.arg != None:
275 params[option.name] = option.arg.default
276 else: # non-arg options are flags, set to default flag value
277 params[option.name] = False
278 assert 'user-id' not in params, params['user-id']
279 if 'user-id' in options:
280 self._user_id = options.pop('user-id')
282 raise UserError, 'Invalid option passed to command %s:\n %s' \
283 % (self.name, '\n '.join(['%s: %s' % (k,v)
284 for k,v in options.items()]))
285 in_optional_args = False
286 for i,arg in enumerate(self.args):
287 if arg.repeatable == True:
288 assert i == len(self.args)-1, arg.name
289 if in_optional_args == True:
290 assert arg.optional == True, arg.name
292 in_optional_args = arg.optional
294 if arg.repeatable == True:
295 params[arg.name] = [args[i]]
297 params[arg.name] = args[i]
298 else: # no value given
299 assert in_optional_args == True, arg.name
300 params[arg.name] = arg.default
301 if len(args) > len(self.args): # add some additional repeats
302 assert self.args[-1].repeatable == True, self.args[-1].name
303 params[self.args[-1].name].extend(args[len(self.args):])
306 def _run(self, **kwargs):
307 raise NotImplementedError
309 def help(self, *args):
310 return '\n\n'.join([self.usage(),
312 self._long_help().rstrip('\n')])
315 usage = 'usage: be %s [options]' % self.name
317 for arg in self.args:
319 if arg.optional == True:
323 if arg.repeatable == True:
325 usage += ']'*num_optional
328 def _option_help(self):
329 o = OptionFormatter(self)
330 return o.option_help().strip('\n')
332 def _long_help(self):
333 return "A detailed help message."
335 def complete(self, argument=None, fragment=None):
337 ret = ['--%s' % o.name for o in self.options]
338 if len(self.args) > 0 and self.args[0].completion_callback != None:
339 ret.extend(self.args[0].completion_callback(self, argument, fragment))
341 elif argument.completion_callback != None:
342 # finish a particular argument
343 return argument.completion_callback(self, argument, fragment)
344 return [] # the particular argument doesn't supply completion info
346 def _check_restricted_access(self, storage, path):
348 Check that the file at path is inside bugdir.root. This is
349 important if you allow other users to execute becommands with
350 your username (e.g. if you're running be-handle-mail through
351 your ~/.procmailrc). If this check wasn't made, a user could
353 be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
354 which would expose your ssh key to anyone who could read the
357 >>> class DummyStorage (object): pass
358 >>> s = DummyStorage()
359 >>> s.repo = os.path.expanduser('~/x/')
362 ... c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
363 ... except UserError, e:
364 ... assert str(e).startswith('file access restricted!'), str(e)
365 ... print 'we got the expected error'
366 we got the expected error
367 >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
368 >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
369 >>> c.restrict_file_access = False
370 >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
372 if self.restrict_file_access == True:
373 path = os.path.abspath(path)
374 repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
375 if path == repo or path.startswith(repo+os.path.sep):
377 raise UserError('file access restricted!\n %s not in %s'
383 class InputOutput (object):
384 def __init__(self, stdin=None, stdout=None):
388 def setup_command(self, command):
389 if not hasattr(self.stdin, 'encoding'):
390 self.stdin.encoding = libbe.util.encoding.get_input_encoding()
391 if not hasattr(self.stdout, 'encoding'):
392 self.stdout.encoding = libbe.util.encoding.get_output_encoding()
393 command.stdin = self.stdin
394 command.stdin.encoding = self.stdin.encoding
395 command.stdout = self.stdout
396 command.stdout.encoding = self.stdout.encoding
401 class StdInputOutput (InputOutput):
402 def __init__(self, input_encoding=None, output_encoding=None):
403 stdin,stdout = self._get_io(input_encoding, output_encoding)
404 InputOutput.__init__(self, stdin, stdout)
406 def _get_io(self, input_encoding=None, output_encoding=None):
407 if input_encoding == None:
408 input_encoding = libbe.util.encoding.get_input_encoding()
409 if output_encoding == None:
410 output_encoding = libbe.util.encoding.get_output_encoding()
411 stdin = codecs.getwriter(input_encoding)(sys.stdin)
412 stdin.encoding = input_encoding
413 stdout = codecs.getwriter(output_encoding)(sys.stdout)
414 stdout.encoding = output_encoding
415 return (stdin, stdout)
417 class StringInputOutput (InputOutput):
419 >>> s = StringInputOutput()
420 >>> s.set_stdin('hello')
425 >>> print >> s.stdout, 'goodbye'
431 Also works with unicode strings
433 >>> s.set_stdin(u'hello')
436 >>> print >> s.stdout, u'goodbye'
441 stdin = StringIO.StringIO()
442 stdin.encoding = 'utf-8'
443 stdout = StringIO.StringIO()
444 stdout.encoding = 'utf-8'
445 InputOutput.__init__(self, stdin, stdout)
447 def set_stdin(self, stdin_string):
448 self.stdin = StringIO.StringIO(stdin_string)
450 def get_stdout(self):
451 ret = self.stdout.getvalue()
452 self.stdout = StringIO.StringIO() # clear stdout for next read
453 self.stdin.encoding = 'utf-8'
456 class UnconnectedStorageGetter (object):
457 def __init__(self, location):
458 self.location = location
461 return libbe.storage.get_storage(self.location)
463 class StorageCallbacks (object):
464 def __init__(self, location=None):
467 self.location = location
468 self._get_unconnected_storage = UnconnectedStorageGetter(location)
470 def setup_command(self, command):
471 command._get_unconnected_storage = self.get_unconnected_storage
472 command._get_storage = self.get_storage
473 command._get_bugdir = self.get_bugdir
475 def get_unconnected_storage(self):
477 Callback for use by commands that need it.
479 The returned Storage instance is may actually be connected,
480 but commands that make use of the returned value should only
481 make use of non-connected Storage methods. This is mainly
482 intended for the init command, which calls Storage.init().
484 if not hasattr(self, '_unconnected_storage'):
485 if self._get_unconnected_storage == None:
486 raise NotImplementedError
487 self._unconnected_storage = self._get_unconnected_storage()
488 return self._unconnected_storage
490 def set_unconnected_storage(self, unconnected_storage):
491 self._unconnected_storage = unconnected_storage
493 def get_storage(self):
494 """Callback for use by commands that need it."""
495 if not hasattr(self, '_storage'):
496 self._storage = self.get_unconnected_storage()
497 self._storage.connect()
498 version = self._storage.storage_version()
499 if version != libbe.storage.STORAGE_VERSION:
500 raise libbe.storage.InvalidStorageVersion(version)
503 def set_storage(self, storage):
504 self._storage = storage
506 def get_bugdir(self):
507 """Callback for use by commands that need it."""
508 if not hasattr(self, '_bugdir'):
509 self._bugdir = libbe.bugdir.BugDir(self.get_storage(),
513 def set_bugdir(self, bugdir):
514 self._bugdir = bugdir
517 if hasattr(self, '_storage'):
518 self._storage.disconnect()
520 class UserInterface (object):
521 def __init__(self, io=None, location=None):
523 io = StringInputOutput()
525 self.storage_callbacks = StorageCallbacks(location)
526 self.restrict_file_access = True
529 raise NotImplementedError
531 def run(self, command, options=None, args=None):
532 self.setup_command(command)
533 return command.run(options, args)
535 def setup_command(self, command):
536 if command.ui == None:
539 self.io.setup_command(command)
540 if self.storage_callbacks != None:
541 self.storage_callbacks.setup_command(command)
542 command.restrict_file_access = self.restrict_file_access
543 command._get_user_id = self._get_user_id
545 def _get_user_id(self):
546 """Callback for use by commands that need it."""
547 if not hasattr(self, '_user_id'):
548 self._user_id = libbe.ui.util.user.get_user_id(
549 self.storage_callbacks.get_storage())
553 self.storage_callbacks.cleanup()