Split cci_thread_init into per-process and per-thread portions
authorTom Yu <tlyu@mit.edu>
Mon, 12 Dec 2011 20:46:20 +0000 (20:46 +0000)
committerTom Yu <tlyu@mit.edu>
Mon, 12 Dec 2011 20:46:20 +0000 (20:46 +0000)
Call the per-thread code on thread attach and per-process once per
process.  Previously, while the function was named 'thread', it was
only actually called once per process.  Currently, the per-thread
code does nothing on non-windows platforms and is not even actually
invoked.

Fixes a windows bug when multiple non-main threads try to use ccapi
at the same time.

Signed-off-by: Kevin Wasserman <kevin.wasserman@painless-security.com>
ticket: 7050

git-svn-id: svn://anonsvn.mit.edu/krb5/trunk@25569 dc483132-0cff-0310-8789-dd5450dbe970

src/ccapi/common/win/OldCC/autolock.hxx
src/ccapi/lib/ccapi_context.c
src/ccapi/lib/ccapi_ipc.c
src/ccapi/lib/ccapi_ipc.h
src/ccapi/lib/ccapi_os_ipc.h
src/ccapi/lib/win/OldCC/client.cxx
src/ccapi/lib/win/ccapi_os_ipc.cxx
src/ccapi/lib/win/debug.exports
src/ccapi/lib/win/dllmain.cxx

index bbd773488c0271d9b354647a877e06b19620f64d..45b881e221aa6a3074a937b75c13d588a5c8e5aa 100644 (file)
@@ -35,10 +35,8 @@ public:
     ~CcOsLock()     {DeleteCriticalSection(&cs);       valid = false;}
     void lock()     {if (valid) EnterCriticalSection(&cs);}
     void unlock()   {if (valid) LeaveCriticalSection(&cs);}
-#if 0
     bool trylock()  {return valid ? (TryEnterCriticalSection(&cs) ? true : false)
                                   : false; }
-#endif
 };
 
 class CcAutoLock {
@@ -50,4 +48,13 @@ public:
     ~CcAutoLock() { m_lock.unlock(); }
 };
 
+class CcAutoTryLock {
+    CcOsLock& m_lock;
+    bool m_locked;
+public:
+    CcAutoTryLock(CcOsLock& lock):m_lock(lock) { m_locked = m_lock.trylock(); }
+    ~CcAutoTryLock() { if (m_locked) m_lock.unlock(); m_locked = false; }
+    bool IsLocked() const { return m_locked; }
+};
+
 #endif /* __AUTOLOCK_HXX */
index a16ce0e580b91002ab7e2b278585bd1fc24b513c..cf677fc551426d4309e05db415a09a443eab63c7 100644 (file)
@@ -79,12 +79,12 @@ static cc_int32 cci_context_sync (cci_context_t in_context,
 #pragma mark -
 #endif
 
-MAKE_INIT_FUNCTION(cci_thread_init);
-MAKE_FINI_FUNCTION(cci_thread_fini);
+MAKE_INIT_FUNCTION(cci_process_init);
+MAKE_FINI_FUNCTION(cci_process_fini);
 
 /* ------------------------------------------------------------------------ */
 
-static int cci_thread_init (void)
+static int cci_process_init (void)
 {
     cc_int32 err = ccNoError;
 
@@ -93,7 +93,7 @@ static int cci_thread_init (void)
     }
 
     if (!err) {
-        err = cci_ipc_thread_init ();
+        err = cci_ipc_process_init ();
     }
 
     if (!err) {
@@ -105,9 +105,9 @@ static int cci_thread_init (void)
 
 /* ------------------------------------------------------------------------ */
 
-static void cci_thread_fini (void)
+static void cci_process_fini (void)
 {
-    if (!INITIALIZER_RAN (cci_thread_init) || PROGRAM_EXITING ()) {
+    if (!INITIALIZER_RAN (cci_process_init) || PROGRAM_EXITING ()) {
        return;
     }
 
@@ -134,7 +134,7 @@ cc_int32 cc_initialize (cc_context_t  *out_context,
     if (!out_context) { err = cci_check_error (ccErrBadParam); }
 
     if (!err) {
-        err = CALL_INIT_FUNCTION (cci_thread_init);
+        err = CALL_INIT_FUNCTION (cci_process_init);
     }
 
     if (!err) {
index 66830de8eb57937da112dfe670283e6347a786fe..2c1fcba6102944f68aacdd0ee419defc6ff38d8e 100644 (file)
 
 /* ------------------------------------------------------------------------ */
 
+cc_int32 cci_ipc_process_init (void)
+{
+    return cci_os_ipc_process_init ();
+}
+
+/* ------------------------------------------------------------------------ */
+
 cc_int32 cci_ipc_thread_init (void)
 {
     return cci_os_ipc_thread_init ();
index a23791cf46ef0be86be7337bf002bca8a6624c54..a23772b29f7ba516869327ab9cb578fc2be64843 100644 (file)
@@ -28,6 +28,8 @@
 
 #include "cci_common.h"
 
+cc_int32 cci_ipc_process_init (void);
+
 cc_int32 cci_ipc_thread_init (void);
 
 cc_int32 cci_ipc_send (enum cci_msg_id_t  in_request_name,
index e27ae63c24326fa808d9cb288f438c69a4b16fa8..fe7c87a08c9b58c6c685ef26cf938fc5592f6e9a 100644 (file)
@@ -28,6 +28,8 @@
 
 #include "cci_common.h"
 
+cc_int32 cci_os_ipc_process_init (void);
+
 cc_int32 cci_os_ipc_thread_init (void);
 
 cc_int32 cci_os_ipc (cc_int32      in_launch_server,
index 5a34d38cc50c4851b631ff3008c0a0c00c8b9067..4b2d718cc431db20f6dd022efeecd5d18c7d0974 100644 (file)
@@ -395,10 +395,11 @@ Client::Connect(char* ep OPTIONAL) {
     }
 
 DWORD Client::Initialize(char* ep OPTIONAL) {
-    CcAutoLock AL(Client::sLock);
+    CcAutoTryLock AL(Client::sLock);
+    if (!AL.IsLocked() || s_init)
+        return 0;
     SecureClient s;
     ccs_request_IfHandle  = NULL;
-    if (s_init) return 0;
     DWORD status = Client::Connect(ep);
     if (!status) s_init = true;
     return status;
index 7359eb0bad2367c7152864498001c46cadc1ad9d..8cc9d03bd4f8385c19ad2ee878acb4786d1afec5 100644 (file)
@@ -64,7 +64,7 @@ SECURITY_ATTRIBUTES     sa                  = { 0 };
  */
 
 cc_int32        ccapi_connect(const struct tspdata* tsp);
-static DWORD    handle_exception(DWORD code);
+static DWORD    handle_exception(DWORD code, struct tspdata* ptspdata);
 
 extern "C" {
 cc_int32        cci_os_ipc_msg( cc_int32        in_launch_server,
@@ -75,12 +75,46 @@ cc_int32        cci_os_ipc_msg( cc_int32        in_launch_server,
 
 /* ------------------------------------------------------------------------ */
 
+extern "C" cc_int32 cci_os_ipc_process_init (void) {
+    RPC_STATUS status;
+
+    opts.cMinCalls  = 1;
+    opts.cMaxCalls  = 20;
+    if (!isNT()) {
+        status = RpcServerRegisterIf(ccs_reply_ServerIfHandle,  // interface
+                                     NULL,                      // MgrTypeUuid
+                                     NULL);                     // MgrEpv; null means use default
+        }
+    else {
+        status = RpcServerRegisterIfEx(ccs_reply_ServerIfHandle,  // interface
+                                       NULL,                      // MgrTypeUuid
+                                       NULL,                      // MgrEpv; 0 means default
+                                       RPC_IF_ALLOW_SECURE_ONLY,
+                                       opts.cMaxCalls,
+                                       NULL);                     // No security callback.
+        }
+    cci_check_error(status);
+
+    if (!status) {
+        status = RpcServerRegisterAuthInfo(0, // server principal
+                                           RPC_C_AUTHN_WINNT,
+                                           0,
+                                           0 );
+        cci_check_error(status);
+        }
+
+    return status; // ugh. needs translation
+}
+
+/* ------------------------------------------------------------------------ */
+
 extern "C" cc_int32 cci_os_ipc_thread_init (void) {
     cc_int32                    err         = ccNoError;
     struct tspdata*             ptspdata;
-    HANDLE                      replyEvent;
+    HANDLE                      replyEvent  = NULL;
     UUID __RPC_FAR              uuid;
-    unsigned char __RPC_FAR*    uuidString  = NULL;
+    RPC_CSTR __RPC_FAR          uuidString  = NULL;
+    char*                       endpoint    = NULL;
 
     if (!GetTspData(GetTlsIndex(), &ptspdata)) return ccErrNoMem;
 
@@ -91,10 +125,18 @@ extern "C" cc_int32 cci_os_ipc_thread_init (void) {
     err   = cci_check_error(UuidCreate(&uuid)); // Get a UUID
     if (err == RPC_S_OK) {                      // Convert to string
         err = UuidToString(&uuid, &uuidString);
+        cci_check_error(err);
         }
     if (!err) {                                 // Save in thread local storage
         tspdata_setUUID(ptspdata, uuidString);
+        endpoint = clientEndpoint((const char *)uuidString);
+        err = RpcServerUseProtseqEp((RPC_CSTR)"ncalrpc",
+                                    opts.cMaxCalls,
+                                    (RPC_CSTR)endpoint,
+                                    sa.lpSecurityDescriptor);  // SD
+        cci_check_error(err);
         }
+
 #if 0
     cci_debug_printf("%s UUID:<%s>", __FUNCTION__, tspdata_getUUID(ptspdata));
 #endif
@@ -109,6 +151,17 @@ extern "C" cc_int32 cci_os_ipc_thread_init (void) {
         replyEvent = createThreadEvent((char*)uuidString, REPLY_SUFFIX);
         }
 
+    if (!err) {
+        static bool bListening = false;
+        if (!bListening) {
+            err = RpcServerListen(opts.cMinCalls,
+                                  opts.cMaxCalls,
+                                  TRUE);
+            cci_check_error(err);
+            }
+            bListening = err == 0;
+        }
+
     if (replyEvent) tspdata_setReplyEvent(ptspdata, replyEvent);
     else            err = cci_check_error(GetLastError());
 
@@ -159,6 +212,10 @@ extern "C" cc_int32 cci_os_ipc_msg( cc_int32        in_launch_server,
     sst              = tspdata_getSST (ptspdata);
     uuid             = tspdata_getUUID(ptspdata);
 
+    // Initialize old CCAPI if necessary:
+    if (!err) if (!Init::  Initialized()) err = cci_check_error(Init::  Initialize( ));
+    if (!err) if (!Client::Initialized()) err = cci_check_error(Client::Initialize(0));
+
     // The lazy connection to the server has been put off as long as possible!
     // ccapi_connect starts listening for replies as an RPC server and then
     //   calls ccs_rpc_connect.
@@ -183,10 +240,6 @@ extern "C" cc_int32 cci_os_ipc_msg( cc_int32        in_launch_server,
     CcAutoLock*     a = 0;
     CcAutoLock::Start(a, Client::sLock);
 
-    // Initialize old CCAPI if necessary:
-    if (!err) if (!Init::  Initialized()) err = cci_check_error(Init::  Initialize( ));
-    if (!err) if (!Client::Initialized()) err = cci_check_error(Client::Initialize(0));
-
     // New code using new RPC procedures for sending the data and receiving a reply:
     if (!err) {
         RpcTryExcept {
@@ -209,7 +262,7 @@ extern "C" cc_int32 cci_os_ipc_msg( cc_int32        in_launch_server,
                 (long*)(&err) );                /* Return code */
             }
         RpcExcept(1) {
-            handle_exception(RpcExceptionCode());
+            err = handle_exception(RpcExceptionCode(), ptspdata);
             }
         RpcEndExcept;
         }
@@ -247,12 +300,13 @@ extern "C" cc_int32 cci_os_ipc_msg( cc_int32        in_launch_server,
 
 
 
-static DWORD handle_exception(DWORD code) {
+static DWORD handle_exception(DWORD code, struct tspdata* ptspdata) {
     cci_debug_printf("%s code %u; ccs_request_IfHandle:0x%X", __FUNCTION__, code, ccs_request_IfHandle);
     if ( (code == RPC_S_SERVER_UNAVAILABLE) || (code == RPC_S_INVALID_BINDING) ) {
-        Client::Reconnect(0);
+        Client::Cleanup();
+        tspdata_setConnected(ptspdata, FALSE);
         }
-    return 4;
+    return code;
     }
 
 
@@ -262,7 +316,6 @@ static DWORD handle_exception(DWORD code) {
  */
 cc_int32 ccapi_connect(const struct tspdata* tsp) {
     BOOL                    bListen     = TRUE;
-    char*                   endpoint    = NULL;
     HANDLE                  replyEvent  = 0;
     RPC_STATUS              status      = FALSE;
     char*                   uuid        = NULL;
@@ -275,56 +328,13 @@ cc_int32 ccapi_connect(const struct tspdata* tsp) {
     /* Build complete RPC uuid using previous CCAPI implementation: */
     replyEvent      = tspdata_getReplyEvent(tsp);
     uuid            = tspdata_getUUID(tsp);
-    endpoint        = clientEndpoint(uuid);
-    cci_debug_printf("%s Registering endpoint %s", __FUNCTION__, endpoint);
 
     opts.cMinCalls  = 1;
     opts.cMaxCalls  = 20;
     opts.fDontWait  = TRUE;
 
-    if (!status) {
-        status = RpcServerUseProtseqEp((RPC_CSTR)"ncalrpc",
-                                       opts.cMaxCalls,
-                                       (RPC_CSTR)endpoint,
-                                       sa.lpSecurityDescriptor);  // SD
-        cci_check_error(status);
-        }
-
-    if (!status) {
-        status = RpcServerRegisterAuthInfo(0, // server principal
-                                           RPC_C_AUTHN_WINNT,
-                                           0,
-                                           0 );
-        cci_check_error(status);
-        }
-
     cci_debug_printf("%s is listening ...", __FUNCTION__);
 
-    if (!status) {
-        if (!isNT()) {
-            status = RpcServerRegisterIf(ccs_reply_ServerIfHandle,  // interface 
-                                         NULL,                      // MgrTypeUuid
-                                         NULL);                     // MgrEpv; null means use default
-            } 
-        else {
-            status = RpcServerRegisterIfEx(ccs_reply_ServerIfHandle,// interface
-                                         NULL,                      // MgrTypeUuid
-                                         NULL,                      // MgrEpv; 0 means default
-                                         RPC_IF_ALLOW_SECURE_ONLY,
-                                         opts.cMaxCalls,
-                                         NULL);                     // No security callback.
-            }
-
-        cci_check_error(status);
-
-        if (!status) {
-            status = RpcServerListen(opts.cMinCalls,
-                                     opts.cMaxCalls,
-                                     TRUE);
-            cci_check_error(status);
-            }
-        }
-
     // Clear replyEvent so we can detect when a reply to our connect request has been received:
     ResetEvent(replyEvent);
 
index 583e9ca1ee6473d99b174d898a6603864b8e86b8..6dc1fc083a2a601047a5cc06a36db6292b45f2c0 100644 (file)
@@ -8,3 +8,4 @@
     krb5int_ipc_stream_new
 
     ccs_authenticate
+    cci_os_ipc_process_init
index e37a9ad6bfcde82c657a749de5d2b0ce57c6e62a..3141e190e7acfbac571a3706a0170e05b250c463 100644 (file)
@@ -32,9 +32,10 @@ extern "C" {
 #include "tls.h"
 #include "cci_debugging.h"
 #include "ccapi_context.h"
+#include "ccapi_ipc.h"
 #include "client.h"
 
-void cci_thread_init__auxinit();
+void cci_process_init__auxinit();
     }
 
 
@@ -91,10 +92,8 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL,     // DLL module handle
 
             // Allocate a TLS index:
             if ((dwTlsIndex = TlsAlloc()) == TLS_OUT_OF_INDEXES) return FALSE; 
-            // Initialize CCAPI once per DLL load:
-            firstThreadID = GetCurrentThreadId();
 
+            cci_process_init__auxinit();
             // Don't break; fallthrough: Initialize the TLS index for first thread.
  
         // The attached process creates a new thread:
@@ -107,8 +106,8 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL,     // DLL module handle
 
             memset(ptspdata, 0, sizeof(struct tspdata));
 
-            // Initialize CCAPI once per DLL load:
-            if (GetCurrentThreadId() == firstThreadID) cci_thread_init__auxinit();
+            // Initialize CCAPI thread data:
+            cci_ipc_thread_init();
 
             break;