Update exception handling to use `except X as Y` syntax.
[update-copyright.git] / update_copyright / utils.py
index 7cceb222f39eba297c1ee288207a744345548fba..117235149b469d0cb992e6c64dad59d5ee402e19 100644 (file)
@@ -1,30 +1,35 @@
-# Copyright (C) 2012 W. Trevor King
+# Copyright (C) 2012 W. Trevor King <wking@tremily.us>
 #
 # This file is part of update-copyright.
 #
-# update-copyright is free software: you can redistribute it and/or
-# modify it under the terms of the GNU General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
+# update-copyright is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option) any
+# later version.
 #
-# update-copyright is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
-# General Public License for more details.
+# update-copyright is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
+# more details.
 #
-# You should have received a copy of the GNU General Public License
-# along with update-copyright.  If not, see
-# <http://www.gnu.org/licenses/>.
+# You should have received a copy of the GNU General Public License along with
+# update-copyright.  If not, see <http://www.gnu.org/licenses/>.
 
+import codecs as _codecs
 import difflib as _difflib
+import locale as _locale
 import os as _os
 import os.path as _os_path
+import sys as _sys
 import textwrap as _textwrap
 import time as _time
 
 from . import LOG as _LOG
 
 
+ENCODING = _locale.getpreferredencoding() or _sys.getdefaultencoding()
+
+
 def long_author_formatter(copyright_year_string, authors):
     """
     >>> print '\\n'.join(long_author_formatter(
@@ -51,12 +56,12 @@ def short_author_formatter(copyright_year_string, authors):
 
 def copyright_string(original_year, final_year, authors, text, info={},
                      author_format_fn=long_author_formatter,
-                     formatter_kwargs={}, prefix='', wrap=True,
+                     formatter_kwargs={}, prefix=('', '', None), wrap=True,
                      **wrap_kwargs):
     """
     >>> print(copyright_string(original_year=2005, final_year=2005,
     ...                        authors=['A <a@a.com>', 'B <b@b.edu>'],
-    ...                        text=['BLURB',], prefix='# '
+    ...                        text=['BLURB',], prefix=('# ', '# ', None),
     ...                        )) # doctest: +REPORT_UDIFF
     # Copyright (C) 2005 A <a@a.com>
     #                    B <b@b.edu>
@@ -64,6 +69,15 @@ def copyright_string(original_year, final_year, authors, text, info={},
     # BLURB
     >>> print(copyright_string(original_year=2005, final_year=2009,
     ...                        authors=['A <a@a.com>', 'B <b@b.edu>'],
+    ...                        text=['BLURB',], prefix=('/* ', ' * ', ' */'),
+    ...                        )) # doctest: +REPORT_UDIFF
+    /* Copyright (C) 2005-2009 A <a@a.com>
+     *                         B <b@b.edu>
+     *
+     * BLURB
+     */
+    >>> print(copyright_string(original_year=2005, final_year=2009,
+    ...                        authors=['A <a@a.com>', 'B <b@b.edu>'],
     ...                        text=['BLURB',]
     ...                        )) # doctest: +REPORT_UDIFF
     Copyright (C) 2005-2009 A <a@a.com>
@@ -95,7 +109,7 @@ def copyright_string(original_year, final_year, authors, text, info={},
     """
     for key in ['initial_indent', 'subsequent_indent']:
         if key not in wrap_kwargs:
-            wrap_kwargs[key] = prefix
+            wrap_kwargs[key] = prefix[1]
 
     if original_year == final_year:
         date_range = '%s' % original_year
@@ -106,16 +120,19 @@ def copyright_string(original_year, final_year, authors, text, info={},
     lines = author_format_fn(copyright_year_string, authors,
                              **formatter_kwargs)
     for i,line in enumerate(lines):
-        lines[i] = prefix + line
+        if i == 0:
+            lines[i] = prefix[0] + line
+        else:
+            lines[i] = prefix[1] + line
 
     for i,paragraph in enumerate(text):
         try:
             text[i] = paragraph % info
-        except ValueError, e:
+        except ValueError as e:
             _LOG.error(
                 "{}: can't format {} with {}".format(e, paragraph, info))
             raise
-        except TypeError, e:
+        except TypeError as e:
             _LOG.error(
                 ('{}: copright text must be a list of paragraph strings, '
                  'not {}').format(e, repr(text)))
@@ -126,10 +143,13 @@ def copyright_string(original_year, final_year, authors, text, info={},
     else:
         assert wrap_kwargs['subsequent_indent'] == '', \
             wrap_kwargs['subsequent_indent']
-    sep = '\n%s\n' % prefix.rstrip()
-    return sep.join(['\n'.join(lines)] + text)
+    sep = '\n{}\n'.format(prefix[1].rstrip())
+    ret = sep.join(['\n'.join(lines)] + text)
+    if prefix[2]:
+        ret += ('\n{}'.format(prefix[2]))
+    return ret
 
-def tag_copyright(contents, tag=None):
+def tag_copyright(contents, prefix=('# ', '# ', None), tag=None):
     """
     >>> contents = '''Some file
     ... bla bla
@@ -146,20 +166,45 @@ def tag_copyright(contents, tag=None):
     (copyright ends)
     bla bla bla
     <BLANKLINE>
+    >>> contents = '''Some file
+    ... bla bla
+    ... /* Copyright (copyright begins)
+    ...  * (copyright continues)
+    ...  *
+    ...  * bla bla bla
+    ...  */
+    ... (copyright ends)
+    ... bla bla bla
+    ... '''
+    >>> print tag_copyright(
+    ...     contents, prefix=('/* ', ' * ', ' */'), tag='-xyz-CR-zyx-')
+    Some file
+    bla bla
+    -xyz-CR-zyx-
+    (copyright ends)
+    bla bla bla
+    <BLANKLINE>
     """
     lines = []
     incopy = False
+    start = prefix[0] + 'Copyright'
+    middle = prefix[1].rstrip()
+    end = prefix[2]
     for line in contents.splitlines():
-        if incopy == False and line.startswith('# Copyright'):
+        if not incopy and line.startswith(start):
             incopy = True
             lines.append(tag)
-        elif incopy == True and not line.startswith('#'):
+        elif incopy and not line.startswith(middle):
+            if end:
+                assert line.startswith(end), line
             incopy = False
-        if incopy == False:
+        if not incopy:
             lines.append(line.rstrip('\n'))
+        if incopy and end and line.startswith(end):
+            incopy = False
     return '\n'.join(lines)+'\n'
 
-def update_copyright(contents, tag=None, **kwargs):
+def update_copyright(contents, prefix=('# ', '# ', None), tag=None, **kwargs):
     """
     >>> contents = '''Some file
     ... bla bla
@@ -169,9 +214,9 @@ def update_copyright(contents, tag=None, **kwargs):
     ... (copyright ends)
     ... bla bla bla
     ... '''
-    >>> print update_copyright(contents, original_year=2008,
-    ...                        authors=['Jack', 'Jill'],
-    ...                        text=['BLURB',], prefix='# ', tag='--tag--'
+    >>> print update_copyright(
+    ...     contents, original_year=2008, authors=['Jack', 'Jill'],
+    ...     text=['BLURB',], prefix=('# ', '# ', None), tag='--tag--'
     ...     ) # doctest: +ELLIPSIS, +REPORT_UDIFF
     Some file
     bla bla
@@ -184,28 +229,35 @@ def update_copyright(contents, tag=None, **kwargs):
     <BLANKLINE>
     """
     current_year = _time.gmtime()[0]
-    string = copyright_string(final_year=current_year, **kwargs)
-    contents = tag_copyright(contents=contents, tag=tag)
+    string = copyright_string(final_year=current_year, prefix=prefix, **kwargs)
+    contents = tag_copyright(contents=contents, prefix=prefix, tag=tag)
     return contents.replace(tag, string)
 
-def get_contents(filename):
+def get_contents(filename, unicode=False, encoding=None):
     if _os_path.isfile(filename):
-        f = open(filename, 'r')
+        if unicode:
+            if encoding is None:
+                encoding = ENCODING
+            f = _codecs.open(filename, 'r', encoding=encoding)
+        else:
+            f = open(filename, 'r')
         contents = f.read()
         f.close()
         return contents
     return None
 
-def set_contents(filename, contents, original_contents=None, dry_run=False):
+def set_contents(filename, contents, original_contents=None, unicode=False,
+                 encoding=None, dry_run=False):
     if original_contents is None:
-        original_contents = get_contents(filename=filename)
+        original_contents = get_contents(
+            filename=filename, unicode=unicode, encoding=encoding)
     _LOG.debug('check contents of {}'.format(filename))
     if contents != original_contents:
         if original_contents is None:
             _LOG.info('creating {}'.format(filename))
         else:
             _LOG.info('updating {}'.format(filename))
-            _LOG.debug('\n'.join(
+            _LOG.debug(u'\n'.join(
                     _difflib.unified_diff(
                         original_contents.splitlines(), contents.splitlines(),
                         fromfile=_os_path.normpath(
@@ -213,7 +265,12 @@ def set_contents(filename, contents, original_contents=None, dry_run=False):
                         tofile=_os_path.normpath(_os_path.join('b', filename)),
                         n=3, lineterm='')))
         if dry_run == False:
-            f = file(filename, 'w')
+            if unicode:
+                if encoding is None:
+                    encoding = ENCODING
+                f = _codecs.open(filename, 'w', encoding=encoding)
+            else:
+                f = file(filename, 'w')
             f.write(contents)
             f.close()
     _LOG.debug('no change in {}'.format(filename))
@@ -221,4 +278,4 @@ def set_contents(filename, contents, original_contents=None, dry_run=False):
 def list_files(root='.'):
     for dirpath,dirnames,filenames in _os.walk(root):
         for filename in filenames:
-            yield _os_path.join(root, dirpath, filename)
+            yield _os_path.normpath(_os_path.join(root, dirpath, filename))