4e1e41f839449d58cbf087dce49d14425387b5fd
[be.git] / libbe / command / base.py
1 # Copyright (C) 2009-2012 Chris Ball <cjb@laptop.org>
2 #                         Phil Schumm <philschumm@gmail.com>
3 #                         Robert Lehmann <mail@robertlehmann.de>
4 #                         W. Trevor King <wking@drexel.edu>
5 #
6 # This file is part of Bugs Everywhere.
7 #
8 # Bugs Everywhere is free software: you can redistribute it and/or modify it
9 # under the terms of the GNU General Public License as published by the Free
10 # Software Foundation, either version 2 of the License, or (at your option) any
11 # later version.
12 #
13 # Bugs Everywhere is distributed in the hope that it will be useful, but
14 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
15 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
16 # more details.
17 #
18 # You should have received a copy of the GNU General Public License along with
19 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
20
21 import codecs
22 import optparse
23 import os.path
24 import StringIO
25 import sys
26
27 import libbe
28 import libbe.storage
29 import libbe.ui.util.user
30 import libbe.util.encoding
31 import libbe.util.plugin
32
33
34 class UserError (Exception):
35     "An error due to improper BE usage."
36     pass
37
38
39 class UsageError (UserError):
40     """A serious parsing error due to invalid BE command construction.
41
42     The distinction between `UserError`\s and the more specific
43     `UsageError`\s is that when displaying a `UsageError` to the user,
44     the user is pointed towards the command usage information.  Use
45     the more general `UserError` if you feel that usage information
46     would not be particularly enlightening.
47     """
48     def __init__(self, command=None, command_name=None, message=None):
49         super(UsageError, self).__init__(message)
50         self.command = command
51         if command_name is None and command is not None:
52             command_name = command.name
53         self.command_name = command_name
54         self.message = message
55
56
57 class UnknownCommand (UsageError):
58     def __init__(self, command_name, message=None):
59         uc_message = "Unknown command '%s'" % command_name
60         if message is None:
61             message = uc_message
62         else:
63             message = '%s\n(%s)' % (uc_message, message)
64         super(UnknownCommand, self).__init__(
65             command_name=command_name,
66             message=message)
67
68
69 def get_command(command_name):
70     """Retrieves the module for a user command
71
72     >>> try:
73     ...     get_command('asdf')
74     ... except UnknownCommand, e:
75     ...     print e
76     Unknown command 'asdf'
77     (No module named asdf)
78     >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
79     True
80     """
81     try:
82         cmd = libbe.util.plugin.import_by_name(
83             'libbe.command.%s' % command_name.replace("-", "_"))
84     except ImportError, e:
85         raise UnknownCommand(command_name, message=unicode(e))
86     return cmd
87
88 def get_command_class(module=None, command_name=None):
89     """Retrieves a command class from a module.
90
91     >>> import_xml_mod = get_command('import-xml')
92     >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
93     >>> repr(import_xml)
94     "<class 'libbe.command.import_xml.Import_XML'>"
95     >>> import_xml = get_command_class(command_name='import-xml')
96     >>> repr(import_xml)
97     "<class 'libbe.command.import_xml.Import_XML'>"
98     """
99     if module == None:
100         module = get_command(command_name)
101     try:
102         cname = command_name.capitalize().replace('-', '_')
103         cmd = getattr(module, cname)
104     except ImportError, e:
105         raise UnknownCommand(command_name)
106     return cmd
107
108 def modname_to_command_name(modname):
109     """Little hack to replicate
110     >>> import sys
111     >>> def real_modname_to_command_name(modname):
112     ...     mod = libbe.util.plugin.import_by_name(
113     ...         'libbe.command.%s' % modname)
114     ...     attrs = [getattr(mod, name) for name in dir(mod)]
115     ...     commands = []
116     ...     for attr_name in dir(mod):
117     ...         attr = getattr(mod, attr_name)
118     ...         try:
119     ...             if issubclass(attr, Command):
120     ...                 commands.append(attr)
121     ...         except TypeError, e:
122     ...             pass
123     ...     if len(commands) == 0:
124     ...         raise Exception('No Command classes in %s' % dir(mod))
125     ...     return commands[0].name
126     >>> real_modname_to_command_name('new')
127     'new'
128     >>> real_modname_to_command_name('import_xml')
129     'import-xml'
130     """
131     return modname.replace('_', '-')
132
133 def commands(command_names=False):
134     for modname in libbe.util.plugin.modnames('libbe.command'):
135         if modname not in ['base', 'util']:
136             if command_names == False:
137                 yield modname
138             else:
139                 yield modname_to_command_name(modname)
140
141 class CommandInput (object):
142     def __init__(self, name, help=''):
143         self.name = name
144         self.help = help
145
146     def __str__(self):
147         return '<%s %s>' % (self.__class__.__name__, self.name)
148
149     def __repr__(self):
150         return self.__str__()
151
152 class Argument (CommandInput):
153     def __init__(self, metavar=None, default=None, type='string',
154                  optional=False, repeatable=False,
155                  completion_callback=None, *args, **kwargs):
156         CommandInput.__init__(self, *args, **kwargs)
157         self.metavar = metavar
158         self.default = default
159         self.type = type
160         self.optional = optional
161         self.repeatable = repeatable
162         self.completion_callback = completion_callback
163         if self.metavar == None:
164             self.metavar = self.name.upper()
165
166 class Option (CommandInput):
167     def __init__(self, callback=None, short_name=None, arg=None,
168                  *args, **kwargs):
169         CommandInput.__init__(self, *args, **kwargs)
170         self.callback = callback
171         self.short_name = short_name
172         self.arg = arg
173         if self.arg == None and self.callback == None:
174             # use an implicit boolean argument
175             self.arg = Argument(name=self.name, help=self.help,
176                                 default=False, type='bool')
177         self.validate()
178
179     def validate(self):
180         if self.arg == None:
181             assert self.callback != None, self.name
182             return
183         assert self.callback == None, '%s: %s' (self.name, self.callback)
184         assert self.arg.name == self.name, \
185             'Name missmatch: %s != %s' % (self.arg.name, self.name)
186         assert self.arg.optional == False, self.name
187         assert self.arg.repeatable == False, self.name
188
189     def __str__(self):
190         return '--%s' % self.name
191
192     def __repr__(self):
193         return '<Option %s>' % self.__str__()
194
195 class _DummyParser (optparse.OptionParser):
196     def __init__(self, command):
197         optparse.OptionParser.__init__(self)
198         self.remove_option('-h')
199         self.command = command
200         self._command_opts = []
201         for option in self.command.options:
202             self._add_option(option)
203
204     def _add_option(self, option):
205         # from libbe.ui.command_line.CmdOptionParser._add_option
206         option.validate()
207         long_opt = '--%s' % option.name
208         if option.short_name != None:
209             short_opt = '-%s' % option.short_name
210         assert '_' not in option.name, \
211             'Non-reconstructable option name %s' % option.name
212         kwargs = {'dest':option.name.replace('-', '_'),
213                   'help':option.help}
214         if option.arg == None or option.arg.type == 'bool':
215             kwargs['action'] = 'store_true'
216             kwargs['metavar'] = None
217             kwargs['default'] = False
218         else:
219             kwargs['type'] = option.arg.type
220             kwargs['action'] = 'store'
221             kwargs['metavar'] = option.arg.metavar
222             kwargs['default'] = option.arg.default
223         if option.short_name != None:
224             opt = optparse.Option(short_opt, long_opt, **kwargs)
225         else:
226             opt = optparse.Option(long_opt, **kwargs)
227         #option.takes_value = lambda : option.arg != None
228         opt._option = option
229         self._command_opts.append(opt)
230         self.add_option(opt)
231
232 class OptionFormatter (optparse.IndentedHelpFormatter):
233     def __init__(self, command):
234         optparse.IndentedHelpFormatter.__init__(self)
235         self.command = command
236     def option_help(self):
237         # based on optparse.OptionParser.format_option_help()
238         parser = _DummyParser(self.command)
239         self.store_option_strings(parser)
240         ret = []
241         ret.append(self.format_heading('Options'))
242         self.indent()
243         for option in parser._command_opts:
244             ret.append(self.format_option(option))
245             ret.append('\n')
246         self.dedent()
247         # Drop the last '\n', or the header if no options or option groups:
248         return ''.join(ret[:-1])
249
250 class Command (object):
251     """One-line command description here.
252
253     >>> c = Command()
254     >>> print c.help()
255     usage: be command [options]
256     <BLANKLINE>
257     Options:
258       -h, --help  Print a help message.
259     <BLANKLINE>
260       --complete  Print a list of possible completions.
261     <BLANKLINE>
262     A detailed help message.
263     """
264
265     name = 'command'
266
267     def __init__(self, ui=None):
268         self.ui = ui # calling user-interface
269         self.status = None
270         self.result = None
271         self.restrict_file_access = True
272         self.options = [
273             Option(name='help', short_name='h',
274                 help='Print a help message.',
275                 callback=self.help),
276             Option(name='complete',
277                 help='Print a list of possible completions.',
278                 callback=self.complete),
279                 ]
280         self.args = []
281
282     def run(self, options=None, args=None):
283         self.status = 1 # in case we raise an exception
284         params = self._parse_options_args(options, args)
285         if params['help'] == True:
286             pass
287         else:
288             params.pop('help')
289         if params['complete'] != None:
290             pass
291         else:
292             params.pop('complete')
293
294         self.status = self._run(**params)
295         return self.status
296
297     def _parse_options_args(self, options=None, args=None):
298         if options == None:
299             options = {}
300         if args == None:
301             args = []
302         params = {}
303         for option in self.options:
304             assert option.name not in params, params[option.name]
305             if option.name in options:
306                 params[option.name] = options.pop(option.name)
307             elif option.arg != None:
308                 params[option.name] = option.arg.default
309             else: # non-arg options are flags, set to default flag value
310                 params[option.name] = False
311         assert 'user-id' not in params, params['user-id']
312         if 'user-id' in options:
313             self._user_id = options.pop('user-id')
314         if len(options) > 0:
315             raise UserError, 'Invalid option passed to command %s:\n  %s' \
316                 % (self.name, '\n  '.join(['%s: %s' % (k,v)
317                                            for k,v in options.items()]))
318         in_optional_args = False
319         for i,arg in enumerate(self.args):
320             if arg.repeatable == True:
321                 assert i == len(self.args)-1, arg.name
322             if in_optional_args == True:
323                 assert arg.optional == True, arg.name
324             else:
325                 in_optional_args = arg.optional
326             if i < len(args):
327                 if arg.repeatable == True:
328                     params[arg.name] = [args[i]]
329                 else:
330                     params[arg.name] = args[i]
331             else:  # no value given
332                 assert in_optional_args == True, arg.name
333                 params[arg.name] = arg.default
334         if len(args) > len(self.args):  # add some additional repeats
335             assert self.args[-1].repeatable == True, self.args[-1].name
336             params[self.args[-1].name].extend(args[len(self.args):])
337         return params
338
339     def _run(self, **kwargs):
340         raise NotImplementedError
341
342     def help(self, *args):
343         return '\n\n'.join([self.usage(),
344                             self._option_help(),
345                             self._long_help().rstrip('\n')])
346
347     def usage(self):
348         usage = 'usage: be %s [options]' % self.name
349         num_optional = 0
350         for arg in self.args:
351             usage += ' '
352             if arg.optional == True:
353                 usage += '['
354                 num_optional += 1
355             usage += arg.metavar
356             if arg.repeatable == True:
357                 usage += ' ...'
358         usage += ']'*num_optional
359         return usage
360
361     def _option_help(self):
362         o = OptionFormatter(self)
363         return o.option_help().strip('\n')
364
365     def _long_help(self):
366         return "A detailed help message."
367
368     def complete(self, argument=None, fragment=None):
369         if argument == None:
370             ret = ['--%s' % o.name for o in self.options
371                     if o.name != 'complete']
372             if len(self.args) > 0 and self.args[0].completion_callback != None:
373                 ret.extend(self.args[0].completion_callback(self, argument, fragment))
374             return ret
375         elif argument.completion_callback != None:
376             # finish a particular argument
377             return argument.completion_callback(self, argument, fragment)
378         return [] # the particular argument doesn't supply completion info
379
380     def _check_restricted_access(self, storage, path):
381         """
382         Check that the file at path is inside bugdir.root.  This is
383         important if you allow other users to execute becommands with
384         your username (e.g. if you're running be-handle-mail through
385         your ~/.procmailrc).  If this check wasn't made, a user could
386         e.g.  run
387           be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
388         which would expose your ssh key to anyone who could read the
389         VCS log.
390
391         >>> class DummyStorage (object): pass
392         >>> s = DummyStorage()
393         >>> s.repo = os.path.expanduser('~/x/')
394         >>> c = Command()
395         >>> try:
396         ...     c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
397         ... except UserError, e:
398         ...     assert str(e).startswith('file access restricted!'), str(e)
399         ...     print 'we got the expected error'
400         we got the expected error
401         >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
402         >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
403         >>> c.restrict_file_access = False
404         >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
405         """
406         if self.restrict_file_access == True:
407             path = os.path.abspath(path)
408             repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
409             if path == repo or path.startswith(repo+os.path.sep):
410                 return
411             raise UserError('file access restricted!\n  %s not in %s'
412                             % (path, repo))
413
414     def cleanup(self):
415         pass
416
417 class InputOutput (object):
418     def __init__(self, stdin=None, stdout=None):
419         self.stdin = stdin
420         self.stdout = stdout
421
422     def setup_command(self, command):
423         if not hasattr(self.stdin, 'encoding'):
424             self.stdin.encoding = libbe.util.encoding.get_input_encoding()
425         if not hasattr(self.stdout, 'encoding'):
426             self.stdout.encoding = libbe.util.encoding.get_output_encoding()
427         command.stdin = self.stdin
428         command.stdin.encoding = self.stdin.encoding
429         command.stdout = self.stdout
430         command.stdout.encoding = self.stdout.encoding
431
432     def cleanup(self):
433         pass
434
435 class StdInputOutput (InputOutput):
436     def __init__(self, input_encoding=None, output_encoding=None):
437         stdin,stdout = self._get_io(input_encoding, output_encoding)
438         InputOutput.__init__(self, stdin, stdout)
439
440     def _get_io(self, input_encoding=None, output_encoding=None):
441         if input_encoding == None:
442             input_encoding = libbe.util.encoding.get_input_encoding()
443         if output_encoding == None:
444             output_encoding = libbe.util.encoding.get_output_encoding()
445         stdin = codecs.getreader(input_encoding)(sys.stdin)
446         stdin.encoding = input_encoding
447         stdout = codecs.getwriter(output_encoding)(sys.stdout)
448         stdout.encoding = output_encoding
449         return (stdin, stdout)
450
451 class StringInputOutput (InputOutput):
452     """
453     >>> s = StringInputOutput()
454     >>> s.set_stdin('hello')
455     >>> s.stdin.read()
456     'hello'
457     >>> s.stdin.read()
458     ''
459     >>> print >> s.stdout, 'goodbye'
460     >>> s.get_stdout()
461     'goodbye\\n'
462     >>> s.get_stdout()
463     ''
464
465     Also works with unicode strings
466
467     >>> s.set_stdin(u'hello')
468     >>> s.stdin.read()
469     u'hello'
470     >>> print >> s.stdout, u'goodbye'
471     >>> s.get_stdout()
472     u'goodbye\\n'
473     """
474     def __init__(self):
475         stdin = StringIO.StringIO()
476         stdin.encoding = 'utf-8'
477         stdout = StringIO.StringIO()
478         stdout.encoding = 'utf-8'
479         InputOutput.__init__(self, stdin, stdout)
480
481     def set_stdin(self, stdin_string):
482         self.stdin = StringIO.StringIO(stdin_string)
483
484     def get_stdout(self):
485         ret = self.stdout.getvalue()
486         self.stdout = StringIO.StringIO() # clear stdout for next read
487         self.stdin.encoding = 'utf-8'
488         return ret
489
490 class UnconnectedStorageGetter (object):
491     def __init__(self, location):
492         self.location = location
493
494     def __call__(self):
495         return libbe.storage.get_storage(self.location)
496
497 class StorageCallbacks (object):
498     def __init__(self, location=None):
499         if location == None:
500             location = '.'
501         self.location = location
502         self._get_unconnected_storage = UnconnectedStorageGetter(location)
503
504     def setup_command(self, command):
505         command._get_unconnected_storage = self.get_unconnected_storage
506         command._get_storage = self.get_storage
507         command._get_bugdir = self.get_bugdir
508
509     def get_unconnected_storage(self):
510         """
511         Callback for use by commands that need it.
512         
513         The returned Storage instance is may actually be connected,
514         but commands that make use of the returned value should only
515         make use of non-connected Storage methods.  This is mainly
516         intended for the init command, which calls Storage.init().
517         """
518         if not hasattr(self, '_unconnected_storage'):
519             if self._get_unconnected_storage == None:
520                 raise NotImplementedError
521             self._unconnected_storage = self._get_unconnected_storage()
522         return self._unconnected_storage
523
524     def set_unconnected_storage(self, unconnected_storage):
525         self._unconnected_storage = unconnected_storage
526
527     def get_storage(self):
528         """Callback for use by commands that need it."""
529         if not hasattr(self, '_storage'):
530             self._storage = self.get_unconnected_storage()
531             self._storage.connect()
532             version = self._storage.storage_version()
533             if version != libbe.storage.STORAGE_VERSION:
534                 raise libbe.storage.InvalidStorageVersion(version)
535         return self._storage
536
537     def set_storage(self, storage):
538         self._storage = storage
539
540     def get_bugdir(self):
541         """Callback for use by commands that need it."""
542         if not hasattr(self, '_bugdir'):
543             self._bugdir = libbe.bugdir.BugDir(self.get_storage(),
544                                                from_storage=True)
545         return self._bugdir
546
547     def set_bugdir(self, bugdir):
548         self._bugdir = bugdir
549
550     def cleanup(self):
551         if hasattr(self, '_storage'):
552             self._storage.disconnect()
553
554 class UserInterface (object):
555     def __init__(self, io=None, location=None):
556         if io == None:
557             io = StringInputOutput()
558         self.io = io
559         self.storage_callbacks = StorageCallbacks(location)
560         self.restrict_file_access = True
561
562     def help(self):
563         raise NotImplementedError
564
565     def run(self, command, options=None, args=None):
566         self.setup_command(command)
567         return command.run(options, args)
568
569     def setup_command(self, command):
570         if command.ui == None:
571             command.ui = self
572         if self.io != None:
573             self.io.setup_command(command)
574         if self.storage_callbacks != None:
575             self.storage_callbacks.setup_command(command)        
576         command.restrict_file_access = self.restrict_file_access
577         command._get_user_id = self._get_user_id
578
579     def _get_user_id(self):
580         """Callback for use by commands that need it."""
581         if not hasattr(self, '_user_id'):
582             self._user_id = libbe.ui.util.user.get_user_id(
583                 self.storage_callbacks.get_storage())
584         return self._user_id
585
586     def cleanup(self):
587         self.storage_callbacks.cleanup()
588         self.io.cleanup()