cd8b763ffbbe731c9f193809a2193eb81c202312
[gentoolkit.git] / pym / gentoolkit / helpers.py
1 # Copyright(c) 2009-2010, Gentoo Foundation
2 #
3 # Licensed under the GNU General Public License, v2 or higher
4 #
5 # $Header$
6
7 """Miscellaneous helper functions and classes.
8
9 @note: find_* functions that previously lived here have moved to
10            the query module, where they are called as: Query('portage').find_*().
11 """
12
13 __all__ = (
14         'ChangeLog',
15         'FileOwner',
16         'get_cpvs',
17         'get_installed_cpvs',
18         'get_uninstalled_cpvs',
19         'get_bintree_cpvs',
20         'uniqify',
21 )
22 __docformat__ = 'epytext'
23
24 # =======
25 # Imports
26 # =======
27
28 import os
29 import sys
30 import re
31 import codecs
32 from functools import partial
33 from itertools import chain
34
35 from gentoolkit import pprinter as pp
36 from gentoolkit import errors
37 from gentoolkit.atom import Atom
38 from gentoolkit.cpv import CPV
39 from gentoolkit.dbapi import BINDB, PORTDB, VARDB
40 from gentoolkit.versionmatch import VersionMatch
41 # This has to be imported below to stop circular import.
42 #from gentoolkit.package import Package
43
44 # =======
45 # Classes
46 # =======
47
48 class ChangeLog(object):
49         """Provides methods for working with a Gentoo ChangeLog file.
50
51         Example usage:
52                 >>> from gentoolkit.helpers import ChangeLog
53                 >>> portage = ChangeLog('/usr/portage/sys-apps/portage/ChangeLog')
54                 >>> print portage.latest.strip()
55                 *portage-2.2_rc50 (15 Nov 2009)
56
57                   15 Nov 2009; Zac Medico <zmedico@gentoo.org> +portage-2.2_rc50.ebuild:
58                   2.2_rc50 bump. This includes all fixes in 2.1.7.5.
59                 >>> len(portage.full)
60                 75
61                 >>> len(portage.entries_matching_range(
62                 ...     from_ver='2.2_rc40',
63                 ...     to_ver='2.2_rc50'))
64                 11
65
66         """
67         def __init__(self, changelog_path, invalid_entry_is_fatal=False):
68                 if not (os.path.isfile(changelog_path) and
69                         os.access(changelog_path, os.R_OK)):
70                         raise errors.GentoolkitFatalError(
71                                 "%s does not exist or is unreadable" % pp.path(changelog_path)
72                         )
73                 self.changelog_path = changelog_path
74                 self.invalid_entry_is_fatal = invalid_entry_is_fatal
75
76                 # Process the ChangeLog:
77                 self.entries = self._split_changelog()
78                 self.indexed_entries = self._index_changelog()
79                 self.full = self.entries
80                 self.latest = self.entries[0]
81
82         def __repr__(self):
83                 return "<%s %r>" % (self.__class__.__name__, self.changelog_path)
84
85         def entries_matching_atom(self, atom):
86                 """Return entries whose header versions match atom's version.
87
88                 @type atom: L{gentoolkit.atom.Atom} or str
89                 @param atom: a atom to find matching entries against
90                 @rtype: list
91                 @return: entries matching atom
92                 @raise errors.GentoolkitInvalidAtom: if atom is a string and malformed
93                 """
94                 result = []
95
96                 if not isinstance(atom, Atom):
97                         atom = Atom(atom)
98
99                 for entry_set in self.indexed_entries:
100                         i, entry = entry_set
101                         # VersionMatch doesn't store .cp, so we'll force it to match here:
102                         i.cp = atom.cp
103                         if atom.intersects(i):
104                                 result.append(entry)
105
106                 return result
107
108         def entries_matching_range(self, from_ver=None, to_ver=None):
109                 """Return entries whose header versions are within a range of versions.
110
111                 @type from_ver: str
112                 @param from_ver: valid Gentoo version
113                 @type to_ver: str
114                 @param to_ver: valid Gentoo version
115                 @rtype: list
116                 @return: entries between from_ver and to_ver
117                 @raise errors.GentoolkitFatalError: if neither vers are set
118                 @raise errors.GentoolkitInvalidVersion: if either ver is invalid
119                 """
120                 result = []
121
122                 # Make sure we have at least one version set
123                 if not (from_ver or to_ver):
124                         raise errors.GentoolkitFatalError(
125                                 "Need to specifiy 'from_ver' or 'to_ver'"
126                         )
127
128                 # Create a VersionMatch instance out of from_ver
129                 from_restriction = None
130                 if from_ver:
131                         try:
132                                 from_ver_rev = CPV("null-%s" % from_ver)
133                         except errors.GentoolkitInvalidCPV:
134                                 raise errors.GentoolkitInvalidVersion(from_ver)
135                         from_restriction = VersionMatch(from_ver_rev, op='>=')
136
137                 # Create a VersionMatch instance out of to_ver
138                 to_restriction = None
139                 if to_ver:
140                         try:
141                                 to_ver_rev = CPV("null-%s" % to_ver)
142                         except errors.GentoolkitInvalidCPV:
143                                 raise errors.GentoolkitInvalidVersion(to_ver)
144                         to_restriction = VersionMatch(to_ver_rev, op='<=')
145
146                 # Add entry to result if version ranges intersect it
147                 for entry_set in self.indexed_entries:
148                         i, entry = entry_set
149                         if from_restriction and not from_restriction.match(i):
150                                 continue
151                         if to_restriction and not to_restriction.match(i):
152                                 # TODO: is it safe to break here?
153                                 continue
154                         result.append(entry)
155
156                 return result
157
158         def _index_changelog(self):
159                 """Use the output of L{self._split_changelog} to create an index list
160                 of L{gentoolkit.versionmatch.VersionMatch} objects.
161
162                 @rtype: list
163                 @return: tuples containing a VersionMatch instance for the release
164                         version of each entry header as the first item and the entire entry
165                         as the second item
166                 @raise ValueError: if self.invalid_entry_is_fatal is True and we hit an
167                         invalid entry
168                 """
169
170                 result = []
171                 for entry in self.entries:
172                         # Extract the package name from the entry header, ex:
173                         # *xterm-242 (07 Mar 2009) => xterm-242
174                         pkg_name = entry.split(' ', 1)[0].lstrip('*')
175                         if not pkg_name.strip():
176                                 continue
177                         try:
178                                 entry_ver = CPV(pkg_name)
179                         except errors.GentoolkitInvalidCPV:
180                                 if self.invalid_entry_is_fatal:
181                                         raise ValueError(entry_ver)
182                                 continue
183
184                         result.append((VersionMatch(entry_ver, op='='), entry))
185
186                 return result
187
188         def _split_changelog(self):
189                 """Split the ChangeLog into individual entries.
190
191                 @rtype: list
192                 @return: individual ChangeLog entries
193                 """
194
195                 result = []
196                 partial_entries = []
197                 with codecs.open(self.changelog_path, encoding="utf-8",
198                         errors="replace") as log:
199                         for line in log:
200                                 if line.startswith('#'):
201                                         continue
202                                 elif line.startswith('*'):
203                                         # Append last entry to result...
204                                         entry = ''.join(partial_entries)
205                                         if entry and not entry.isspace():
206                                                 result.append(entry)
207                                         # ... and start a new entry
208                                         partial_entries = [line]
209                                 else:
210                                         partial_entries.append(line)
211                         else:
212                                 # Append the final entry
213                                 entry = ''.join(partial_entries)
214                                 result.append(entry)
215
216                 return result
217
218
219 class FileOwner(object):
220         """Creates a function for locating the owner of filename queries.
221
222         Example usage:
223                 >>> from gentoolkit.helpers import FileOwner
224                 >>> findowner = FileOwner()
225                 >>> findowner(('/usr/bin/vim',))
226                 [(<Package app-editors/vim-7.2.182>, '/usr/bin/vim')]
227         """
228         def __init__(self, is_regex=False, early_out=False, printer_fn=None):
229                 """Instantiate function.
230
231                 @type is_regex: bool
232                 @param is_regex: funtion args are regular expressions
233                 @type early_out: bool
234                 @param early_out: return when first result is found (safe)
235                 @type printer_fn: callable
236                 @param printer_fn: If defined, will be passed useful information for
237                         printing each result as it is found.
238                 """
239                 self.is_regex = is_regex
240                 self.early_out = early_out
241                 self.printer_fn = printer_fn
242
243         def __call__(self, queries):
244                 """Run the function.
245
246                 @type queries: iterable
247                 @param queries: filepaths or filepath regexes
248                 """
249                 query_re_string = self._prepare_search_regex(queries)
250                 try:
251                         query_re = re.compile(query_re_string)
252                 except (TypeError, re.error) as err:
253                         raise errors.GentoolkitInvalidRegex(err)
254
255                 use_match = False
256                 if ((self.is_regex or query_re_string.startswith('^\/'))
257                         and '|' not in query_re_string ):
258                         # If we were passed a regex or a single path starting with root,
259                         # we can use re.match, else use re.search.
260                         use_match = True
261
262                 pkgset = get_installed_cpvs()
263
264                 return self.find_owners(query_re, use_match=use_match, pkgset=pkgset)
265
266         def find_owners(self, query_re, use_match=False, pkgset=None):
267                 """Find owners and feed data to supplied output function.
268
269                 @type query_re: _sre.SRE_Pattern
270                 @param query_re: file regex
271                 @type use_match: bool
272                 @param use_match: use re.match or re.search
273                 @type pkgset: iterable or None
274                 @param pkgset: list of packages to look through
275                 """
276                 # FIXME: Remove when lazyimport supports objects:
277                 from gentoolkit.package import Package
278
279                 if use_match:
280                         query_fn = query_re.match
281                 else:
282                         query_fn = query_re.search
283
284                 results = []
285                 found_match = False
286                 for pkg in sorted([Package(x) for x in pkgset]):
287                         files = pkg.parsed_contents()
288                         for cfile in files:
289                                 match = query_fn(cfile)
290                                 if match:
291                                         results.append((pkg, cfile))
292                                         if self.printer_fn is not None:
293                                                 self.printer_fn(pkg, cfile)
294                                         if self.early_out:
295                                                 found_match = True
296                                                 break
297                         if found_match:
298                                 break
299                 return results
300
301         @staticmethod
302         def expand_abspaths(paths):
303                 """Expand any relative paths (./file) to their absolute paths.
304
305                 @type paths: list
306                 @param paths: file path strs
307                 @rtype: list
308                 @return: the original list with any relative paths expanded
309                 @raise AttributeError: if paths does not have attribute 'extend'
310                 """
311
312                 osp = os.path
313                 expanded_paths = []
314                 for p in paths:
315                         if p.startswith('./'):
316                                 expanded_paths.append(osp.abspath(p))
317                         else:
318                                 expanded_paths.append(p)
319
320                 return expanded_paths
321
322         @staticmethod
323         def extend_realpaths(paths):
324                 """Extend a list of paths with the realpaths for any symlinks.
325
326                 @type paths: list
327                 @param paths: file path strs
328                 @rtype: list
329                 @return: the original list plus the realpaths for any symlinks
330                         so long as the realpath doesn't already exist in the list
331                 @raise AttributeError: if paths does not have attribute 'extend'
332                 """
333
334                 osp = os.path
335                 paths.extend([osp.realpath(x) for x in paths
336                         if osp.islink(x) and osp.realpath(x) not in paths])
337
338                 return paths
339
340         def _prepare_search_regex(self, queries):
341                 """Create a regex out of the queries"""
342
343                 queries = list(queries)
344                 if self.is_regex:
345                         return '|'.join(queries)
346                 else:
347                         result = []
348                         # Trim trailing and multiple slashes from queries
349                         slashes = re.compile('/+')
350                         queries = self.expand_abspaths(queries)
351                         queries = self.extend_realpaths(queries)
352                         for query in queries:
353                                 query = slashes.sub('/', query).rstrip('/')
354                                 if query.startswith('/'):
355                                         query = "^%s$" % re.escape(query)
356                                 else:
357                                         query = "/%s$" % re.escape(query)
358                                 result.append(query)
359                 result = "|".join(result)
360                 return result
361
362 # =========
363 # Functions
364 # =========
365
366 def get_cpvs(predicate=None, include_installed=True):
367         """Get all packages in the Portage tree and overlays. Optionally apply a
368         predicate.
369
370         Example usage:
371                 >>> from gentoolkit.helpers import get_cpvs
372                 >>> len(set(get_cpvs()))
373                 26065
374                 >>> fn = lambda x: x.startswith('app-portage')
375                 >>> len(get_cpvs(fn, include_installed=False))
376                 112
377
378         @type predicate: function
379         @param predicate: a function to filter the package list with
380         @type include_installed: bool
381         @param include_installed:
382                 If True: Return the union of all_cpvs and all_installed_cpvs
383                 If False: Return the difference of all_cpvs and all_installed_cpvs
384         @rtype: generator
385         @return: a generator that yields unsorted cat/pkg-ver strings from the
386                 Portage tree
387         """
388
389         if predicate:
390                 all_cps = iter(x for x in PORTDB.cp_all() if predicate(x))
391         else:
392                 all_cps = PORTDB.cp_all()
393
394         all_cpvs = chain.from_iterable(PORTDB.cp_list(x) for x in all_cps)
395         all_installed_cpvs = set(get_installed_cpvs(predicate))
396
397         if include_installed:
398                 for cpv in all_cpvs:
399                         if cpv in all_installed_cpvs:
400                                 all_installed_cpvs.remove(cpv)
401                         yield cpv
402                 for cpv in all_installed_cpvs:
403                         yield cpv
404         else:
405                 for cpv in all_cpvs:
406                         if cpv not in all_installed_cpvs:
407                                 yield cpv
408
409
410 # pylint thinks this is a global variable
411 # pylint: disable-msg=C0103
412 get_uninstalled_cpvs = partial(get_cpvs, include_installed=False)
413
414
415 def get_installed_cpvs(predicate=None):
416         """Get all installed packages. Optionally apply a predicate.
417
418         @type predicate: function
419         @param predicate: a function to filter the package list with
420         @rtype: generator
421         @return: a generator that yields unsorted installed cat/pkg-ver strings
422                 from VARDB
423         """
424
425         if predicate:
426                 installed_cps = iter(x for x in VARDB.cp_all() if predicate(x))
427         else:
428                 installed_cps = VARDB.cp_all()
429
430         for cpv in chain.from_iterable(VARDB.cp_list(x) for x in installed_cps):
431                 yield cpv
432
433
434 def get_bintree_cpvs(predicate=None):
435         """Get all binary packages available. Optionally apply a predicate.
436
437         @type predicate: function
438         @param predicate: a function to filter the package list with
439         @rtype: generator
440         @return: a generator that yields unsorted binary package cat/pkg-ver strings
441                 from BINDB
442         """
443
444         if predicate:
445                 installed_cps = iter(x for x in BINDB.cp_all() if predicate(x))
446         else:
447                 installed_cps = BINDB.cp_all()
448
449         for cpv in chain.from_iterable(BINDB.cp_list(x) for x in installed_cps):
450                 yield cpv
451
452
453 def print_file(path):
454         """Display the contents of a file."""
455
456         with open(path, "rb") as open_file:
457                 lines = open_file.read()
458                 pp.uprint(lines.strip())
459
460
461 def print_sequence(seq):
462         """Print every item of a sequence."""
463
464         for item in seq:
465                 pp.uprint(item)
466
467
468 def uniqify(seq, preserve_order=True):
469         """Return a uniqified list. Optionally preserve order."""
470
471         if preserve_order:
472                 seen = set()
473                 result = [x for x in seq if x not in seen and not seen.add(x)]
474         else:
475                 result = list(set(seq))
476
477         return result
478
479 # vim: set ts=4 sw=4 tw=79: