_selinux/spawn_wrapper: setexec *after* fork
authorZac Medico <zmedico@gentoo.org>
Fri, 27 Jul 2012 02:42:51 +0000 (19:42 -0700)
committerZac Medico <zmedico@gentoo.org>
Fri, 27 Jul 2012 02:42:51 +0000 (19:42 -0700)
This avoids any interference with concurrent threads in the calling
process.

pym/portage/_selinux.py

index 9470978c4e404e93915b25f0d81f644dca63e23d..173714515386e977c45f2788f1e93c0ad6979a8d 100644 (file)
@@ -95,20 +95,32 @@ def setfscreate(ctx="\n"):
                raise OSError(
                        _("setfscreate: Failed setting fs create context \"%s\".") % ctx)
 
-def spawn_wrapper(spawn_func, selinux_type):
-
-       selinux_type = _unicode_encode(selinux_type,
-               encoding=_encodings['content'], errors='strict')
-
-       def wrapper_func(*args, **kwargs):
-               con = settype(selinux_type)
-               setexec(con)
-               try:
-                       return spawn_func(*args, **kwargs)
-               finally:
-                       setexec()
-
-       return wrapper_func
+class spawn_wrapper(object):
+       """
+       Create a wrapper function for the given spawn function. When the wrapper
+       is called, it will adjust the arguments such that setexec() to be called
+       *after* the fork (thereby avoiding any interference with concurrent
+       threads in the calling process).
+       """
+       __slots__ = ("_con", "_spawn_func")
+
+       def __init__(self, spawn_func, selinux_type):
+               self._spawn_func = spawn_func
+               selinux_type = _unicode_encode(selinux_type,
+                       encoding=_encodings['content'], errors='strict')
+               self._con = settype(selinux_type)
+
+       def __call__(self, *args, **kwargs):
+
+               pre_exec = kwargs.get("pre_exec")
+
+               def _pre_exec():
+                       if pre_exec is not None:
+                               pre_exec()
+                       setexec(self._con)
+
+               kwargs["pre_exec"] = _pre_exec
+               return self._spawn_func(*args, **kwargs)
 
 def symlink(target, link, reflnk):
        target = _unicode_encode(target, encoding=_encodings['fs'], errors='strict')