74573a38874b4fa3a133d197972479719f938f10
[be.git] / libbe / command / base.py
1 # Copyright
2
3 import codecs
4 import optparse
5 import os.path
6 import sys
7
8 import libbe
9 import libbe.ui.util.user
10 import libbe.util.encoding
11 import libbe.util.plugin
12
13 class UserError(Exception):
14     pass
15
16 class UnknownCommand(UserError):
17     def __init__(self, cmd):
18         Exception.__init__(self, "Unknown command '%s'" % cmd)
19         self.cmd = cmd
20
21
22 def get_command(command_name):
23     """Retrieves the module for a user command
24
25     >>> try:
26     ...     get_command('asdf')
27     ... except UnknownCommand, e:
28     ...     print e
29     Unknown command 'asdf'
30     >>> repr(get_command('list')).startswith("<module 'libbe.command.list' from ")
31     True
32     """
33     try:
34         cmd = libbe.util.plugin.import_by_name(
35             'libbe.command.%s' % command_name.replace("-", "_"))
36     except ImportError, e:
37         raise UnknownCommand(command_name)
38     return cmd
39
40 def get_command_class(module, command_name):
41     """Retrieves a command class from a module.
42
43     >>> import_xml_mod = get_command('import-xml')
44     >>> import_xml = get_command_class(import_xml_mod, 'import-xml')
45     >>> repr(import_xml)
46     "<class 'libbe.command.import_xml.Import_XML'>"
47     """
48     try:
49         cname = command_name.capitalize().replace('-', '_')
50         cmd = getattr(module, cname)
51     except ImportError, e:
52         raise UnknownCommand(command_name)
53     return cmd
54
55 def commands():
56     for modname in libbe.util.plugin.modnames('libbe.command'):
57         if modname not in ['base', 'util']:
58             yield modname
59
60 class CommandInput (object):
61     def __init__(self, name, help=''):
62         self.name = name
63         self.help = help
64
65 class Argument (CommandInput):
66     def __init__(self, metavar=None, default=None, type='string',
67                  optional=False, repeatable=False,
68                  completion_callback=None, *args, **kwargs):
69         CommandInput.__init__(self, *args, **kwargs)
70         self.metavar = metavar
71         self.default = default
72         self.type = type
73         self.optional = optional
74         self.repeatable = repeatable
75         self.completion_callback = completion_callback
76         if self.metavar == None:
77             self.metavar = self.name.upper()
78
79 class Option (CommandInput):
80     def __init__(self, callback=None, short_name=None, arg=None,
81                  *args, **kwargs):
82         CommandInput.__init__(self, *args, **kwargs)
83         self.callback = callback
84         self.short_name = short_name
85         self.arg = arg
86         if self.arg == None and self.callback == None:
87             # use an implicit boolean argument
88             self.arg = Argument(name=self.name, help=self.help,
89                                 default=False, type='bool')
90         self.validate()
91
92     def validate(self):
93         if self.arg == None:
94             assert self.callback != None, self.name
95             return
96         assert self.callback == None, '%s: %s' (self.name, self.callback)
97         assert self.arg.name == self.name, \
98             'Name missmatch: %s != %s' % (self.arg.name, self.name)
99         assert self.arg.optional == False, self.name
100         assert self.arg.repeatable == False, self.name
101
102     def __str__(self):
103         return '--%s' % self.name
104
105     def __repr__(self):
106         return '<Option %s>' % self.__str__()
107
108 class _DummyParser (optparse.OptionParser):
109     def __init__(self, command):
110         optparse.OptionParser.__init__(self)
111         self.remove_option('-h')
112         self.command = command
113         self._command_opts = []
114         for option in self.command.options:
115             self._add_option(option)
116
117     def _add_option(self, option):
118         # from libbe.ui.command_line.CmdOptionParser._add_option
119         option.validate()
120         long_opt = '--%s' % option.name
121         if option.short_name != None:
122             short_opt = '-%s' % option.short_name
123         assert '_' not in option.name, \
124             'Non-reconstructable option name %s' % option.name
125         kwargs = {'dest':option.name.replace('-', '_'),
126                   'help':option.help}
127         if option.arg == None or option.arg.type == 'bool':
128             kwargs['action'] = 'store_true'
129             kwargs['metavar'] = None
130             kwargs['default'] = False
131         else:
132             kwargs['type'] = option.arg.type
133             kwargs['action'] = 'store'
134             kwargs['metavar'] = option.arg.metavar
135             kwargs['default'] = option.arg.default
136         if option.short_name != None:
137             opt = optparse.Option(short_opt, long_opt, **kwargs)
138         else:
139             opt = optparse.Option(long_opt, **kwargs)
140         #option.takes_value = lambda : option.arg != None
141         opt._option = option
142         self._command_opts.append(opt)
143         self.add_option(opt)
144
145 class OptionFormatter (optparse.IndentedHelpFormatter):
146     def __init__(self, command):
147         optparse.IndentedHelpFormatter.__init__(self)
148         self.command = command
149     def option_help(self):
150         # based on optparse.OptionParser.format_option_help()
151         parser = _DummyParser(self.command)
152         self.store_option_strings(parser)
153         ret = []
154         ret.append(self.format_heading('Options'))
155         self.indent()
156         for option in parser._command_opts:
157             ret.append(self.format_option(option))
158             ret.append('\n')
159         self.dedent()
160         # Drop the last '\n', or the header if no options or option groups:
161         return ''.join(ret[:-1])
162
163 class Command (object):
164     """One-line command description.
165
166     >>> c = Command()
167     >>> print c.help()
168     usage: be command [options]
169     <BLANKLINE>
170     Options:
171       -h, --help  Print a help message.
172     <BLANKLINE>
173       --complete  Print a list of possible completions.
174     <BLANKLINE>
175     A detailed help message.
176     """
177
178     name = 'command'
179
180     def __init__(self, input_encoding=None, output_encoding=None,
181                  get_unconnected_storage=None, ui=None):
182         self.input_encoding = input_encoding
183         self.output_encoding = output_encoding
184         self.get_unconnected_storage = get_unconnected_storage
185         self.ui = ui # calling user-interface, e.g. for Help()
186         self.status = None
187         self.result = None
188         self.restrict_file_access = True
189         self.options = [
190             Option(name='help', short_name='h',
191                 help='Print a help message.',
192                 callback=self.help),
193             Option(name='complete',
194                 help='Print a list of possible completions.',
195                 callback=self.complete),
196                 ]
197         self.args = []
198
199     def run(self, options=None, args=None):
200         if options == None:
201             options = {}
202         if args == None:
203             args = []
204         params = {}
205         for option in self.options:
206             assert option.name not in params, params[option.name]
207             if option.name in options:
208                 params[option.name] = options.pop(option.name)
209             elif option.arg != None:
210                 params[option.name] = option.arg.default
211             else: # non-arg options are flags, set to default flag value
212                 params[option.name] = False
213         assert 'user-id' not in params, params['user-id']
214         if 'user-id' in options:
215             self._user_id = options.pop('user-id')
216         if len(options) > 0:
217             raise UserError, 'Invalid option passed to command %s:\n  %s' \
218                 % (self.name, '\n  '.join(['%s: %s' % (k,v)
219                                            for k,v in options.items()]))
220         in_optional_args = False
221         for i,arg in enumerate(self.args):
222             if arg.repeatable == True:
223                 assert i == len(self.args)-1, arg.name
224             if in_optional_args == True:
225                 assert arg.optional == True, arg.name
226             else:
227                 in_optional_args = arg.optional
228             if i < len(args):
229                 if arg.repeatable == True:
230                     params[arg.name] = [args[i]]
231                 else:
232                     params[arg.name] = args[i]
233             else:  # no value given
234                 assert in_optional_args == True, arg.name
235                 params[arg.name] = arg.default
236         if len(args) > len(self.args):  # add some additional repeats
237             assert self.args[-1].repeatable == True, self.args[-1].name
238             params[self.args[-1].name].extend(args[len(self.args):])
239
240         if params['help'] == True:
241             pass
242         else:
243             params.pop('help')
244         if params['complete'] != None:
245             pass
246         else:
247             params.pop('complete')
248
249         self._setup_io(self.input_encoding, self.output_encoding)
250         self.status = self._run(**params)
251         return self.status
252
253     def _run(self, **kwargs):
254         raise NotImplementedError
255
256     def _setup_io(self, input_encoding=None, output_encoding=None):
257         if input_encoding == None:
258             input_encoding = libbe.util.encoding.get_input_encoding()
259         if output_encoding == None:
260             output_encoding = libbe.util.encoding.get_output_encoding()
261         self.stdin = codecs.getwriter(input_encoding)(sys.stdin)
262         self.stdin.encoding = input_encoding
263         self.stdout = codecs.getwriter(output_encoding)(sys.stdout)
264         self.stdout.encoding = output_encoding
265
266     def help(self, *args):       
267         return '\n\n'.join([self._usage(),
268                             self._option_help(),
269                             self._long_help()])
270
271     def _usage(self):
272         usage = 'usage: be %s [options]' % self.name
273         num_optional = 0
274         for arg in self.args:
275             usage += ' '
276             if arg.optional == True:
277                 usage += '['
278                 num_optional += 1
279             usage += arg.metavar
280             if arg.repeatable == True:
281                 usage += ' ...'
282         usage += ']'*num_optional
283         return usage
284
285     def _option_help(self):
286         o = OptionFormatter(self)
287         return o.option_help().strip('\n')
288
289     def _long_help(self):
290         return "A detailed help message."
291
292     def complete(self, argument=None, fragment=None):
293         if argument == None:
294             ret = ['--%s' % o.name for o in self.options]
295             if len(self.args) > 0 and self.args[0].completion_callback != None:
296                 ret.extend(self.args[0].completion_callback(self, argument))
297             return ret
298         elif argument.completion_callback != None:
299             # finish a particular argument
300             return argument.completion_callback(self, argument, fragment)
301         return [] # the particular argument doesn't supply completion info
302
303     def _check_restricted_access(self, storage, path):
304         """
305         Check that the file at path is inside bugdir.root.  This is
306         important if you allow other users to execute becommands with
307         your username (e.g. if you're running be-handle-mail through
308         your ~/.procmailrc).  If this check wasn't made, a user could
309         e.g.  run
310           be commit -b ~/.ssh/id_rsa "Hack to expose ssh key"
311         which would expose your ssh key to anyone who could read the
312         VCS log.
313
314         >>> class DummyStorage (object): pass
315         >>> s = DummyStorage()
316         >>> s.repo = os.path.expanduser('~/x/')
317         >>> c = Command()
318         >>> try:
319         ...     c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
320         ... except UserError, e:
321         ...     assert str(e).startswith('file access restricted!'), str(e)
322         ...     print 'we got the expected error'
323         we got the expected error
324         >>> c._check_restricted_access(s, os.path.expanduser('~/x'))
325         >>> c._check_restricted_access(s, os.path.expanduser('~/x/y'))
326         >>> c.restrict_file_access = False
327         >>> c._check_restricted_access(s, os.path.expanduser('~/.ssh/id_rsa'))
328         """
329         if self.restrict_file_access == True:
330             path = os.path.abspath(path)
331             repo = os.path.abspath(storage.repo).rstrip(os.path.sep)
332             if path == repo or path.startswith(repo+os.path.sep):
333                 return
334             raise UserError('file access restricted!\n  %s not in %s'
335                             % (path, repo))
336
337     def _get_unconnected_storage(self):
338         """Callback for use by commands that need it."""
339         if not hasattr(self, '_unconnected_storage'):
340             if self.get_unconnected_storage == None:
341                 raise NotImplementedError
342             self._unconnected_storage = self.get_unconnected_storage()
343         return self._unconnected_storage
344
345     def _get_storage(self):
346         """
347         Callback for use by commands that need it.
348         
349         Note that with the current implementation,
350         _get_unconnected_storage() will not work after this method
351         runs, but that shouldn't be an issue for any command I can
352         think of...
353         """
354         if not hasattr(self, '_storage'):
355             self._storage = self._get_unconnected_storage()
356             self._storage.connect()
357         return self._storage
358
359     def _get_bugdir(self):
360         """Callback for use by commands that need it."""
361         if not hasattr(self, '_bugdir'):
362             self._bugdir = libbe.bugdir.BugDir(self._get_storage(), from_storage=True)
363         return self._bugdir
364
365     def _get_user_id(self):
366         """Callback for use by commands that need it."""
367         if not hasattr(self, '_user_id'):
368             self._user_id = libbe.ui.util.user.get_user_id(self._get_storage())
369         return self._user_id
370
371     def cleanup(self):
372         if hasattr(self, '_storage'):
373             self._storage.disconnect()