diff --git a/include/znc/Modules.h b/include/znc/Modules.h index bad32c1f..ca251d24 100644 --- a/include/znc/Modules.h +++ b/include/znc/Modules.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -166,6 +167,19 @@ class CFPTimer; class CSockManager; // !Forward Declarations +class CCapability { + public: + virtual ~CCapability() = default; + virtual void OnServerChangedSupport(CIRCNetwork* pNetwork, bool bState) {} + virtual void OnClientChangedSupport(CClient* pClient, bool bState) {} + + CModule* GetModule() { return m_pModule; } + void SetModule(CModule* p) { m_pModule = p; } + + protected: + CModule* m_pModule = nullptr; +}; + class CTimer : public CCron { public: CTimer(CModule* pModule, unsigned int uInterval, unsigned int uCycles, @@ -1292,10 +1306,15 @@ class CModule { virtual EModRet OnUnknownUserRaw(CClient* pClient, CString& sLine); virtual EModRet OnUnknownUserRawMessage(CMessage& Message); - /** Called after login, upon disconnect, and also during JumpNetwork. */ + /** Called after login, and also during JumpNetwork. */ virtual void OnClientAttached(); + /** Called upon disconnect, and also during JumpNetwork. */ virtual void OnClientDetached(); +#ifndef SWIG + void AddCapability(const CString& sName, std::unique_ptr pCap); +#endif + /** Called when a client told us CAP LS. Use ssCaps.insert("cap-name") * for announcing capabilities which your module supports. * @param pClient The client which requested the list. @@ -1391,6 +1410,7 @@ class CModule { CString m_sArgs; CString m_sModPath; CTranslationDomainRefHolder m_Translation; + std::map> m_mCaps; private: MCString diff --git a/src/Modules.cpp b/src/Modules.cpp index 35a8d970..25b1e351 100644 --- a/src/Modules.cpp +++ b/src/Modules.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -160,6 +161,37 @@ CModule::CModule(ModHandle pDLL, CUser* pUser, CIRCNetwork* pNetwork, } CModule::~CModule() { + for (const auto& [sName, pCap] : m_mCaps) { + switch (GetType()) { + case CModInfo::NetworkModule: + GetNetwork()->NotifyClientsAboutServerDependentCap( + sName, false, [&](CClient* pClient, bool bState) {}); + for (CClient* pClient : GetNetwork()->GetClients()) { + pCap->OnClientChangedSupport(pClient, false); + } + break; + case CModInfo::UserModule: + for (CIRCNetwork* pNetwork : GetUser()->GetNetworks()) { + pNetwork->NotifyClientsAboutServerDependentCap( + sName, false, [&](CClient* pClient, bool bState) {}); + for (CClient* pClient : pNetwork->GetClients()) { + pCap->OnClientChangedSupport(pClient, false); + } + } + break; + case CModInfo::GlobalModule: + for (auto& [_, pUser] : CZNC::Get().GetUserMap()) { + for (CIRCNetwork* pNetwork : pUser->GetNetworks()) { + pNetwork->NotifyClientsAboutServerDependentCap( + sName, false, [&](CClient* pClient, bool bState) {}); + for (CClient* pClient : pNetwork->GetClients()) { + pCap->OnClientChangedSupport(pClient, false); + } + } + } + } + } + while (!m_sTimers.empty()) { RemTimer(*m_sTimers.begin()); } @@ -605,8 +637,22 @@ bool CModule::OnLoad(const CString& sArgs, CString& sMessage) { bool CModule::OnBoot() { return true; } void CModule::OnPreRehash() {} void CModule::OnPostRehash() {} -void CModule::OnIRCDisconnected() {} -void CModule::OnIRCConnected() {} +void CModule::OnIRCDisconnected() { + for (const auto& [sName, pCap] : m_mCaps) { + GetNetwork()->NotifyClientsAboutServerDependentCap( + sName, false, [&](CClient* pClient, bool bState) {}); + for (CClient* pClient : GetNetwork()->GetClients()) { + pCap->OnClientChangedSupport(pClient, false); + } + } +} +void CModule::OnIRCConnected() { + for (const auto& [sName, pCap] : m_mCaps) { + if (GetNetwork()->IsServerCapAccepted(sName)) { + GetNetwork()->NotifyClientsAboutServerDependentCap(sName, true, [](CClient* pClient, bool bState){}); + } + } +} CModule::EModRet CModule::OnIRCConnecting(CIRCSock* IRCSock) { return CONTINUE; } @@ -997,14 +1043,48 @@ CModule::EModRet CModule::OnSendToIRC(CString& sLine) { return CONTINUE; } CModule::EModRet CModule::OnSendToIRCMessage(CMessage& Message) { return CONTINUE; } -void CModule::OnClientAttached() {} -void CModule::OnClientDetached() {} +void CModule::OnClientAttached() { + if (!GetNetwork()) return; + for (const auto& [sName, pCap] : m_mCaps) { + if (GetNetwork()->IsServerCapAccepted(sName)) { + GetClient()->NotifyServerDependentCap(sName, true, GetNetwork()->GetIRCSock()->GetCapLsValue(sName), nullptr); + } + } +} +void CModule::OnClientDetached() { + for (const auto& [sName, pCap] : m_mCaps) { + GetClient()->NotifyServerDependentCap(sName, false, "", + [](CClient*, bool) {}); + pCap->OnClientChangedSupport(GetClient(), false); + } +} bool CModule::OnServerCapAvailable(const CString& sCap) { return false; } -bool CModule::OnServerCap302Available(const CString& sCap, const CString& sValue) { - return OnServerCapAvailable(sCap); +bool CModule::OnServerCap302Available(const CString& sCap, + const CString& sValue) { + auto it = m_mCaps.find(sCap); + if (it == m_mCaps.end()) return OnServerCapAvailable(sCap); + if (GetNetwork()->IsServerCapAccepted(sCap)) { + // This can happen when server sent CAP NEW with another value. + GetNetwork()->NotifyClientsAboutServerDependentCap(sCap, true, nullptr); + // It's enabled already, no need to REQ it again. + return false; + } + return true; +} +void CModule::OnServerCapResult(const CString& sCap, bool bSuccess) { + auto it = m_mCaps.find(sCap); + if (it == m_mCaps.end()) return; + it->second->OnServerChangedSupport(GetNetwork(), bSuccess); + if (GetNetwork()->GetIRCSock()->IsAuthed()) { + GetNetwork()->NotifyClientsAboutServerDependentCap( + sCap, bSuccess, [&](CClient* pClient, bool bState) {}); + if (!bSuccess) + for (CClient* pClient : GetNetwork()->GetClients()) { + it->second->OnClientChangedSupport(pClient, false); + } + } } -void CModule::OnServerCapResult(const CString& sCap, bool bSuccess) {} bool CModule::PutIRC(const CString& sLine) { return m_pNetwork ? m_pNetwork->PutIRC(sLine) : false; @@ -1073,13 +1153,36 @@ CModule::EModRet CModule::OnUnknownUserRaw(CClient* pClient, CString& sLine) { CModule::EModRet CModule::OnUnknownUserRawMessage(CMessage& Message) { return CONTINUE; } -void CModule::OnClientCapLs(CClient* pClient, SCString& ssCaps) {} +void CModule::OnClientCapLs(CClient* pClient, SCString& ssCaps) { + for (const auto& [sName, pCap] : m_mCaps) { + if (GetNetwork() && GetNetwork()->IsServerCapAccepted(sName)) { + if (pClient->HasCap302()) { + CString sValue = + GetNetwork()->GetIRCSock()->GetCapLsValue(sName); + if (!sValue.empty()) { + ssCaps.insert(sName + '=' + sValue); + } else { + ssCaps.insert(sName); + } + } else { + ssCaps.insert(sName); + } + } + } +} bool CModule::IsClientCapSupported(CClient* pClient, const CString& sCap, bool bState) { - return false; + auto it = m_mCaps.find(sCap); + if (it == m_mCaps.end()) return false; + if (!bState) return true; + return GetNetwork() && GetNetwork()->IsServerCapAccepted(sCap); } void CModule::OnClientCapRequest(CClient* pClient, const CString& sCap, - bool bState) {} + bool bState) { + auto it = m_mCaps.find(sCap); + if (it == m_mCaps.end()) return; + it->second->OnClientChangedSupport(pClient, bState); +} CModule::EModRet CModule::OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, @@ -1097,6 +1200,11 @@ CModule::EModRet CModule::OnGetModInfo(CModInfo& ModInfo, } void CModule::OnGetAvailableMods(set& ssMods, CModInfo::EModuleType eType) {} +void CModule::AddCapability(const CString& sName, + std::unique_ptr pCap) { + pCap->SetModule(this); + m_mCaps[sName] = std::move(pCap); +} CModules::CModules() : m_pUser(nullptr), m_pNetwork(nullptr), m_pClient(nullptr) {} diff --git a/test/integration/tests/core.cpp b/test/integration/tests/core.cpp index 078fc3fb..bde99d43 100644 --- a/test/integration/tests/core.cpp +++ b/test/integration/tests/core.cpp @@ -540,105 +540,20 @@ TEST_F(ZNCTest, ServerDependentCapInModule) { auto client = LoginClient(); InstallModule("testmod.cpp", R"( #include - #include - #include - #include - #include class TestModule : public CModule { + class TestCap : public CCapability { + public: + using CCapability::CCapability; + void OnServerChangedSupport(CIRCNetwork* pNetwork, bool bState) override { + GetModule()->PutModule("Server changed support: " + CString(bState)); + } + void OnClientChangedSupport(CClient* pClient, bool bState) override { + GetModule()->PutModule("Client changed support: " + CString(bState)); + } + }; public: - MODCONSTRUCTOR(TestModule) {} - void OnClientCapLs(CClient* pClient, SCString& ssCaps) override { - if (GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap")) { - if (pClient->HasCap302()) { - CString sValue = GetNetwork()->GetIRCSock()->GetCapLsValue("testcap"); - if (!sValue.empty()) { - ssCaps.insert("testcap=" + sValue); - } else { - ssCaps.insert("testcap"); - } - } else { - ssCaps.insert("testcap"); - } - } - } - bool IsClientCapSupported(CClient* pClient, const CString& sCap, - bool bState) override { - if (!bState) return false; - if (sCap != "testcap") return false; - return GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap"); - } - void OnClientCapRequest(CClient* pClient, const CString& sCap, - bool bState) override { - PutModule("OnClientCapRequest " + sCap + " " + CString(bState)); - } - bool OnServerCap302Available(const CString& sCap, const CString& sValue) override { - PutModule("OnServerCapAvailable " + sCap + " " + sValue); - if (sCap == "testcap") { - if (GetNetwork()->IsServerCapAccepted("testcap")) { - // This can happen when server sent CAP NEW with another value. - for (CClient* pClient : GetNetwork()->GetClients()) { - pClient->NotifyServerDependentCap("testcap", true, sValue, nullptr); - } - } - return true; - } - return false; - } - void OnServerCapResult(const CString& sCap, bool bSuccess) override { - if (sCap == "testcap") { - PutModule("OnServerCapResult " + sCap + " " + CString(bSuccess)); - if (GetNetwork()->GetIRCSock()->IsAuthed()) { - GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", bSuccess, [=](CClient* pClient, bool bState) { - PutModule("OnServerCapResult " + sCap + " " + CString(bSuccess) + " " + CString(bState)); - }); - } - } - } - void OnIRCConnected() override { - if (GetNetwork()->IsServerCapAccepted("testcap")) { - PutModule("OnIRCConnected"); - GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", true, [=](CClient* pClient, bool bState) { - PutModule("OnIRCConnected " + CString(bState)); - }); - } - } - void OnIRCDisconnected() override { - GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { - PutModule("OnIRCDisconnected " + CString(bState)); - }); - } - void OnClientAttached() override { - if (!GetNetwork()) return; - if (GetNetwork()->IsServerCapAccepted("testcap")) { - GetClient()->NotifyServerDependentCap("testcap", true, GetNetwork()->GetIRCSock()->GetCapLsValue("testcap"), nullptr); - } - } - void OnClientDetached() override { - GetClient()->NotifyServerDependentCap("testcap", false, "", [](CClient*, bool) {}); - } - ~TestModule() override { - switch (GetType()) { - case CModInfo::NetworkModule: - GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { - PutModule("~ " + CString(bState)); - }); - break; - case CModInfo::UserModule: - for (CIRCNetwork* pNetwork : GetUser()->GetNetworks()) { - pNetwork->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { - PutModule("~ " + CString(bState)); - }); - } - break; - case CModInfo::GlobalModule: - for (auto& [_, pUser] : CZNC::Get().GetUserMap()) { - for (CIRCNetwork* pNetwork : pUser->GetNetworks()) { - pNetwork->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { - PutModule("~ " + CString(bState)); - }); - } - } - } + MODCONSTRUCTOR(TestModule) { + AddCapability("testcap", std::make_unique()); } }; MODULEDEFS(TestModule, "Test") @@ -668,8 +583,8 @@ TEST_F(ZNCTest, ServerDependentCapInModule) { client.ReadUntil(" testcap=value "); ircd.Write("CAP nick DEL testcap"); + client.ReadUntil(":Server changed support: false"); client.ReadUntil("CAP nick DEL :testcap"); - client.ReadUntil(":OnServerCapResult testcap false false"); ircd.Close(); ircd = ConnectIRCd();