Merged Anton Batenev's report of Nicolas Alvarez' unicode-in-be-new bug
[be.git] / libbe / command / base.py
1 # Copyright (C) 2009-2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 import codecs
18 import optparse
19 import os.path
20 import StringIO
21 import sys
22
23 import libbe
24 import libbe.storage
25 import libbe.ui.util.user
26 import libbe.util.encoding
27 import libbe.util.plugin
28
29 class UserError(Exception):
30     pass
31
32 class UnknownCommand(UserError):
33     def __init__(self, cmd):
34         Exception.__init__(self, "Unknown command '%s'" % cmd)
35         self.cmd = cmd
36
37 def get_command(command_name):
38     """Retrieves the module for a user command
39
40     >>> try:
41     ...     get_command('asdf')
42     ... except UnknownCommand, e:
43     ...     print e
44     Unknown command 'asdf'
45     >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
46     True
47     """
48     try:
49         cmd = libbe.util.plugin.import_by_name(
50             'libbe.command.%s' % command_name.replace("-", "_"))
51     except ImportError, e:
52         raise UnknownCommand(command_name)
53     return cmd
54
55 def get_command_class(module=None, command_name=None):
56     """Retrieves a command class from a module.
57
58     >>> import_xml_mod = get_command('import-xml')
59     >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
60     >>> repr(import_xml)
61     "<class 'libbe.command.import_xml.Import_XML'>"
62     >>> import_xml = get_command_class(command_name='import-xml')
63     >>> repr(import_xml)
64     "<class 'libbe.command.import_xml.Import_XML'>"
65     """
66     if module == None:
67         module = get_command(command_name)
68     try:
69         cname = command_name.capitalize().replace('-', '_')
70         cmd = getattr(module, cname)
71     except ImportError, e:
72         raise UnknownCommand(command_name)
73     return cmd
74
75 def modname_to_command_name(modname):
76     """Little hack to replicate
77     >>> import sys
78     >>> def real_modname_to_command_name(modname):
79     ...     mod = libbe.util.plugin.import_by_name(
80     ...         'libbe.command.%s' % modname)
81     ...     attrs = [getattr(mod, name) for name in dir(mod)]
82     ...     commands = []
83     ...     for attr_name in dir(mod):
84     ...         attr = getattr(mod, attr_name)
85     ...         try:
86     ...             if issubclass(attr, Command):
87     ...                 commands.append(attr)
88     ...         except TypeError, e:
89     ...             pass
90     ...     if len(commands) == 0:
91     ...         raise Exception('No Command classes in %s' % dir(mod))
92     ...     return commands[0].name
93     >>> real_modname_to_command_name('new')
94     'new'
95     >>> real_modname_to_command_name('import_xml')
96     'import-xml'
97     """
98     return modname.replace('_', '-')
99
100 def commands(command_names=False):
101     for modname in libbe.util.plugin.modnames('libbe.command'):
102         if modname not in ['base', 'util']:
103             if command_names == False:
104                 yield modname
105             else:
106                 yield modname_to_command_name(modname)
107
108 class CommandInput (object):
109     def __init__(self, name, help=''):
110         self.name = name
111         self.help = help
112
113     def __str__(self):
114         return '<%s %s>' % (self.__class__.__name__, self.name)
115
116     def __repr__(self):
117         return self.__str__()
118
119 class Argument (CommandInput):
120     def __init__(self, metavar=None, default=None, type='string',
121                  optional=False, repeatable=False,
122                  completion_callback=None, *args, **kwargs):
123         CommandInput.__init__(self, *args, **kwargs)
124         self.metavar = metavar
125         self.default = default
126         self.type = type
127         self.optional = optional
128         self.repeatable = repeatable
129         self.completion_callback = completion_callback
130         if self.metavar == None:
131             self.metavar = self.name.upper()
132
133 class Option (CommandInput):
134     def __init__(self, callback=None, short_name=None, arg=None,
135                  *args, **kwargs):
136         CommandInput.__init__(self, *args, **kwargs)
137         self.callback = callback
138         self.short_name = short_name
139         self.arg = arg
140         if self.arg == None and self.callback == None:
141             # use an implicit boolean argument
142             self.arg = Argument(name=self.name, help=self.help,
143                                 default=False, type='bool')
144         self.validate()
145
146     def validate(self):
147         if self.arg == None:
148             assert self.callback != None, self.name
149             return
150         assert self.callback == None, '%s: %s' (self.name, self.callback)
151         assert self.arg.name == self.name, \
152             'Name missmatch: %s != %s' % (self.arg.name, self.name)
153         assert self.arg.optional == False, self.name
154         assert self.arg.repeatable == False, self.name
155
156     def __str__(self):
157         return '--%s' % self.name
158
159     def __repr__(self):
160         return '<Option %s>' % self.__str__()
161
162 class _DummyParser (optparse.OptionParser):
163     def __init__(self, command):
164         optparse.OptionParser.__init__(self)
165         self.remove_option('-h')
166         self.command = command
167         self._command_opts = []
168         for option in self.command.options:
169             self._add_option(option)
170
171     def _add_option(self, option):
172         # from libbe.ui.command_line.CmdOptionParser._add_option
173         option.validate()
174         long_opt = '--%s' % option.name
175         if option.short_name != None:
176             short_opt = '-%s' % option.short_name
177         assert '_' not in option.name, \
178             'Non-reconstructable option name %s' % option.name
179         kwargs = {'dest':option.name.replace('-', '_'),
180                   'help':option.help}
181         if option.arg == None or option.arg.type == 'bool':
182             kwargs['action'] = 'store_true'
183             kwargs['metavar'] = None
184             kwargs['default'] = False
185         else:
186             kwargs['type'] = option.arg.type
187             kwargs['action'] = 'store'
188             kwargs['metavar'] = option.arg.metavar
189             kwargs['default'] = option.arg.default
190         if option.short_name != None:
191             opt = optparse.Option(short_opt, long_opt, **kwargs)
192         else:
193             opt = optparse.Option(long_opt, **kwargs)
194         #option.takes_value = lambda : option.arg != None
195         opt._option = option
196         self._command_opts.append(opt)
197         self.add_option(opt)
198
199 class OptionFormatter (optparse.IndentedHelpFormatter):
200     def __init__(self, command):
201         optparse.IndentedHelpFormatter.__init__(self)
202         self.command = command
203     def option_help(self):
204         # based on optparse.OptionParser.format_option_help()
205         parser = _DummyParser(self.command)
206         self.store_option_strings(parser)
207         ret = []
208         ret.append(self.format_heading('Options'))
209         self.indent()
210         for option in parser._command_opts:
211             ret.append(self.format_option(option))
212             ret.append('\n')
213         self.dedent()
214         # Drop the last '\n', or the header if no options or option groups:
215         return ''.join(ret[:-1])
216
217 class Command (object):
218     """One-line command description here.
219
220     >>> c = Command()
221     >>> print c.help()
222     usage: be command [options]
223     <BLANKLINE>
224     Options:
225       -h, --help  Print a help message.
226     <BLANKLINE>
227       --complete  Print a list of possible completions.
228     <BLANKLINE>
229     A detailed help message.
230     """
231
232     name = 'command'
233
234     def __init__(self, ui=None):
235         self.ui = ui # calling user-interface
236         self.status = None
237         self.result = None
238         self.restrict_file_access = True
239         self.options = [
240             Option(name='help', short_name='h',
241                 help='Print a help message.',
242                 callback=self.help),
243             Option(name='complete',
244                 help='Print a list of possible completions.',
245                 callback=self.complete),
246                 ]
247         self.args = []
248
249     def run(self, options=None, args=None):
250         self.status = 1 # in case we raise an exception
251         params = self._parse_options_args(options, args)
252         if params['help'] == True:
253             pass
254         else:
255             params.pop('help')
256         if params['complete'] != None:
257             pass
258         else:
259             params.pop('complete')
260
261         self.status = self._run(**params)
262         return self.status
263
264     def _parse_options_args(self, options=None, args=None):
265         if options == None:
266             options = {}
267         if args == None:
268             args = []
269         params = {}
270         for option in self.options:
271             assert option.name not in params, params[option.name]
272             if option.name in options:
273                 params[option.name] = options.pop(option.name)
274             elif option.arg != None:
275                 params[option.name] = option.arg.default
276             else: # non-arg options are flags, set to default flag value
277                 params[option.name] = False
278         assert 'user-id' not in params, params['user-id']
279         if 'user-id' in options:
280             self._user_id = options.pop('user-id')
281         if len(options) > 0:
282             raise UserError, 'Invalid option passed to command %s:\n  %s' \
283                 % (self.name, '\n  '.join(['%s: %s' % (k,v)
284                                            for k,v in options.items()]))
285         in_optional_args = False
286         for i,arg in enumerate(self.args):
287             if arg.repeatable == True:
288                 assert i == len(self.args)-1, arg.name
289             if in_optional_args == True:
290                 assert arg.optional == True, arg.name
291             else:
292                 in_optional_args = arg.optional
293             if i < len(args):
294                 if arg.repeatable == True:
295                     params[arg.name] = [args[i]]
296                 else:
297                     params[arg.name] = args[i]
298             else:  # no value given
299                 assert in_optional_args == True, arg.name
300                 params[arg.name] = arg.default
301         if len(args) > len(self.args):  # add some additional repeats
302             assert self.args[-1].repeatable == True, self.args[-1].name
303             params[self.args[-1].name].extend(args[len(self.args):])
304         return params
305
306     def _run(self, **kwargs):
307         raise NotImplementedError
308
309     def help(self, *args):
310         return '\n\n'.join([self.usage(),
311                             self._option_help(),
312                             self._long_help().rstrip('\n')])
313
314     def usage(self):
315         usage = 'usage: be %s [options]' % self.name
316         num_optional = 0
317         for arg in self.args:
318             usage += ' '
319             if arg.optional == True:
320                 usage += '['
321                 num_optional += 1
322             usage += arg.metavar
323             if arg.repeatable == True:
324                 usage += ' ...'
325         usage += ']'*num_optional
326         return usage
327
328     def _option_help(self):
329         o = OptionFormatter(self)
330         return o.option_help().strip('\n')
331
332     def _long_help(self):
333         return "A detailed help message."
334
335     def complete(self, argument=None, fragment=None):
336         if argument == None:
337             ret = ['--%s' % o.name for o in self.options]
338             if len(self.args) > 0 and self.args[0].completion_callback != None:
339                 ret.extend(self.args[0].completion_callback(self, argument, fragment))
340             return ret
341         elif argument.completion_callback != None:
342             # finish a particular argument
343             return argument.completion_callback(self, argument, fragment)
344         return [] # the particular argument doesn't supply completion info
345
346     def _check_restricted_access(self, storage, path):
347         """
348         Check that the file at path is inside bugdir.root.  This is
349         important if you allow other users to execute becommands with
350         your username (e.g. if you're running be-handle-mail through
351         your ~/.procmailrc).  If this check wasn't made, a user could
352         e.g.  run
353           be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
354         which would expose your ssh key to anyone who could read the
355         VCS log.
356
357         >>> class DummyStorage (object): pass
358         >>> s = DummyStorage()
359         >>> s.repo = os.path.expanduser('~/x/')
360         >>> c = Command()
361         >>> try:
362         ...     c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
363         ... except UserError, e:
364         ...     assert str(e).startswith('file access restricted!'), str(e)
365         ...     print 'we got the expected error'
366         we got the expected error
367         >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
368         >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
369         >>> c.restrict_file_access = False
370         >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
371         """
372         if self.restrict_file_access == True:
373             path = os.path.abspath(path)
374             repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
375             if path == repo or path.startswith(repo+os.path.sep):
376                 return
377             raise UserError('file access restricted!\n  %s not in %s'
378                             % (path, repo))
379
380     def cleanup(self):
381         pass
382
383 class InputOutput (object):
384     def __init__(self, stdin=None, stdout=None):
385         self.stdin = stdin
386         self.stdout = stdout
387
388     def setup_command(self, command):
389         if not hasattr(self.stdin, 'encoding'):
390             self.stdin.encoding = libbe.util.encoding.get_input_encoding()
391         if not hasattr(self.stdout, 'encoding'):
392             self.stdout.encoding = libbe.util.encoding.get_output_encoding()
393         command.stdin = self.stdin
394         command.stdin.encoding = self.stdin.encoding
395         command.stdout = self.stdout
396         command.stdout.encoding = self.stdout.encoding
397
398     def cleanup(self):
399         pass
400
401 class StdInputOutput (InputOutput):
402     def __init__(self, input_encoding=None, output_encoding=None):
403         stdin,stdout = self._get_io(input_encoding, output_encoding)
404         InputOutput.__init__(self, stdin, stdout)
405
406     def _get_io(self, input_encoding=None, output_encoding=None):
407         if input_encoding == None:
408             input_encoding = libbe.util.encoding.get_input_encoding()
409         if output_encoding == None:
410             output_encoding = libbe.util.encoding.get_output_encoding()
411         stdin = codecs.getwriter(input_encoding)(sys.stdin)
412         stdin.encoding = input_encoding
413         stdout = codecs.getwriter(output_encoding)(sys.stdout)
414         stdout.encoding = output_encoding
415         return (stdin, stdout)
416
417 class StringInputOutput (InputOutput):
418     """
419     >>> s = StringInputOutput()
420     >>> s.set_stdin('hello')
421     >>> s.stdin.read()
422     'hello'
423     >>> s.stdin.read()
424     ''
425     >>> print >> s.stdout, 'goodbye'
426     >>> s.get_stdout()
427     'goodbye\\n'
428     >>> s.get_stdout()
429     ''
430
431     Also works with unicode strings
432
433     >>> s.set_stdin(u'hello')
434     >>> s.stdin.read()
435     u'hello'
436     >>> print >> s.stdout, u'goodbye'
437     >>> s.get_stdout()
438     u'goodbye\\n'
439     """
440     def __init__(self):
441         stdin = StringIO.StringIO()
442         stdin.encoding = 'utf-8'
443         stdout = StringIO.StringIO()
444         stdout.encoding = 'utf-8'
445         InputOutput.__init__(self, stdin, stdout)
446
447     def set_stdin(self, stdin_string):
448         self.stdin = StringIO.StringIO(stdin_string)
449
450     def get_stdout(self):
451         ret = self.stdout.getvalue()
452         self.stdout = StringIO.StringIO() # clear stdout for next read
453         self.stdin.encoding = 'utf-8'
454         return ret
455
456 class UnconnectedStorageGetter (object):
457     def __init__(self, location):
458         self.location = location
459
460     def __call__(self):
461         return libbe.storage.get_storage(self.location)
462
463 class StorageCallbacks (object):
464     def __init__(self, location=None):
465         if location == None:
466             location = '.'
467         self.location = location
468         self._get_unconnected_storage = UnconnectedStorageGetter(location)
469
470     def setup_command(self, command):
471         command._get_unconnected_storage = self.get_unconnected_storage
472         command._get_storage = self.get_storage
473         command._get_bugdir = self.get_bugdir
474
475     def get_unconnected_storage(self):
476         """
477         Callback for use by commands that need it.
478         
479         The returned Storage instance is may actually be connected,
480         but commands that make use of the returned value should only
481         make use of non-connected Storage methods.  This is mainly
482         intended for the init command, which calls Storage.init().
483         """
484         if not hasattr(self, '_unconnected_storage'):
485             if self._get_unconnected_storage == None:
486                 raise NotImplementedError
487             self._unconnected_storage = self._get_unconnected_storage()
488         return self._unconnected_storage
489
490     def set_unconnected_storage(self, unconnected_storage):
491         self._unconnected_storage = unconnected_storage
492
493     def get_storage(self):
494         """Callback for use by commands that need it."""
495         if not hasattr(self, '_storage'):
496             self._storage = self.get_unconnected_storage()
497             self._storage.connect()
498             version = self._storage.storage_version()
499             if version != libbe.storage.STORAGE_VERSION:
500                 raise libbe.storage.InvalidStorageVersion(version)
501         return self._storage
502
503     def set_storage(self, storage):
504         self._storage = storage
505
506     def get_bugdir(self):
507         """Callback for use by commands that need it."""
508         if not hasattr(self, '_bugdir'):
509             self._bugdir = libbe.bugdir.BugDir(self.get_storage(),
510                                                from_storage=True)
511         return self._bugdir
512
513     def set_bugdir(self, bugdir):
514         self._bugdir = bugdir
515
516     def cleanup(self):
517         if hasattr(self, '_storage'):
518             self._storage.disconnect()
519
520 class UserInterface (object):
521     def __init__(self, io=None, location=None):
522         if io == None:
523             io = StringInputOutput()
524         self.io = io
525         self.storage_callbacks = StorageCallbacks(location)
526         self.restrict_file_access = True
527
528     def help(self):
529         raise NotImplementedError
530
531     def run(self, command, options=None, args=None):
532         self.setup_command(command)
533         return command.run(options, args)
534
535     def setup_command(self, command):
536         if command.ui == None:
537             command.ui = self
538         if self.io != None:
539             self.io.setup_command(command)
540         if self.storage_callbacks != None:
541             self.storage_callbacks.setup_command(command)        
542         command.restrict_file_access = self.restrict_file_access
543         command._get_user_id = self._get_user_id
544
545     def _get_user_id(self):
546         """Callback for use by commands that need it."""
547         if not hasattr(self, '_user_id'):
548             self._user_id = libbe.ui.util.user.get_user_id(
549                 self.storage_callbacks.get_storage())
550         return self._user_id
551
552     def cleanup(self):
553         self.storage_callbacks.cleanup()
554         self.io.cleanup()