rpcrt4: Re-use already registered endpoints for a protocol sequence.
authorRob Shearman <robertshearman@gmail.com>
Thu, 26 Mar 2009 13:35:36 +0000 (13:35 +0000)
committerAlexandre Julliard <julliard@winehq.org>
Thu, 26 Mar 2009 15:10:49 +0000 (16:10 +0100)
Return RPC_S_INVALID_ENDPOINT_FORMAT if a NULL endpoint is passed into
RpcServerUseProtseqEp{,Ex}{A,W}.

dlls/rpcrt4/rpc_server.c
dlls/rpcrt4/tests/rpc_protseq.c

index 12a525d68b647bb40bdef81e88fc1d9cc30d6fdb..4a87bd07dc9590938d97d698c02df60a7dc532ad 100644 (file)
@@ -566,11 +566,32 @@ static void RPCRT4_stop_listen(BOOL auto_listen)
   LeaveCriticalSection(&listen_cs);
 }
 
+static BOOL RPCRT4_protseq_is_endpoint_registered(RpcServerProtseq *protseq, LPCSTR endpoint)
+{
+  RpcConnection *conn;
+  EnterCriticalSection(&protseq->cs);
+  for (conn = protseq->conn; conn; conn = conn->Next)
+  {
+    if (!endpoint || !strcmp(endpoint, conn->Endpoint))
+      break;
+  }
+  LeaveCriticalSection(&protseq->cs);
+  return (conn != NULL);
+}
+
 static RPC_STATUS RPCRT4_use_protseq(RpcServerProtseq* ps, LPSTR endpoint)
 {
   RPC_STATUS status;
 
-  status = ps->ops->open_endpoint(ps, endpoint);
+  EnterCriticalSection(&ps->cs);
+
+  if (RPCRT4_protseq_is_endpoint_registered(ps, endpoint))
+    status = RPC_S_OK;
+  else
+    status = ps->ops->open_endpoint(ps, endpoint);
+
+  LeaveCriticalSection(&ps->cs);
+
   if (status != RPC_S_OK)
     return status;
 
@@ -751,6 +772,9 @@ RPC_STATUS WINAPI RpcServerUseProtseqEpExA( RPC_CSTR Protseq, UINT MaxCalls, RPC
        debugstr_a(szep), SecurityDescriptor,
        lpPolicy->Length, lpPolicy->EndpointFlags, lpPolicy->NICFlags );
 
+  if (!Endpoint)
+    return RPC_S_INVALID_ENDPOINT_FORMAT;
+
   status = RPCRT4_get_or_create_serverprotseq(MaxCalls, RPCRT4_strdupA(szps), &ps);
   if (status != RPC_S_OK)
     return status;
@@ -772,6 +796,9 @@ RPC_STATUS WINAPI RpcServerUseProtseqEpExW( RPC_WSTR Protseq, UINT MaxCalls, RPC
        debugstr_w( Endpoint ), SecurityDescriptor,
        lpPolicy->Length, lpPolicy->EndpointFlags, lpPolicy->NICFlags );
 
+  if (!Endpoint)
+    return RPC_S_INVALID_ENDPOINT_FORMAT;
+
   status = RPCRT4_get_or_create_serverprotseq(MaxCalls, RPCRT4_strdupWtoA(Protseq), &ps);
   if (status != RPC_S_OK)
     return status;
@@ -787,8 +814,16 @@ RPC_STATUS WINAPI RpcServerUseProtseqEpExW( RPC_WSTR Protseq, UINT MaxCalls, RPC
  */
 RPC_STATUS WINAPI RpcServerUseProtseqA(RPC_CSTR Protseq, unsigned int MaxCalls, void *SecurityDescriptor)
 {
+  RPC_STATUS status;
+  RpcServerProtseq* ps;
+
   TRACE("(Protseq == %s, MaxCalls == %d, SecurityDescriptor == ^%p)\n", debugstr_a((char*)Protseq), MaxCalls, SecurityDescriptor);
-  return RpcServerUseProtseqEpA(Protseq, MaxCalls, NULL, SecurityDescriptor);
+
+  status = RPCRT4_get_or_create_serverprotseq(MaxCalls, RPCRT4_strdupA((const char *)Protseq), &ps);
+  if (status != RPC_S_OK)
+    return status;
+
+  return RPCRT4_use_protseq(ps, NULL);
 }
 
 /***********************************************************************
@@ -796,8 +831,16 @@ RPC_STATUS WINAPI RpcServerUseProtseqA(RPC_CSTR Protseq, unsigned int MaxCalls,
  */
 RPC_STATUS WINAPI RpcServerUseProtseqW(RPC_WSTR Protseq, unsigned int MaxCalls, void *SecurityDescriptor)
 {
+  RPC_STATUS status;
+  RpcServerProtseq* ps;
+
   TRACE("Protseq == %s, MaxCalls == %d, SecurityDescriptor == ^%p)\n", debugstr_w(Protseq), MaxCalls, SecurityDescriptor);
-  return RpcServerUseProtseqEpW(Protseq, MaxCalls, NULL, SecurityDescriptor);
+
+  status = RPCRT4_get_or_create_serverprotseq(MaxCalls, RPCRT4_strdupWtoA(Protseq), &ps);
+  if (status != RPC_S_OK)
+    return status;
+
+  return RPCRT4_use_protseq(ps, NULL);
 }
 
 void RPCRT4_destroy_all_protseqs(void)
index a4589d6f157e88baaf07d92584c696f7b78523eb..a9b98a0d5f557f940d4ee21eb7c22bea57cd47c2 100644 (file)
@@ -46,7 +46,6 @@ static void test_RpcServerUseProtseq(void)
     /* show that RpcServerUseProtseqEp(..., NULL, ...) isn't the same as
      * RpcServerUseProtseq(...) */
     status = RpcServerUseProtseqEp(ncalrpc, 0, NULL, NULL);
-    todo_wine
     ok(status == RPC_S_INVALID_ENDPOINT_FORMAT,
        "RpcServerUseProtseqEp with NULL endpoint should have failed with "
        "RPC_S_INVALID_ENDPOINT_FORMAT instead of %d\n", status);
@@ -111,7 +110,6 @@ static void test_RpcServerUseProtseq(void)
     status = RpcServerInqBindings(&bindings);
     ok(status == RPC_S_OK, "RpcServerInqBindings failed with status %d\n", status);
     binding_count_after2 = bindings->Count;
-    todo_wine
     ok(binding_count_after2 == binding_count_after1,
        "bindings should have been re-used - after1: %u after2: %u\n",
        binding_count_after1, binding_count_after2);