b9038a0a127d943f9f01177780f2d5ab1a3063fd
[be.git] / libbe / command / base.py
1 # Copyright (C) 2009-2012 Chris Ball <cjb@laptop.org>
2 #                         Phil Schumm <philschumm@gmail.com>
3 #                         Robert Lehmann <mail@robertlehmann.de>
4 #                         W. Trevor King <wking@drexel.edu>
5 #
6 # This file is part of Bugs Everywhere.
7 #
8 # Bugs Everywhere is free software: you can redistribute it and/or modify it
9 # under the terms of the GNU General Public License as published by the Free
10 # Software Foundation, either version 2 of the License, or (at your option) any
11 # later version.
12 #
13 # Bugs Everywhere is distributed in the hope that it will be useful, but
14 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
15 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
16 # more details.
17 #
18 # You should have received a copy of the GNU General Public License along with
19 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
20
21 import codecs
22 import optparse
23 import os.path
24 import StringIO
25 import sys
26 import urlparse
27 import yaml
28
29 import libbe
30 import libbe.storage
31 import libbe.ui.util.user
32 import libbe.util.encoding
33 import libbe.util.http
34 import libbe.util.plugin
35
36
37 class UserError (Exception):
38     "An error due to improper BE usage."
39     pass
40
41
42 class UsageError (UserError):
43     """A serious parsing error due to invalid BE command construction.
44
45     The distinction between `UserError`\s and the more specific
46     `UsageError`\s is that when displaying a `UsageError` to the user,
47     the user is pointed towards the command usage information.  Use
48     the more general `UserError` if you feel that usage information
49     would not be particularly enlightening.
50     """
51     def __init__(self, command=None, command_name=None, message=None):
52         super(UsageError, self).__init__(message)
53         self.command = command
54         if command_name is None and command is not None:
55             command_name = command.name
56         self.command_name = command_name
57         self.message = message
58
59
60 class UnknownCommand (UsageError):
61     def __init__(self, command_name, message=None):
62         uc_message = "Unknown command '%s'" % command_name
63         if message is None:
64             message = uc_message
65         else:
66             message = '%s\n(%s)' % (uc_message, message)
67         super(UnknownCommand, self).__init__(
68             command_name=command_name,
69             message=message)
70
71
72 def get_command(command_name):
73     """Retrieves the module for a user command
74
75     >>> try:
76     ...     get_command('asdf')
77     ... except UnknownCommand, e:
78     ...     print e
79     Unknown command 'asdf'
80     (No module named asdf)
81     >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
82     True
83     """
84     try:
85         cmd = libbe.util.plugin.import_by_name(
86             'libbe.command.%s' % command_name.replace("-", "_"))
87     except ImportError, e:
88         raise UnknownCommand(command_name, message=unicode(e))
89     return cmd
90
91 def get_command_class(module=None, command_name=None):
92     """Retrieves a command class from a module.
93
94     >>> import_xml_mod = get_command('import-xml')
95     >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
96     >>> repr(import_xml)
97     "<class 'libbe.command.import_xml.Import_XML'>"
98     >>> import_xml = get_command_class(command_name='import-xml')
99     >>> repr(import_xml)
100     "<class 'libbe.command.import_xml.Import_XML'>"
101     """
102     if module == None:
103         module = get_command(command_name)
104     try:
105         cname = command_name.capitalize().replace('-', '_')
106         cmd = getattr(module, cname)
107     except ImportError, e:
108         raise UnknownCommand(command_name)
109     return cmd
110
111 def modname_to_command_name(modname):
112     """Little hack to replicate
113     >>> import sys
114     >>> def real_modname_to_command_name(modname):
115     ...     mod = libbe.util.plugin.import_by_name(
116     ...         'libbe.command.%s' % modname)
117     ...     attrs = [getattr(mod, name) for name in dir(mod)]
118     ...     commands = []
119     ...     for attr_name in dir(mod):
120     ...         attr = getattr(mod, attr_name)
121     ...         try:
122     ...             if issubclass(attr, Command):
123     ...                 commands.append(attr)
124     ...         except TypeError, e:
125     ...             pass
126     ...     if len(commands) == 0:
127     ...         raise Exception('No Command classes in %s' % dir(mod))
128     ...     return commands[0].name
129     >>> real_modname_to_command_name('new')
130     'new'
131     >>> real_modname_to_command_name('import_xml')
132     'import-xml'
133     """
134     return modname.replace('_', '-')
135
136 def commands(command_names=False):
137     for modname in libbe.util.plugin.modnames('libbe.command'):
138         if modname not in ['base', 'util']:
139             if command_names == False:
140                 yield modname
141             else:
142                 yield modname_to_command_name(modname)
143
144 class CommandInput (object):
145     def __init__(self, name, help=''):
146         self.name = name
147         self.help = help
148
149     def __str__(self):
150         return '<%s %s>' % (self.__class__.__name__, self.name)
151
152     def __repr__(self):
153         return self.__str__()
154
155 class Argument (CommandInput):
156     def __init__(self, metavar=None, default=None, type='string',
157                  optional=False, repeatable=False,
158                  completion_callback=None, *args, **kwargs):
159         CommandInput.__init__(self, *args, **kwargs)
160         self.metavar = metavar
161         self.default = default
162         self.type = type
163         self.optional = optional
164         self.repeatable = repeatable
165         self.completion_callback = completion_callback
166         if self.metavar == None:
167             self.metavar = self.name.upper()
168
169 class Option (CommandInput):
170     def __init__(self, callback=None, short_name=None, arg=None,
171                  *args, **kwargs):
172         CommandInput.__init__(self, *args, **kwargs)
173         self.callback = callback
174         self.short_name = short_name
175         self.arg = arg
176         if self.arg == None and self.callback == None:
177             # use an implicit boolean argument
178             self.arg = Argument(name=self.name, help=self.help,
179                                 default=False, type='bool')
180         self.validate()
181
182     def validate(self):
183         if self.arg == None:
184             assert self.callback != None, self.name
185             return
186         assert self.callback == None, '%s: %s' (self.name, self.callback)
187         assert self.arg.name == self.name, \
188             'Name missmatch: %s != %s' % (self.arg.name, self.name)
189         assert self.arg.optional == False, self.name
190         assert self.arg.repeatable == False, self.name
191
192     def __str__(self):
193         return '--%s' % self.name
194
195     def __repr__(self):
196         return '<Option %s>' % self.__str__()
197
198 class _DummyParser (optparse.OptionParser):
199     def __init__(self, command):
200         optparse.OptionParser.__init__(self)
201         self.remove_option('-h')
202         self.command = command
203         self._command_opts = []
204         for option in self.command.options:
205             self._add_option(option)
206
207     def _add_option(self, option):
208         # from libbe.ui.command_line.CmdOptionParser._add_option
209         option.validate()
210         long_opt = '--%s' % option.name
211         if option.short_name != None:
212             short_opt = '-%s' % option.short_name
213         assert '_' not in option.name, \
214             'Non-reconstructable option name %s' % option.name
215         kwargs = {'dest':option.name.replace('-', '_'),
216                   'help':option.help}
217         if option.arg == None or option.arg.type == 'bool':
218             kwargs['action'] = 'store_true'
219             kwargs['metavar'] = None
220             kwargs['default'] = False
221         else:
222             kwargs['type'] = option.arg.type
223             kwargs['action'] = 'store'
224             kwargs['metavar'] = option.arg.metavar
225             kwargs['default'] = option.arg.default
226         if option.short_name != None:
227             opt = optparse.Option(short_opt, long_opt, **kwargs)
228         else:
229             opt = optparse.Option(long_opt, **kwargs)
230         #option.takes_value = lambda : option.arg != None
231         opt._option = option
232         self._command_opts.append(opt)
233         self.add_option(opt)
234
235 class OptionFormatter (optparse.IndentedHelpFormatter):
236     def __init__(self, command):
237         optparse.IndentedHelpFormatter.__init__(self)
238         self.command = command
239     def option_help(self):
240         # based on optparse.OptionParser.format_option_help()
241         parser = _DummyParser(self.command)
242         self.store_option_strings(parser)
243         ret = []
244         ret.append(self.format_heading('Options'))
245         self.indent()
246         for option in parser._command_opts:
247             ret.append(self.format_option(option))
248             ret.append('\n')
249         self.dedent()
250         # Drop the last '\n', or the header if no options or option groups:
251         return ''.join(ret[:-1])
252
253 class Command (object):
254     """One-line command description here.
255
256     >>> c = Command()
257     >>> print c.help()
258     usage: be command [options]
259     <BLANKLINE>
260     Options:
261       -h, --help  Print a help message.
262     <BLANKLINE>
263       --complete  Print a list of possible completions.
264     <BLANKLINE>
265     A detailed help message.
266     """
267     user_agent = 'BE-HTTP-Command'
268
269     name = 'command'
270
271     def __init__(self, ui=None, server=None):
272         self.ui = ui # calling user-interface
273         self.server = server # location of eventual execution
274         self.status = None
275         self.result = None
276         self.restrict_file_access = True
277         self.options = [
278             Option(name='help', short_name='h',
279                 help='Print a help message.',
280                 callback=self.help),
281             Option(name='complete',
282                 help='Print a list of possible completions.',
283                 callback=self.complete),
284                 ]
285         self.args = []
286
287     def run(self, options=None, args=None):
288         self.status = 1 # in case we raise an exception
289         params = self._parse_options_args(options, args)
290         if params['help'] == True:
291             pass
292         else:
293             params.pop('help')
294         if params['complete'] != None:
295             pass
296         else:
297             params.pop('complete')
298
299         if self.server:
300             self.status = self._run_remote(**params)
301         else:
302             self.status = self._run(**params)
303         return self.status
304
305     def _parse_options_args(self, options=None, args=None):
306         if options == None:
307             options = {}
308         if args == None:
309             args = []
310         params = {}
311         for option in self.options:
312             assert option.name not in params, params[option.name]
313             if option.name in options:
314                 params[option.name] = options.pop(option.name)
315             elif option.arg != None:
316                 params[option.name] = option.arg.default
317             else: # non-arg options are flags, set to default flag value
318                 params[option.name] = False
319         assert 'user-id' not in params, params['user-id']
320         if 'user-id' in options:
321             self._user_id = options.pop('user-id')
322         if len(options) > 0:
323             raise UserError, 'Invalid option passed to command %s:\n  %s' \
324                 % (self.name, '\n  '.join(['%s: %s' % (k,v)
325                                            for k,v in options.items()]))
326         in_optional_args = False
327         for i,arg in enumerate(self.args):
328             if arg.repeatable == True:
329                 assert i == len(self.args)-1, arg.name
330             if in_optional_args == True:
331                 assert arg.optional == True, arg.name
332             else:
333                 in_optional_args = arg.optional
334             if i < len(args):
335                 if arg.repeatable == True:
336                     params[arg.name] = [args[i]]
337                 else:
338                     params[arg.name] = args[i]
339             else:  # no value given
340                 assert in_optional_args == True, arg.name
341                 params[arg.name] = arg.default
342         if len(args) > len(self.args):  # add some additional repeats
343             assert self.args[-1].repeatable == True, self.args[-1].name
344             params[self.args[-1].name].extend(args[len(self.args):])
345         return params
346
347     def _run(self, **kwargs):
348         raise NotImplementedError
349
350     def _run_remote(self, **kwargs):
351         data = yaml.safe_dump({
352                 'command': self.name,
353                 'parameters': kwargs,
354                 })
355         url = urlparse.urljoin(self.server, 'run')
356         page,final_url,info = libbe.util.http.get_post_url(
357             url=url, get=False, data=data, agent=self.user_agent)
358         self.stdout.write(page)
359         return 0
360
361     def help(self, *args):
362         return '\n\n'.join([self.usage(),
363                             self._option_help(),
364                             self._long_help().rstrip('\n')])
365
366     def usage(self):
367         usage = 'usage: be %s [options]' % self.name
368         num_optional = 0
369         for arg in self.args:
370             usage += ' '
371             if arg.optional == True:
372                 usage += '['
373                 num_optional += 1
374             usage += arg.metavar
375             if arg.repeatable == True:
376                 usage += ' ...'
377         usage += ']'*num_optional
378         return usage
379
380     def _option_help(self):
381         o = OptionFormatter(self)
382         return o.option_help().strip('\n')
383
384     def _long_help(self):
385         return "A detailed help message."
386
387     def complete(self, argument=None, fragment=None):
388         if argument == None:
389             ret = ['--%s' % o.name for o in self.options
390                     if o.name != 'complete']
391             if len(self.args) > 0 and self.args[0].completion_callback != None:
392                 ret.extend(self.args[0].completion_callback(self, argument, fragment))
393             return ret
394         elif argument.completion_callback != None:
395             # finish a particular argument
396             return argument.completion_callback(self, argument, fragment)
397         return [] # the particular argument doesn't supply completion info
398
399     def _check_restricted_access(self, storage, path):
400         """
401         Check that the file at path is inside bugdir.root.  This is
402         important if you allow other users to execute becommands with
403         your username (e.g. if you're running be-handle-mail through
404         your ~/.procmailrc).  If this check wasn't made, a user could
405         e.g.  run
406           be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
407         which would expose your ssh key to anyone who could read the
408         VCS log.
409
410         >>> class DummyStorage (object): pass
411         >>> s = DummyStorage()
412         >>> s.repo = os.path.expanduser('~/x/')
413         >>> c = Command()
414         >>> try:
415         ...     c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
416         ... except UserError, e:
417         ...     assert str(e).startswith('file access restricted!'), str(e)
418         ...     print 'we got the expected error'
419         we got the expected error
420         >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
421         >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
422         >>> c.restrict_file_access = False
423         >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
424         """
425         if self.restrict_file_access == True:
426             path = os.path.abspath(path)
427             repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
428             if path == repo or path.startswith(repo+os.path.sep):
429                 return
430             raise UserError('file access restricted!\n  %s not in %s'
431                             % (path, repo))
432
433     def cleanup(self):
434         pass
435
436 class InputOutput (object):
437     def __init__(self, stdin=None, stdout=None):
438         self.stdin = stdin
439         self.stdout = stdout
440
441     def setup_command(self, command):
442         if not hasattr(self.stdin, 'encoding'):
443             self.stdin.encoding = libbe.util.encoding.get_input_encoding()
444         if not hasattr(self.stdout, 'encoding'):
445             self.stdout.encoding = libbe.util.encoding.get_output_encoding()
446         command.stdin = self.stdin
447         command.stdin.encoding = self.stdin.encoding
448         command.stdout = self.stdout
449         command.stdout.encoding = self.stdout.encoding
450
451     def cleanup(self):
452         pass
453
454 class StdInputOutput (InputOutput):
455     def __init__(self, input_encoding=None, output_encoding=None):
456         stdin,stdout = self._get_io(input_encoding, output_encoding)
457         InputOutput.__init__(self, stdin, stdout)
458
459     def _get_io(self, input_encoding=None, output_encoding=None):
460         if input_encoding == None:
461             input_encoding = libbe.util.encoding.get_input_encoding()
462         if output_encoding == None:
463             output_encoding = libbe.util.encoding.get_output_encoding()
464         stdin = codecs.getreader(input_encoding)(sys.stdin)
465         stdin.encoding = input_encoding
466         stdout = codecs.getwriter(output_encoding)(sys.stdout)
467         stdout.encoding = output_encoding
468         return (stdin, stdout)
469
470 class StringInputOutput (InputOutput):
471     """
472     >>> s = StringInputOutput()
473     >>> s.set_stdin('hello')
474     >>> s.stdin.read()
475     'hello'
476     >>> s.stdin.read()
477     ''
478     >>> print >> s.stdout, 'goodbye'
479     >>> s.get_stdout()
480     'goodbye\\n'
481     >>> s.get_stdout()
482     ''
483
484     Also works with unicode strings
485
486     >>> s.set_stdin(u'hello')
487     >>> s.stdin.read()
488     u'hello'
489     >>> print >> s.stdout, u'goodbye'
490     >>> s.get_stdout()
491     u'goodbye\\n'
492     """
493     def __init__(self):
494         stdin = StringIO.StringIO()
495         stdin.encoding = 'utf-8'
496         stdout = StringIO.StringIO()
497         stdout.encoding = 'utf-8'
498         InputOutput.__init__(self, stdin, stdout)
499
500     def set_stdin(self, stdin_string):
501         self.stdin = StringIO.StringIO(stdin_string)
502
503     def get_stdout(self):
504         ret = self.stdout.getvalue()
505         self.stdout.truncate(size=0)
506         return ret
507
508 class UnconnectedStorageGetter (object):
509     def __init__(self, location):
510         self.location = location
511
512     def __call__(self):
513         return libbe.storage.get_storage(self.location)
514
515 class StorageCallbacks (object):
516     def __init__(self, location=None):
517         if location == None:
518             location = '.'
519         self.location = location
520         self._get_unconnected_storage = UnconnectedStorageGetter(location)
521
522     def setup_command(self, command):
523         command._get_unconnected_storage = self.get_unconnected_storage
524         command._get_storage = self.get_storage
525         command._get_bugdirs = self.get_bugdirs
526
527     def get_unconnected_storage(self):
528         """
529         Callback for use by commands that need it.
530         
531         The returned Storage instance is may actually be connected,
532         but commands that make use of the returned value should only
533         make use of non-connected Storage methods.  This is mainly
534         intended for the init command, which calls Storage.init().
535         """
536         if not hasattr(self, '_unconnected_storage'):
537             if self._get_unconnected_storage == None:
538                 raise NotImplementedError
539             self._unconnected_storage = self._get_unconnected_storage()
540         return self._unconnected_storage
541
542     def set_unconnected_storage(self, unconnected_storage):
543         self._unconnected_storage = unconnected_storage
544
545     def get_storage(self):
546         """Callback for use by commands that need it."""
547         if not hasattr(self, '_storage'):
548             self._storage = self.get_unconnected_storage()
549             self._storage.connect()
550             version = self._storage.storage_version()
551             if version != libbe.storage.STORAGE_VERSION:
552                 raise libbe.storage.InvalidStorageVersion(version)
553         return self._storage
554
555     def set_storage(self, storage):
556         self._storage = storage
557
558     def get_bugdirs(self):
559         """Callback for use by commands that need it."""
560         if not hasattr(self, '_bugdirs'):
561             storage = self.get_storage()
562             self._bugdirs = dict(
563                 (uuid, libbe.bugdir.BugDir(
564                         storage=storage,
565                         uuid=uuid,
566                         from_storage=True))
567                 for uuid in storage.children())
568         return self._bugdirs
569
570     def set_bugdirs(self, bugdirs):
571         self._bugdirs = bugdirs
572
573     def cleanup(self):
574         if hasattr(self, '_storage'):
575             self._storage.disconnect()
576
577 class UserInterface (object):
578     def __init__(self, io=None, location=None):
579         if io == None:
580             io = StringInputOutput()
581         self.io = io
582         self.storage_callbacks = StorageCallbacks(location)
583         self.restrict_file_access = True
584
585     def help(self):
586         raise NotImplementedError
587
588     def run(self, command, options=None, args=None):
589         self.setup_command(command)
590         return command.run(options, args)
591
592     def setup_command(self, command):
593         if command.ui is None:
594             command.ui = self
595         if self.io is not None:
596             self.io.setup_command(command)
597         if self.storage_callbacks is not None:
598             self.storage_callbacks.setup_command(command)        
599         command.restrict_file_access = self.restrict_file_access
600         command._get_user_id = self._get_user_id
601
602     def _get_user_id(self):
603         """Callback for use by commands that need it."""
604         if not hasattr(self, '_user_id'):
605             self._user_id = libbe.ui.util.user.get_user_id(
606                 self.storage_callbacks.get_storage())
607         return self._user_id
608
609     def cleanup(self):
610         if self.storage_callbacks is not None:
611             self.storage_callbacks.cleanup()
612         self.io.cleanup()