helpers: update package counts in get_cpvs doctest
[gentoolkit.git] / pym / gentoolkit / helpers.py
1 # Copyright(c) 2009-2010, Gentoo Foundation
2 #
3 # Licensed under the GNU General Public License, v2 or higher
4 #
5 # $Header$
6
7 """Miscellaneous helper functions and classes.
8
9 @note: find_* functions that previously lived here have moved to
10            the query module, where they are called as: Query('portage').find_*().
11 """
12
13 __all__ = (
14         'ChangeLog',
15         'FileOwner',
16         'get_cpvs',
17         'get_installed_cpvs',
18         'get_uninstalled_cpvs',
19         'get_bintree_cpvs',
20         'uniqify',
21 )
22 __docformat__ = 'epytext'
23
24 # =======
25 # Imports
26 # =======
27
28 import os
29 import re
30 import codecs
31 from functools import partial
32 from itertools import chain
33
34 import portage
35
36 from gentoolkit import pprinter as pp
37 from gentoolkit import errors
38 from gentoolkit.atom import Atom
39 from gentoolkit.cpv import CPV
40 from gentoolkit.versionmatch import VersionMatch
41 # This has to be imported below to stop circular import.
42 #from gentoolkit.package import Package
43
44 # =======
45 # Classes
46 # =======
47
48 class ChangeLog(object):
49         """Provides methods for working with a Gentoo ChangeLog file.
50
51         Example usage:
52                 >>> from gentoolkit.helpers import ChangeLog
53                 >>> portage = ChangeLog('/usr/portage/sys-apps/portage/ChangeLog')
54                 >>> print(portage.latest.strip())
55                 *portage-2.2.0_alpha142 (26 Oct 2012)
56                 <BLANKLINE>
57                   26 Oct 2012; Zac Medico <zmedico@gentoo.org> +portage-2.2.0_alpha142.ebuild:
58                   2.2.0_alpha142 version bump. This includes all of the fixes in 2.1.11.31. Bug
59                   #210077 tracks all bugs fixed since portage-2.1.x.
60                 >>> len(portage.full)
61                 75
62                 >>> len(portage.entries_matching_range(
63                 ...     from_ver='2.2_rc40',
64                 ...     to_ver='2.2_rc50'))
65                 11
66
67         """
68         def __init__(self, changelog_path, invalid_entry_is_fatal=False):
69                 if not (os.path.isfile(changelog_path) and
70                         os.access(changelog_path, os.R_OK)):
71                         raise errors.GentoolkitFatalError(
72                                 "%s does not exist or is unreadable" % pp.path(changelog_path)
73                         )
74                 self.changelog_path = changelog_path
75                 self.invalid_entry_is_fatal = invalid_entry_is_fatal
76
77                 # Process the ChangeLog:
78                 self.entries = self._split_changelog()
79                 self.indexed_entries = self._index_changelog()
80                 self.full = self.entries
81                 self.latest = self.entries[0]
82
83         def __repr__(self):
84                 return "<%s %r>" % (self.__class__.__name__, self.changelog_path)
85
86         def entries_matching_atom(self, atom):
87                 """Return entries whose header versions match atom's version.
88
89                 @type atom: L{gentoolkit.atom.Atom} or str
90                 @param atom: a atom to find matching entries against
91                 @rtype: list
92                 @return: entries matching atom
93                 @raise errors.GentoolkitInvalidAtom: if atom is a string and malformed
94                 """
95                 result = []
96
97                 if not isinstance(atom, Atom):
98                         atom = Atom(atom)
99
100                 for entry_set in self.indexed_entries:
101                         i, entry = entry_set
102                         # VersionMatch doesn't store .cp, so we'll force it to match here:
103                         i.cp = atom.cp
104                         if atom.intersects(i):
105                                 result.append(entry)
106
107                 return result
108
109         def entries_matching_range(self, from_ver=None, to_ver=None):
110                 """Return entries whose header versions are within a range of versions.
111
112                 @type from_ver: str
113                 @param from_ver: valid Gentoo version
114                 @type to_ver: str
115                 @param to_ver: valid Gentoo version
116                 @rtype: list
117                 @return: entries between from_ver and to_ver
118                 @raise errors.GentoolkitFatalError: if neither vers are set
119                 @raise errors.GentoolkitInvalidVersion: if either ver is invalid
120                 """
121                 result = []
122
123                 # Make sure we have at least one version set
124                 if not (from_ver or to_ver):
125                         raise errors.GentoolkitFatalError(
126                                 "Need to specifiy 'from_ver' or 'to_ver'"
127                         )
128
129                 # Create a VersionMatch instance out of from_ver
130                 from_restriction = None
131                 if from_ver:
132                         try:
133                                 from_ver_rev = CPV("null-%s" % from_ver)
134                         except errors.GentoolkitInvalidCPV:
135                                 raise errors.GentoolkitInvalidVersion(from_ver)
136                         from_restriction = VersionMatch(from_ver_rev, op='>=')
137
138                 # Create a VersionMatch instance out of to_ver
139                 to_restriction = None
140                 if to_ver:
141                         try:
142                                 to_ver_rev = CPV("null-%s" % to_ver)
143                         except errors.GentoolkitInvalidCPV:
144                                 raise errors.GentoolkitInvalidVersion(to_ver)
145                         to_restriction = VersionMatch(to_ver_rev, op='<=')
146
147                 # Add entry to result if version ranges intersect it
148                 for entry_set in self.indexed_entries:
149                         i, entry = entry_set
150                         if from_restriction and not from_restriction.match(i):
151                                 continue
152                         if to_restriction and not to_restriction.match(i):
153                                 # TODO: is it safe to break here?
154                                 continue
155                         result.append(entry)
156
157                 return result
158
159         def _index_changelog(self):
160                 """Use the output of L{self._split_changelog} to create an index list
161                 of L{gentoolkit.versionmatch.VersionMatch} objects.
162
163                 @rtype: list
164                 @return: tuples containing a VersionMatch instance for the release
165                         version of each entry header as the first item and the entire entry
166                         as the second item
167                 @raise ValueError: if self.invalid_entry_is_fatal is True and we hit an
168                         invalid entry
169                 """
170
171                 result = []
172                 for entry in self.entries:
173                         # Extract the package name from the entry header, ex:
174                         # *xterm-242 (07 Mar 2009) => xterm-242
175                         pkg_name = entry.split(' ', 1)[0].lstrip('*')
176                         if not pkg_name.strip():
177                                 continue
178                         try:
179                                 entry_ver = CPV(pkg_name, validate=True)
180                         except errors.GentoolkitInvalidCPV:
181                                 if self.invalid_entry_is_fatal:
182                                         raise ValueError(entry_ver)
183                                 continue
184
185                         result.append((VersionMatch(entry_ver, op='='), entry))
186
187                 return result
188
189         def _split_changelog(self):
190                 """Split the ChangeLog into individual entries.
191
192                 @rtype: list
193                 @return: individual ChangeLog entries
194                 """
195
196                 result = []
197                 partial_entries = []
198                 with codecs.open(self.changelog_path, encoding="utf-8",
199                         errors="replace") as log:
200                         for line in log:
201                                 if line.startswith('#'):
202                                         continue
203                                 elif line.startswith('*'):
204                                         # Append last entry to result...
205                                         entry = ''.join(partial_entries)
206                                         if entry and not entry.isspace():
207                                                 result.append(entry)
208                                         # ... and start a new entry
209                                         partial_entries = [line]
210                                 else:
211                                         partial_entries.append(line)
212                         else:
213                                 # Append the final entry
214                                 entry = ''.join(partial_entries)
215                                 result.append(entry)
216
217                 return result
218
219
220 class FileOwner(object):
221         """Creates a function for locating the owner of filename queries.
222
223         Example usage:
224                 >>> from gentoolkit.helpers import FileOwner
225                 >>> findowner = FileOwner()
226                 >>> findowner(('/bin/grep',))
227                 [(<Package 'sys-apps/grep-2.12'>, '/bin/grep')]
228         """
229         def __init__(self, is_regex=False, early_out=False, printer_fn=None):
230                 """Instantiate function.
231
232                 @type is_regex: bool
233                 @param is_regex: funtion args are regular expressions
234                 @type early_out: bool
235                 @param early_out: return when first result is found (safe)
236                 @type printer_fn: callable
237                 @param printer_fn: If defined, will be passed useful information for
238                         printing each result as it is found.
239                 """
240                 self.is_regex = is_regex
241                 self.early_out = early_out
242                 self.printer_fn = printer_fn
243
244         def __call__(self, queries):
245                 """Run the function.
246
247                 @type queries: iterable
248                 @param queries: filepaths or filepath regexes
249                 """
250                 query_re_string = self._prepare_search_regex(queries)
251                 try:
252                         query_re = re.compile(query_re_string)
253                 except (TypeError, re.error) as err:
254                         raise errors.GentoolkitInvalidRegex(err)
255
256                 use_match = False
257                 if ((self.is_regex or query_re_string.startswith('^\/'))
258                         and '|' not in query_re_string ):
259                         # If we were passed a regex or a single path starting with root,
260                         # we can use re.match, else use re.search.
261                         use_match = True
262
263                 pkgset = get_installed_cpvs()
264
265                 return self.find_owners(query_re, use_match=use_match, pkgset=pkgset)
266
267         def find_owners(self, query_re, use_match=False, pkgset=None):
268                 """Find owners and feed data to supplied output function.
269
270                 @type query_re: _sre.SRE_Pattern
271                 @param query_re: file regex
272                 @type use_match: bool
273                 @param use_match: use re.match or re.search
274                 @type pkgset: iterable or None
275                 @param pkgset: list of packages to look through
276                 """
277                 # FIXME: Remove when lazyimport supports objects:
278                 from gentoolkit.package import Package
279
280                 if use_match:
281                         query_fn = query_re.match
282                 else:
283                         query_fn = query_re.search
284
285                 results = []
286                 found_match = False
287                 for pkg in sorted([Package(x) for x in pkgset]):
288                         files = pkg.parsed_contents()
289                         for cfile in files:
290                                 match = query_fn(cfile)
291                                 if match:
292                                         results.append((pkg, cfile))
293                                         if self.printer_fn is not None:
294                                                 self.printer_fn(pkg, cfile)
295                                         if self.early_out:
296                                                 found_match = True
297                                                 break
298                         if found_match:
299                                 break
300                 return results
301
302         @staticmethod
303         def expand_abspaths(paths):
304                 """Expand any relative paths (./file) to their absolute paths.
305
306                 @type paths: list
307                 @param paths: file path strs
308                 @rtype: list
309                 @return: the original list with any relative paths expanded
310                 @raise AttributeError: if paths does not have attribute 'extend'
311                 """
312
313                 osp = os.path
314                 expanded_paths = []
315                 for path in paths:
316                         if path.startswith('./'):
317                                 expanded_paths.append(osp.abspath(path))
318                         else:
319                                 expanded_paths.append(path)
320
321                 return expanded_paths
322
323         @staticmethod
324         def extend_realpaths(paths):
325                 """Extend a list of paths with the realpaths for any symlinks.
326
327                 @type paths: list
328                 @param paths: file path strs
329                 @rtype: list
330                 @return: the original list plus the realpaths for any symlinks
331                         so long as the realpath doesn't already exist in the list
332                 @raise AttributeError: if paths does not have attribute 'extend'
333                 """
334
335                 osp = os.path
336                 paths.extend([osp.realpath(x) for x in paths
337                         if osp.realpath(x) not in paths])
338
339                 return paths
340
341         def _prepare_search_regex(self, queries):
342                 """Create a regex out of the queries"""
343
344                 queries = list(queries)
345                 if self.is_regex:
346                         return '|'.join(queries)
347                 else:
348                         result = []
349                         # Trim trailing and multiple slashes from queries
350                         slashes = re.compile('/+')
351                         queries = self.expand_abspaths(queries)
352                         queries = self.extend_realpaths(queries)
353                         for query in queries:
354                                 query = slashes.sub('/', query).rstrip('/')
355                                 if query.startswith('/'):
356                                         query = "^%s$" % re.escape(query)
357                                 else:
358                                         query = "/%s$" % re.escape(query)
359                                 result.append(query)
360                 result = "|".join(result)
361                 return result
362
363 # =========
364 # Functions
365 # =========
366
367 def get_cpvs(predicate=None, include_installed=True):
368         """Get all packages in the Portage tree and overlays. Optionally apply a
369         predicate.
370
371         Example usage:
372                 >>> from gentoolkit.helpers import get_cpvs
373                 >>> len(set(get_cpvs()))
374                 33518
375                 >>> fn = lambda x: x.startswith('app-portage')
376                 >>> len(set(get_cpvs(fn, include_installed=False)))
377                 137
378
379         @type predicate: function
380         @param predicate: a function to filter the package list with
381         @type include_installed: bool
382         @param include_installed:
383                 If True: Return the union of all_cpvs and all_installed_cpvs
384                 If False: Return the difference of all_cpvs and all_installed_cpvs
385         @rtype: generator
386         @return: a generator that yields unsorted cat/pkg-ver strings from the
387                 Portage tree
388         """
389
390         if predicate:
391                 all_cps = iter(
392                         x for x
393                         in portage.db[portage.root]["porttree"].dbapi.cp_all()
394                         if predicate(x))
395         else:
396                 all_cps = portage.db[portage.root]["porttree"].dbapi.cp_all()
397
398         all_cpvs = chain.from_iterable(
399                 portage.db[portage.root]["porttree"].dbapi.cp_list(x)
400                 for x in all_cps)
401         all_installed_cpvs = set(get_installed_cpvs(predicate))
402
403         if include_installed:
404                 for cpv in all_cpvs:
405                         if cpv in all_installed_cpvs:
406                                 all_installed_cpvs.remove(cpv)
407                         yield cpv
408                 for cpv in all_installed_cpvs:
409                         yield cpv
410         else:
411                 for cpv in all_cpvs:
412                         if cpv not in all_installed_cpvs:
413                                 yield cpv
414
415
416 # pylint thinks this is a global variable
417 # pylint: disable-msg=C0103
418 get_uninstalled_cpvs = partial(get_cpvs, include_installed=False)
419
420
421 def get_installed_cpvs(predicate=None):
422         """Get all installed packages. Optionally apply a predicate.
423
424         @type predicate: function
425         @param predicate: a function to filter the package list with
426         @rtype: generator
427         @return: a generator that yields unsorted installed cat/pkg-ver strings
428                 from VARDB
429         """
430
431         if predicate:
432                 installed_cps = iter(
433                         x for x
434                         in portage.db[portage.root]["vartree"].dbapi.cp_all()
435                         if predicate(x))
436         else:
437                 installed_cps = (
438                         portage.db[portage.root]["vartree"].dbapi.cp_all())
439
440         for cpv in chain.from_iterable(
441                 portage.db[portage.root]["vartree"].dbapi.cp_list(x)
442                 for x in installed_cps):
443                 yield cpv
444
445
446 def get_bintree_cpvs(predicate=None):
447         """Get all binary packages available. Optionally apply a predicate.
448
449         @type predicate: function
450         @param predicate: a function to filter the package list with
451         @rtype: generator
452         @return: a generator that yields unsorted binary package cat/pkg-ver strings
453                 from BINDB
454         """
455
456         if predicate:
457                 installed_cps = iter(
458                         x for x
459                         in portage.db[portage.root]["bintree"].dbapi.cp_all()
460                         if predicate(x))
461         else:
462                 installed_cps = (
463                         portage.db[portage.root]["bintree"].dbapi.cp_all())
464
465         for cpv in chain.from_iterable(
466                 portage.db[portage.root]["bintree"].dbapi.cp_list(x)
467                 for x in installed_cps):
468                 yield cpv
469
470
471 def print_file(path):
472         """Display the contents of a file."""
473
474         with open(path, "rb") as open_file:
475                 lines = open_file.read()
476                 pp.uprint(lines.strip())
477
478
479 def print_sequence(seq):
480         """Print every item of a sequence."""
481
482         for item in seq:
483                 pp.uprint(item)
484
485
486 def uniqify(seq, preserve_order=True):
487         """Return a uniqified list. Optionally preserve order."""
488
489         if preserve_order:
490                 seen = set()
491                 result = [x for x in seq if x not in seen and not seen.add(x)]
492         else:
493                 result = list(set(seq))
494
495         return result
496
497 # vim: set ts=4 sw=4 tw=79: