diff --git a/include/znc/Client.h b/include/znc/Client.h index 04316b67..54dae022 100644 --- a/include/znc/Client.h +++ b/include/znc/Client.h @@ -227,7 +227,11 @@ class CClient : public CIRCSocket { /** Notifies client about one specific cap which server has just notified us about. */ - void NotifyServerDependentCap(const CString& sCap, bool bValue); + void NotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue, + const std::function& handler); + /** Notifies client if the cap is server-dependent, otherwise noop. + */ + void PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue); /** Notifies client that all these caps are now available. * * This function will internally filter only those which are server-dependent. diff --git a/include/znc/IRCNetwork.h b/include/znc/IRCNetwork.h index 73c907d2..e98aa171 100644 --- a/include/znc/IRCNetwork.h +++ b/include/znc/IRCNetwork.h @@ -156,7 +156,9 @@ class CIRCNetwork : private CCoreTranslationMixin { void IRCConnected(); void IRCDisconnected(); void CheckIRCConnect(); - void NotifyServerDependentCap(const CString& sCap, bool bValue); + void PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue); + void NotifyClientsAboutServerDependentCap(const CString& sCap, bool bValue, const std::function& handler); + bool IsServerCapAccepted(const CString& sCap) const; bool PutIRC(const CString& sLine); bool PutIRC(const CMessage& Message); diff --git a/include/znc/IRCSock.h b/include/znc/IRCSock.h index e40448c5..b9ff5572 100644 --- a/include/znc/IRCSock.h +++ b/include/znc/IRCSock.h @@ -162,6 +162,8 @@ class CIRCSock : public CIRCSocket { bool IsCapAccepted(const CString& sCap) { return 1 == m_ssAcceptedCaps.count(sCap); } + CString GetCapLsValue(const CString& sKey, + const CString& sDefault = "") const; const MCString& GetISupport() const { return m_mISupport; } CString GetISupport(const CString& sKey, const CString& sDefault = "") const; @@ -223,6 +225,7 @@ class CIRCSock : public CIRCSocket { unsigned int m_uCapPaused; SCString m_ssAcceptedCaps; SCString m_ssPendingCaps; + MCString m_msCapLsValues; time_t m_lastCTCP; unsigned int m_uNumCTCP; static const time_t m_uCTCPFloodTime; diff --git a/src/Client.cpp b/src/Client.cpp index b72721aa..443045a1 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -790,7 +790,7 @@ void CClient::HandleCap(const CMessage& Message) { CString sSubCmd = Message.GetParam(0); if (sSubCmd.Equals("LS")) { - m_uCapVersion = Message.GetParam(1).ToInt(); + m_uCapVersion = std::max(m_uCapVersion, Message.GetParam(1).ToUShort()); SCString ssOfferCaps; for (const auto& it : CoreCaps()) { bool bServerDependent = std::get<0>(it.second); @@ -798,7 +798,7 @@ void CClient::HandleCap(const CMessage& Message) { m_ssServerDependentCaps.count(it.first) > 0) ssOfferCaps.insert(it.first); } - GLOBALMODULECALL(OnClientCapLs(this, ssOfferCaps), NOTHING); + NETWORKMODULECALL(OnClientCapLs(this, ssOfferCaps), GetUser(), GetNetwork(), this, NOTHING); VCString vsCaps = MultiLine(ssOfferCaps); m_bInCap = true; if (HasCap302()) { @@ -830,13 +830,13 @@ void CClient::HandleCap(const CMessage& Message) { if (sCap.TrimPrefix("-")) bVal = false; bool bAccepted = false; - const auto& it = CoreCaps().find(sCap); + auto it = CoreCaps().find(sCap); if (CoreCaps().end() != it) { bool bServerDependent = std::get<0>(it->second); bAccepted = !bServerDependent || m_ssServerDependentCaps.count(sCap) > 0; } - GLOBALMODULECALL(IsClientCapSupported(this, sCap, bVal), + NETWORKMODULECALL(IsClientCapSupported(this, sCap, bVal), GetUser(), GetNetwork(), this, &bAccepted); if (!bAccepted) { @@ -857,7 +857,7 @@ void CClient::HandleCap(const CMessage& Message) { const auto& handler = std::get<1>(handler_it->second); handler(this, bVal); } - GLOBALMODULECALL(OnClientCapRequest(this, sCap, bVal), NOTHING); + NETWORKMODULECALL(OnClientCapRequest(this, sCap, bVal), GetUser(), GetNetwork(), this, NOTHING); if (bVal) { m_ssAcceptedCaps.insert(sCap); @@ -936,8 +936,38 @@ void CClient::SetTagSupport(const CString& sTag, bool bState) { } } -void CClient::NotifyServerDependentCap(const CString& sCap, bool bValue) { +void CClient::NotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue, + const std::function& handler) { + if (bValue) { + if (HasCapNotify()) { + if (HasCap302() && !sValue.empty()) { + PutClient(":irc.znc.in CAP " + GetNick() + " NEW :" + sCap + "=" + sValue); + } else { + PutClient(":irc.znc.in CAP " + GetNick() + " NEW :" + sCap); + } + } + } else { + if (HasCapNotify()) { + PutClient(":irc.znc.in CAP " + GetNick() + " DEL :" + sCap); + } + handler(this, false); + m_ssAcceptedCaps.erase(sCap); + } +} + +void CClient::PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue) { auto it = CoreCaps().find(sCap); + if (CoreCaps().end() != it) { + const auto& [bServerDependent, handler] = it->second; + if (bServerDependent) { + NotifyServerDependentCap(sCap, bValue, sValue, handler); + } + } + if (!bValue) { + m_ssServerDependentCaps.erase(sCap); + } + return; + if (bValue) { if (CoreCaps().end() != it) { bool bServerDependent = std::get<0>(it->second); diff --git a/src/IRCNetwork.cpp b/src/IRCNetwork.cpp index b4f33618..009f39b3 100644 --- a/src/IRCNetwork.cpp +++ b/src/IRCNetwork.cpp @@ -1417,12 +1417,24 @@ void CIRCNetwork::IRCDisconnected() { CheckIRCConnect(); } -void CIRCNetwork::NotifyServerDependentCap(const CString& sCap, bool bValue) { +void CIRCNetwork::PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue) { + CString sValue = GetIRCSock() ? GetIRCSock()->GetCapLsValue(sCap) : ""; for (CClient* pClient : m_vClients) { - pClient->NotifyServerDependentCap(sCap, bValue); + pClient->PotentiallyNotifyServerDependentCap(sCap, bValue, sValue); } } +void CIRCNetwork::NotifyClientsAboutServerDependentCap(const CString& sCap, bool bValue, const std::function& handler) { + CString sValue = GetIRCSock() ? GetIRCSock()->GetCapLsValue(sCap) : ""; + for (CClient* pClient : m_vClients) { + pClient->NotifyServerDependentCap(sCap, bValue, sValue, handler); + } +} + +bool CIRCNetwork::IsServerCapAccepted(const CString& sCap) const { + return m_pIRCSock && m_pIRCSock->IsCapAccepted(sCap); +} + void CIRCNetwork::SetIRCConnectEnabled(bool b) { m_bIRCConnectEnabled = b; diff --git a/src/IRCSock.cpp b/src/IRCSock.cpp index 825fbc9a..d91ff565 100644 --- a/src/IRCSock.cpp +++ b/src/IRCSock.cpp @@ -398,7 +398,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) { m_ssAcceptedCaps.erase(sCap); m_ssPendingCaps.erase(sCap); if (m_bAuthed) { - m_pNetwork->NotifyServerDependentCap(sCap, false); + m_pNetwork->PotentiallyNotifyServerDependentCap(sCap, false); } }; @@ -415,6 +415,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) { sCap = sToken.substr(0, eq); sValue = sToken.substr(eq + 1); } + m_msCapLsValues[sCap] = sValue; if (OnServerCapAvailable(sCap, sValue) || mSupportedCaps.count(sCap)) { m_ssPendingCaps.insert(sCap); } @@ -428,7 +429,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) { } m_ssAcceptedCaps.insert(sArgs); if (m_bAuthed) { - m_pNetwork->NotifyServerDependentCap(sArgs, true); + m_pNetwork->PotentiallyNotifyServerDependentCap(sArgs, true); } } else if (sSubCmd == "NAK") { // This should work because there's no [known] @@ -441,6 +442,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) { for (const CString& sCap : vsTokens) { RemoveCap(sCap); + m_msCapLsValues.erase(sCap); } } @@ -1448,6 +1450,16 @@ CString CIRCSock::GetISupport(const CString& sKey, } } +CString CIRCSock::GetCapLsValue(const CString& sKey, + const CString& sDefault) const { + MCString::const_iterator i = m_msCapLsValues.find(sKey); + if (i == m_msCapLsValues.end()) { + return sDefault; + } else { + return i->second; + } +} + void CIRCSock::SendAltNick(const CString& sBadNick) { const CString& sLastNick = m_Nick.GetNick(); diff --git a/test/integration/tests/core.cpp b/test/integration/tests/core.cpp index 9834c394..979dc069 100644 --- a/test/integration/tests/core.cpp +++ b/test/integration/tests/core.cpp @@ -534,6 +534,117 @@ TEST_F(ZNCTest, CAP302LSValue) { client2.ReadUntil("testcap="); } +TEST_F(ZNCTest, ServerDependentCapInModule) { + auto znc = Run(); + auto ircd = ConnectIRCd(); + auto client = LoginClient(); + InstallModule("testmod.cpp", R"( + #include + #include + #include + #include + class TestModule : public CModule { + public: + MODCONSTRUCTOR(TestModule) {} + void OnClientCapLs(CClient* pClient, SCString& ssCaps) override { + if (GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap")) { + if (pClient->HasCap302()) { + CString sValue = GetNetwork()->GetIRCSock()->GetCapLsValue("testcap"); + if (!sValue.empty()) { + ssCaps.insert("testcap=" + sValue); + } else { + ssCaps.insert("testcap"); + } + } else { + ssCaps.insert("testcap"); + } + } + } + bool IsClientCapSupported(CClient* pClient, const CString& sCap, + bool bState) override { + if (!bState) return false; + if (sCap != "testcap") return false; + return GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap"); + } + void OnClientCapRequest(CClient* pClient, const CString& sCap, + bool bState) override { + PutModule("OnClientCapRequest " + sCap + " " + CString(bState)); + } + bool OnServerCap302Available(const CString& sCap, const CString& sValue) override { + PutModule("OnServerCapAvailable " + sCap + " " + sValue); + return sCap == "testcap"; + } + void OnServerCapResult(const CString& sCap, bool bSuccess) override { + if (sCap == "testcap") { + GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", bSuccess, [=](CClient* pClient, bool bState) { + PutModule("OnServerCapResult " + sCap + " " + CString(bSuccess) + " " + CString(bState)); + }); + } + } + void OnIRCConnected() override { + if (GetNetwork()->IsServerCapAccepted("testcap")) { + GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", true, [=](CClient* pClient, bool bState) { + PutModule("OnIRCConnected " + CString(bState)); + }); + } + } + void OnIRCDisconnected() override { + GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { + PutModule("OnIRCDisconnected " + CString(bState)); + }); + } + ~TestModule() override { + // TODO user module + GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) { + PutModule("~ " + CString(bState)); + }); + } + }; + MODULEDEFS(TestModule, "Test") + )"); + client.Write("znc loadmod testmod"); + client.ReadUntil("Loaded module testmod"); + client.Close(); + + client = ConnectClient(); + client.Write("CAP LS 302"); + client.Write("PASS :hunter2"); + client.Write("NICK nick"); + client.Write("USER user x x :x"); + client.Write("CAP END"); + client.ReadUntil("Welcome"); + + ircd.Write("001 nick Welcome"); + ircd.Write("CAP nick NEW testcap=value"); + ircd.ReadUntil("CAP REQ :testcap"); + ircd.Write("CAP nick ACK :testcap"); + client.ReadUntil("CAP nick NEW :testcap=value"); + client.Write("CAP REQ testcap"); + client.ReadUntil("CAP nick ACK :testcap"); + + client.Write("CAP LS"); + client.ReadUntil(" testcap=value "); + + ircd.Write("CAP nick DEL testcap"); + client.ReadUntil("CAP nick DEL :testcap"); + client.ReadUntil(":OnServerCapResult testcap false false"); + + ircd.Close(); + ircd = ConnectIRCd(); + ircd.ReadUntil("CAP LS 302"); + ircd.Write("CAP nick LS :testcap=new"); + ircd.ReadUntil("CAP REQ :testcap"); + ircd.Write("CAP nick ACK :testcap"); + ircd.ReadUntil("CAP END"); + // TODO should NEW wait until 001? + // TODO combine multiple NEWs to single line + client.ReadUntil("CAP nick NEW :testcap=new"); + ircd.ReadUntil("001 nick Welcome"); + client.ReadUntil("Welcome2"); + // TODO NEW with new value without DEL + // TODO client.Write("jumpnetwork"); +} + TEST_F(ZNCTest, HashUpgrade) { QFile conf(m_dir.path() + "/configs/znc.conf"); ASSERT_TRUE(conf.open(QIODevice::Append | QIODevice::Text));