diff --git a/shared/libraries/opcua/opcuaserver/include/opcuaserver/opcuaserver.h b/shared/libraries/opcua/opcuaserver/include/opcuaserver/opcuaserver.h index 1b9002d..6acc486 100644 --- a/shared/libraries/opcua/opcuaserver/include/opcuaserver/opcuaserver.h +++ b/shared/libraries/opcua/opcuaserver/include/opcuaserver/opcuaserver.h @@ -16,11 +16,11 @@ #pragma once +#include #include #include #include #include -#include #include #include @@ -88,6 +88,7 @@ class OpcUaServer final : public daq::utils::ThreadEx void setClientConnectedHandler(const OnClientConnectedCallback& callback); void setClientInfoHandler(const OnSetClientInfoCallback& callback); void setClientDisconnectedHandler(const OnClientDisconnectedCallback& callback); + void scheduleClientInfoChainTask(std::function task); void setAllowBrowsingNodeCallback(const OnAllowBrowsingNodeCallback& callback); void setGetUserRightsMaskCallback(const OnGetUserRightsMaskCallback& callback); void setGetUserAccessLevelCallback(const OnGetUserAccessLevelCallback& callback); @@ -184,6 +185,8 @@ class OpcUaServer final : public daq::utils::ThreadEx const sockaddr_storage& addr, socklen_t addrLen); void waitForPendingClientInfoFutures(); + void processClientInfo(ClientConnectionInfo& info, const sockaddr_storage& addr, socklen_t addrLen); + void continueClientInfoChain(); static UA_StatusCode activateSession(UA_Server* server, UA_AccessControl* ac, @@ -220,8 +223,9 @@ class OpcUaServer final : public daq::utils::ThreadEx OnClientConnectedCallback clientConnectedHandler; OnSetClientInfoCallback clientInfoHandler; OnClientDisconnectedCallback clientDisconnectedHandler; - std::vector> pendingClientInfoFutures; - std::mutex pendingClientInfoFuturesMutex; + std::future clientInfoChain; + std::deque> clientInfoChainWaiting; + std::mutex clientInfoChainMutex; }; END_NAMESPACE_OPENDAQ_OPCUA diff --git a/shared/libraries/opcua/opcuaserver/src/opcuaserver.cpp b/shared/libraries/opcua/opcuaserver/src/opcuaserver.cpp index 0ba5f92..1731e41 100644 --- a/shared/libraries/opcua/opcuaserver/src/opcuaserver.cpp +++ b/shared/libraries/opcua/opcuaserver/src/opcuaserver.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -133,45 +132,72 @@ void OpcUaServer::stop() void OpcUaServer::waitForPendingClientInfoFutures() { - std::lock_guard lock(pendingClientInfoFuturesMutex); - for (auto& future : pendingClientInfoFutures) + std::future chain; { - if (future.valid()) - future.wait(); + std::lock_guard lock(clientInfoChainMutex); + clientInfoChainWaiting.clear(); + chain = std::move(clientInfoChain); } - pendingClientInfoFutures.clear(); + if (chain.valid()) + chain.wait(); } -void OpcUaServer::scheduleClientInfoAsync(ClientConnectionInfo info, - const sockaddr_storage& addr, - socklen_t addrLen) +void OpcUaServer::processClientInfo(ClientConnectionInfo& info, + const sockaddr_storage& addr, + socklen_t addrLen) { const OnSetClientInfoCallback handler = clientInfoHandler; if (!handler) return; - std::lock_guard lock(pendingClientInfoFuturesMutex); - - pendingClientInfoFutures.erase( - std::remove_if(pendingClientInfoFutures.begin(), - pendingClientInfoFutures.end(), - [](std::future& f) { - return f.wait_for(std::chrono::seconds(0)) == std::future_status::ready; - }), - pendingClientInfoFutures.end()); - - pendingClientInfoFutures.push_back( - std::async(std::launch::async, [handler, info = std::move(info), addr, addrLen]() mutable + const auto* sockAddr = reinterpret_cast(&addr); + char ipBuf[NI_MAXHOST] = {}; + char hostBuf[NI_MAXHOST] = {}; + if (getnameinfo(sockAddr, addrLen, ipBuf, sizeof(ipBuf), nullptr, 0, NI_NUMERICHOST) == 0) + info.address = ipBuf; + if (getnameinfo(sockAddr, addrLen, hostBuf, sizeof(hostBuf), nullptr, 0, 0) == 0) + info.hostname = hostBuf; + handler(info); +} + +void OpcUaServer::continueClientInfoChain() +{ + while (true) + { + std::function task; { - const auto* sockAddr = reinterpret_cast(&addr); - char ipBuf[NI_MAXHOST] = {}; - char hostBuf[NI_MAXHOST] = {}; - if (getnameinfo(sockAddr, addrLen, ipBuf, sizeof(ipBuf), nullptr, 0, NI_NUMERICHOST) == 0) - info.address = ipBuf; - if (getnameinfo(sockAddr, addrLen, hostBuf, sizeof(hostBuf), nullptr, 0, 0) == 0) - info.hostname = hostBuf; - handler(info); - })); + std::lock_guard lock(clientInfoChainMutex); + if (clientInfoChainWaiting.empty()) + return; + task = std::move(clientInfoChainWaiting.front()); + clientInfoChainWaiting.pop_front(); + } + task(); + } +} + +void OpcUaServer::scheduleClientInfoChainTask(std::function task) +{ + std::lock_guard lock(clientInfoChainMutex); + const bool chainRunning = !clientInfoChainWaiting.empty() || + (clientInfoChain.valid() && + clientInfoChain.wait_for(std::chrono::seconds(0)) != std::future_status::ready); + clientInfoChainWaiting.push_back(std::move(task)); + if (!chainRunning) + clientInfoChain = std::async(std::launch::async, [this]() { continueClientInfoChain(); }); +} + +void OpcUaServer::scheduleClientInfoAsync(ClientConnectionInfo info, + const sockaddr_storage& addr, + socklen_t addrLen) +{ + if (!clientInfoHandler) + return; + + scheduleClientInfoChainTask([this, info = std::move(info), addr, addrLen]() mutable + { + processClientInfo(info, addr, addrLen); + }); } void OpcUaServer::prepare() diff --git a/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/objects/tms_server_object.h b/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/objects/tms_server_object.h index 5d52eef..314d2af 100644 --- a/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/objects/tms_server_object.h +++ b/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/objects/tms_server_object.h @@ -72,27 +72,27 @@ class TmsServerObject : public std::enable_shared_from_this virtual void createNonhierarchicalReferences(); virtual void onCoreEvent(const CoreEventArgsPtr& eventArgs); - static UA_Boolean allowBrowsingNodeCallback(UA_Server* server, + static UA_Boolean AllowBrowsingNodeCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* nodeId, void* nodeContext); - static UA_UInt32 getUserRightsMaskCallback(UA_Server* server, + static UA_UInt32 GetUserRightsMaskCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* nodeId, void* nodeContext); - static UA_Byte getUserAccessLevelCallback(UA_Server* server, + static UA_Byte GetUserAccessLevelCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* nodeId, void* nodeContext); - static UA_Boolean getUserExecutableCallback(UA_Server* server, + static UA_Boolean GetUserExecutableCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, @@ -110,7 +110,7 @@ class TmsServerObject : public std::enable_shared_from_this return ptr; } - static bool checkPermission(const Permission permission, const UA_NodeId* const nodeId, void* const sessionContext, void* const nodeContext); + static bool CheckPermission(const Permission permission, const UA_NodeId* const nodeId, void* const sessionContext, void* const nodeContext); virtual bool checkPermission(const Permission permission, const UA_NodeId* const nodeId, const OpcUaSession* const sessionContext); std::string readBrowseName(const opcua::OpcUaNodeId& nodeId); diff --git a/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/tms_server.h b/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/tms_server.h index 82ad376..f3aca1b 100644 --- a/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/tms_server.h +++ b/shared/libraries/opcuatms/opcuatms_server/include/opcuatms_server/tms_server.h @@ -38,6 +38,10 @@ class TmsServer void start(); void stop(); +private: + void addConnectedClientInfo(const OpcUaServer::ClientConnectionInfo& connInfo); + void removeConnectedClientInfo(const std::string& clientId); + protected: DevicePtr device; ContextPtr context; diff --git a/shared/libraries/opcuatms/opcuatms_server/src/objects/tms_server_object.cpp b/shared/libraries/opcuatms/opcuatms_server/src/objects/tms_server_object.cpp index 6b9ae24..09553ca 100644 --- a/shared/libraries/opcuatms/opcuatms_server/src/objects/tms_server_object.cpp +++ b/shared/libraries/opcuatms/opcuatms_server/src/objects/tms_server_object.cpp @@ -99,43 +99,43 @@ void TmsServerObject::onCoreEvent(const CoreEventArgsPtr& /*eventArgs*/) { } -UA_Boolean TmsServerObject::allowBrowsingNodeCallback(UA_Server* server, +UA_Boolean TmsServerObject::AllowBrowsingNodeCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* nodeId, void* nodeContext) { - return checkPermission(Permission::Read, nodeId, sessionContext, nodeContext); + return CheckPermission(Permission::Read, nodeId, sessionContext, nodeContext); } -UA_UInt32 TmsServerObject::getUserRightsMaskCallback(UA_Server *server, UA_AccessControl *ac, +UA_UInt32 TmsServerObject::GetUserRightsMaskCallback(UA_Server *server, UA_AccessControl *ac, const UA_NodeId *sessionId, void *sessionContext, const UA_NodeId *nodeId, void *nodeContext) { - return checkPermission(Permission::Write, nodeId, sessionContext, nodeContext) ? 0xFFFFFFFF : 0; + return CheckPermission(Permission::Write, nodeId, sessionContext, nodeContext) ? 0xFFFFFFFF : 0; } -UA_Byte TmsServerObject::getUserAccessLevelCallback( +UA_Byte TmsServerObject::GetUserAccessLevelCallback( UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* nodeId, void* nodeContext) { constexpr UA_Byte readMask = UA_ACCESSLEVELMASK_READ | UA_ACCESSLEVELMASK_HISTORYREAD; constexpr UA_Byte writeMask = UA_ACCESSLEVELMASK_WRITE | UA_ACCESSLEVELMASK_HISTORYWRITE | UA_ACCESSLEVELMASK_SEMANTICCHANGE | UA_ACCESSLEVELMASK_STATUSWRITE | UA_ACCESSLEVELMASK_TIMESTAMPWRITE; UA_Byte mask = 0xFF; - mask = checkPermission(Permission::Read, nodeId, sessionContext, nodeContext) ? (mask | readMask) : (mask & ~readMask); - mask = checkPermission(Permission::Write, nodeId, sessionContext, nodeContext) ? (mask | writeMask) : (mask & ~writeMask); + mask = CheckPermission(Permission::Read, nodeId, sessionContext, nodeContext) ? (mask | readMask) : (mask & ~readMask); + mask = CheckPermission(Permission::Write, nodeId, sessionContext, nodeContext) ? (mask | writeMask) : (mask & ~writeMask); return mask; } -UA_Boolean TmsServerObject::getUserExecutableCallback(UA_Server* server, +UA_Boolean TmsServerObject::GetUserExecutableCallback(UA_Server* server, UA_AccessControl* ac, const UA_NodeId* sessionId, void* sessionContext, const UA_NodeId* methodId, void* methodContext) { - return checkPermission(Permission::Execute, methodId, sessionContext, methodContext); + return CheckPermission(Permission::Execute, methodId, sessionContext, methodContext); } NodeEventManagerPtr TmsServerObject::addEvent(const StringPtr& nodeName) @@ -206,14 +206,14 @@ bool TmsServerObject::hasChildNode(const std::string& nodeName) const return references.count(nodeName) != 0; } -bool TmsServerObject::checkPermission(const Permission permission, +bool TmsServerObject::CheckPermission(const Permission permission, const UA_NodeId* const nodeId, void* const sessionContext, void* const nodeContext) { if (nodeContext == nullptr || sessionContext == nullptr) return true; - return static_cast(nodeContext)->checkPermission(permission, nodeId, static_cast(sessionContext));; + return static_cast(nodeContext)->checkPermission(permission, nodeId, static_cast(sessionContext)); } bool TmsServerObject::checkPermission(const Permission permission, const UA_NodeId* const nodeId, const OpcUaSession* const sessionContext) diff --git a/shared/libraries/opcuatms/opcuatms_server/src/tms_server.cpp b/shared/libraries/opcuatms/opcuatms_server/src/tms_server.cpp index 20e5015..0674b43 100644 --- a/shared/libraries/opcuatms/opcuatms_server/src/tms_server.cpp +++ b/shared/libraries/opcuatms/opcuatms_server/src/tms_server.cpp @@ -61,57 +61,17 @@ void TmsServer::start() registeredClientIds.insert({clientId, 0}); } ); - server->setClientInfoHandler( - [this](const OpcUaServer::ClientConnectionInfo& connInfo) - { - if (!running.load()) - return; - - std::lock_guard lock(connectedClientsMutex); - const auto it = registeredClientIds.find(connInfo.clientId); - if (it == registeredClientIds.end()) - return; - - const auto loggerComponent = context.getLogger().getOrAddComponent("TmsServer"); - LOG_I("Client address resolved, ID: {}, address: {}, hostname: {}", - connInfo.clientId, - connInfo.address, - connInfo.hostname); - - SizeT clientNumber = 0; - if (device.assigned() && !device.isRemoved()) - { - device.getInfo().asPtr().addConnectedClient( - &clientNumber, - ConnectedClientInfo(connInfo.address, - ProtocolType::Configuration, - "OpenDAQOPCUA", - "Control", - connInfo.hostname)); - } - it->second = clientNumber; - } - ); + server->setClientInfoHandler([this](const OpcUaServer::ClientConnectionInfo& connInfo) { addConnectedClientInfo(connInfo); }); server->setClientDisconnectedHandler( [this](const std::string& clientId) { - std::lock_guard lock(connectedClientsMutex); - if (auto it = registeredClientIds.find(clientId); it != registeredClientIds.end()) - { - const auto loggerComponent = context.getLogger().getOrAddComponent("TmsServer"); - LOG_I("Client disconnected, ID: {}", clientId); - if (device.assigned() && !device.isRemoved() && it->second != 0) - { - device.getInfo().asPtr(true).removeConnectedClient(it->second); - } - registeredClientIds.erase(it); - } + server->scheduleClientInfoChainTask([this, clientId]() { removeConnectedClientInfo(clientId); }); } ); - server->setAllowBrowsingNodeCallback(TmsServerObject::allowBrowsingNodeCallback); - server->setGetUserAccessLevelCallback(TmsServerObject::getUserAccessLevelCallback); - server->setGetUserRightsMaskCallback(TmsServerObject::getUserRightsMaskCallback); - server->setGetUserExecutableCallback(TmsServerObject::getUserExecutableCallback); + server->setAllowBrowsingNodeCallback(TmsServerObject::AllowBrowsingNodeCallback); + server->setGetUserAccessLevelCallback(TmsServerObject::GetUserAccessLevelCallback); + server->setGetUserRightsMaskCallback(TmsServerObject::GetUserRightsMaskCallback); + server->setGetUserExecutableCallback(TmsServerObject::GetUserExecutableCallback); server->prepare(); tmsContext = std::make_shared(context, device); @@ -131,6 +91,53 @@ void TmsServer::start() server->start(); } +void TmsServer::addConnectedClientInfo(const OpcUaServer::ClientConnectionInfo& connInfo) +{ + if (!running.load()) + return; + + std::lock_guard lock(connectedClientsMutex); + const auto it = registeredClientIds.find(connInfo.clientId); + if (it == registeredClientIds.end()) + return; + + const auto loggerComponent = context.getLogger().getOrAddComponent("TmsServer"); + LOG_I("Client address resolved, ID: {}, address: {}, hostname: {}", + connInfo.clientId, + connInfo.address, + connInfo.hostname); + + SizeT clientNumber = 0; + if (device.assigned() && !device.isRemoved()) + { + device.getInfo().asPtr().addConnectedClient( + &clientNumber, + ConnectedClientInfo(connInfo.address, + ProtocolType::Configuration, + "OpenDAQOPCUA", + "Control", + connInfo.hostname)); + } + it->second = clientNumber; +} + +void TmsServer::removeConnectedClientInfo(const std::string& clientId) +{ + if (!running.load()) + return; + + std::lock_guard lock(connectedClientsMutex); + const auto it = registeredClientIds.find(clientId); + if (it == registeredClientIds.end()) + return; + + const auto loggerComponent = context.getLogger().getOrAddComponent("TmsServer"); + LOG_I("Client disconnected, ID: {}", clientId); + if (device.assigned() && !device.isRemoved() && it->second != 0) + device.getInfo().asPtr(true).removeConnectedClient(it->second); + registeredClientIds.erase(it); +} + void TmsServer::stop() { running = false; diff --git a/shared/libraries/opcuatms/opcuatms_server/tests/test_tms_user_access.cpp b/shared/libraries/opcuatms/opcuatms_server/tests/test_tms_user_access.cpp index 2dc0d54..6aedd0e 100644 --- a/shared/libraries/opcuatms/opcuatms_server/tests/test_tms_user_access.cpp +++ b/shared/libraries/opcuatms/opcuatms_server/tests/test_tms_user_access.cpp @@ -33,10 +33,10 @@ class TmsServerObjectTestWithParameterizedServer : public TmsServerObjectTest using namespace daq::opcua::tms; server = std::make_shared(); server->setPort(4840); - server->setAllowBrowsingNodeCallback(TmsServerObject::allowBrowsingNodeCallback); - server->setGetUserAccessLevelCallback(TmsServerObject::getUserAccessLevelCallback); - server->setGetUserRightsMaskCallback(TmsServerObject::getUserRightsMaskCallback); - server->setGetUserExecutableCallback(TmsServerObject::getUserExecutableCallback); + server->setAllowBrowsingNodeCallback(TmsServerObject::AllowBrowsingNodeCallback); + server->setGetUserAccessLevelCallback(TmsServerObject::GetUserAccessLevelCallback); + server->setGetUserRightsMaskCallback(TmsServerObject::GetUserRightsMaskCallback); + server->setGetUserExecutableCallback(TmsServerObject::GetUserExecutableCallback); server->setAuthenticationProvider(StaticAuthenticationProvider(true, test_helpers::CreateUsers())); server->start(); diff --git a/shared/libraries/opcuatms/tests/opcuatms_integration/test_tms_integration.cpp b/shared/libraries/opcuatms/tests/opcuatms_integration/test_tms_integration.cpp index bdfe9fe..08ef73f 100644 --- a/shared/libraries/opcuatms/tests/opcuatms_integration/test_tms_integration.cpp +++ b/shared/libraries/opcuatms/tests/opcuatms_integration/test_tms_integration.cpp @@ -5,8 +5,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -664,3 +664,52 @@ TEST_F(TmsIntegrationTest, ConnectedClientInfoHasAddressAndHostname) ASSERT_FALSE(opcuaClient.getAddress().getLength() == 0) << "Client address must not be empty"; ASSERT_FALSE(opcuaClient.getHostName().getLength() == 0) << "Client hostname must not be empty"; } + +TEST_F(TmsIntegrationTest, RapidMultiThreadClientConnectDisconnectLeavesNoClientInfo) +{ + InstancePtr device = createDevice(); + + TmsServer tmsServer(device); + tmsServer.start(); + + const std::string opcUrl = OPC_URL; + constexpr int threadCount = 10; + constexpr int connectsPerThread = 2; + + std::vector threads; + threads.reserve(threadCount); + for (int t = 0; t < threadCount; ++t) + { + threads.emplace_back([opcUrl, connectsPerThread]() + { + const auto moduleManager = ModuleManager("[[none]]"); + auto logger = Logger(); + auto context = Context(Scheduler(logger, 1), logger, TypeManager(), moduleManager, nullptr); + + for (int i = 0; i < connectsPerThread; ++i) + { + try + { + TmsClient tmsClient(context, nullptr, opcUrl); + const DevicePtr clientDevice = tmsClient.connect(); + } + catch (...) + { + return; + } + } + }); + } + + for (auto& thread : threads) + thread.join(); + + size_t opcuaClientCount = 0; + for (const auto& client : device.getInfo().getConnectedClientsInfo()) + { + if (client.getProtocolName() == "OpenDAQOPCUA") + ++opcuaClientCount; + } + + ASSERT_EQ(opcuaClientCount, 0u) << "Connected OpenDAQOPCUA client info must be empty after all clients disconnected"; +} diff --git a/shared/libraries/opcuatms/tests/opcuatms_integration/tms_object_integration_test.cpp b/shared/libraries/opcuatms/tests/opcuatms_integration/tms_object_integration_test.cpp index 8ebad4c..ef5e174 100644 --- a/shared/libraries/opcuatms/tests/opcuatms_integration/tms_object_integration_test.cpp +++ b/shared/libraries/opcuatms/tests/opcuatms_integration/tms_object_integration_test.cpp @@ -18,10 +18,10 @@ void TmsObjectIntegrationTest::Init() { server = std::make_shared(); server->setPort(4840); - server->setAllowBrowsingNodeCallback(TmsServerObject::allowBrowsingNodeCallback); - server->setGetUserAccessLevelCallback(TmsServerObject::getUserAccessLevelCallback); - server->setGetUserRightsMaskCallback(TmsServerObject::getUserRightsMaskCallback); - server->setGetUserExecutableCallback(TmsServerObject::getUserExecutableCallback); + server->setAllowBrowsingNodeCallback(TmsServerObject::AllowBrowsingNodeCallback); + server->setGetUserAccessLevelCallback(TmsServerObject::GetUserAccessLevelCallback); + server->setGetUserRightsMaskCallback(TmsServerObject::GetUserRightsMaskCallback); + server->setGetUserExecutableCallback(TmsServerObject::GetUserExecutableCallback); server->setAuthenticationProvider(StaticAuthenticationProvider(true, test_helpers::CreateUsers())); server->start(); client = CreateAndConnectTestClient();