This commit is contained in:
Alexey Sokolov
2024-01-05 00:45:41 +00:00
parent 4cbccac707
commit 66137bd89a
7 changed files with 186 additions and 12 deletions

View File

@@ -227,7 +227,11 @@ class CClient : public CIRCSocket {
/** Notifies client about one specific cap which server has just notified us about.
*/
void NotifyServerDependentCap(const CString& sCap, bool bValue);
void NotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue,
const std::function<void(CClient*, bool)>& handler);
/** Notifies client if the cap is server-dependent, otherwise noop.
*/
void PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue);
/** Notifies client that all these caps are now available.
*
* This function will internally filter only those which are server-dependent.

View File

@@ -156,7 +156,9 @@ class CIRCNetwork : private CCoreTranslationMixin {
void IRCConnected();
void IRCDisconnected();
void CheckIRCConnect();
void NotifyServerDependentCap(const CString& sCap, bool bValue);
void PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue);
void NotifyClientsAboutServerDependentCap(const CString& sCap, bool bValue, const std::function<void(CClient*, bool)>& handler);
bool IsServerCapAccepted(const CString& sCap) const;
bool PutIRC(const CString& sLine);
bool PutIRC(const CMessage& Message);

View File

@@ -162,6 +162,8 @@ class CIRCSock : public CIRCSocket {
bool IsCapAccepted(const CString& sCap) {
return 1 == m_ssAcceptedCaps.count(sCap);
}
CString GetCapLsValue(const CString& sKey,
const CString& sDefault = "") const;
const MCString& GetISupport() const { return m_mISupport; }
CString GetISupport(const CString& sKey,
const CString& sDefault = "") const;
@@ -223,6 +225,7 @@ class CIRCSock : public CIRCSocket {
unsigned int m_uCapPaused;
SCString m_ssAcceptedCaps;
SCString m_ssPendingCaps;
MCString m_msCapLsValues;
time_t m_lastCTCP;
unsigned int m_uNumCTCP;
static const time_t m_uCTCPFloodTime;

View File

@@ -790,7 +790,7 @@ void CClient::HandleCap(const CMessage& Message) {
CString sSubCmd = Message.GetParam(0);
if (sSubCmd.Equals("LS")) {
m_uCapVersion = Message.GetParam(1).ToInt();
m_uCapVersion = std::max(m_uCapVersion, Message.GetParam(1).ToUShort());
SCString ssOfferCaps;
for (const auto& it : CoreCaps()) {
bool bServerDependent = std::get<0>(it.second);
@@ -798,7 +798,7 @@ void CClient::HandleCap(const CMessage& Message) {
m_ssServerDependentCaps.count(it.first) > 0)
ssOfferCaps.insert(it.first);
}
GLOBALMODULECALL(OnClientCapLs(this, ssOfferCaps), NOTHING);
NETWORKMODULECALL(OnClientCapLs(this, ssOfferCaps), GetUser(), GetNetwork(), this, NOTHING);
VCString vsCaps = MultiLine(ssOfferCaps);
m_bInCap = true;
if (HasCap302()) {
@@ -830,13 +830,13 @@ void CClient::HandleCap(const CMessage& Message) {
if (sCap.TrimPrefix("-")) bVal = false;
bool bAccepted = false;
const auto& it = CoreCaps().find(sCap);
auto it = CoreCaps().find(sCap);
if (CoreCaps().end() != it) {
bool bServerDependent = std::get<0>(it->second);
bAccepted = !bServerDependent ||
m_ssServerDependentCaps.count(sCap) > 0;
}
GLOBALMODULECALL(IsClientCapSupported(this, sCap, bVal),
NETWORKMODULECALL(IsClientCapSupported(this, sCap, bVal), GetUser(), GetNetwork(), this,
&bAccepted);
if (!bAccepted) {
@@ -857,7 +857,7 @@ void CClient::HandleCap(const CMessage& Message) {
const auto& handler = std::get<1>(handler_it->second);
handler(this, bVal);
}
GLOBALMODULECALL(OnClientCapRequest(this, sCap, bVal), NOTHING);
NETWORKMODULECALL(OnClientCapRequest(this, sCap, bVal), GetUser(), GetNetwork(), this, NOTHING);
if (bVal) {
m_ssAcceptedCaps.insert(sCap);
@@ -936,8 +936,38 @@ void CClient::SetTagSupport(const CString& sTag, bool bState) {
}
}
void CClient::NotifyServerDependentCap(const CString& sCap, bool bValue) {
void CClient::NotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue,
const std::function<void(CClient*, bool)>& handler) {
if (bValue) {
if (HasCapNotify()) {
if (HasCap302() && !sValue.empty()) {
PutClient(":irc.znc.in CAP " + GetNick() + " NEW :" + sCap + "=" + sValue);
} else {
PutClient(":irc.znc.in CAP " + GetNick() + " NEW :" + sCap);
}
}
} else {
if (HasCapNotify()) {
PutClient(":irc.znc.in CAP " + GetNick() + " DEL :" + sCap);
}
handler(this, false);
m_ssAcceptedCaps.erase(sCap);
}
}
void CClient::PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue, const CString& sValue) {
auto it = CoreCaps().find(sCap);
if (CoreCaps().end() != it) {
const auto& [bServerDependent, handler] = it->second;
if (bServerDependent) {
NotifyServerDependentCap(sCap, bValue, sValue, handler);
}
}
if (!bValue) {
m_ssServerDependentCaps.erase(sCap);
}
return;
if (bValue) {
if (CoreCaps().end() != it) {
bool bServerDependent = std::get<0>(it->second);

View File

@@ -1417,12 +1417,24 @@ void CIRCNetwork::IRCDisconnected() {
CheckIRCConnect();
}
void CIRCNetwork::NotifyServerDependentCap(const CString& sCap, bool bValue) {
void CIRCNetwork::PotentiallyNotifyServerDependentCap(const CString& sCap, bool bValue) {
CString sValue = GetIRCSock() ? GetIRCSock()->GetCapLsValue(sCap) : "";
for (CClient* pClient : m_vClients) {
pClient->NotifyServerDependentCap(sCap, bValue);
pClient->PotentiallyNotifyServerDependentCap(sCap, bValue, sValue);
}
}
void CIRCNetwork::NotifyClientsAboutServerDependentCap(const CString& sCap, bool bValue, const std::function<void(CClient*, bool)>& handler) {
CString sValue = GetIRCSock() ? GetIRCSock()->GetCapLsValue(sCap) : "";
for (CClient* pClient : m_vClients) {
pClient->NotifyServerDependentCap(sCap, bValue, sValue, handler);
}
}
bool CIRCNetwork::IsServerCapAccepted(const CString& sCap) const {
return m_pIRCSock && m_pIRCSock->IsCapAccepted(sCap);
}
void CIRCNetwork::SetIRCConnectEnabled(bool b) {
m_bIRCConnectEnabled = b;

View File

@@ -398,7 +398,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) {
m_ssAcceptedCaps.erase(sCap);
m_ssPendingCaps.erase(sCap);
if (m_bAuthed) {
m_pNetwork->NotifyServerDependentCap(sCap, false);
m_pNetwork->PotentiallyNotifyServerDependentCap(sCap, false);
}
};
@@ -415,6 +415,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) {
sCap = sToken.substr(0, eq);
sValue = sToken.substr(eq + 1);
}
m_msCapLsValues[sCap] = sValue;
if (OnServerCapAvailable(sCap, sValue) || mSupportedCaps.count(sCap)) {
m_ssPendingCaps.insert(sCap);
}
@@ -428,7 +429,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) {
}
m_ssAcceptedCaps.insert(sArgs);
if (m_bAuthed) {
m_pNetwork->NotifyServerDependentCap(sArgs, true);
m_pNetwork->PotentiallyNotifyServerDependentCap(sArgs, true);
}
} else if (sSubCmd == "NAK") {
// This should work because there's no [known]
@@ -441,6 +442,7 @@ bool CIRCSock::OnCapabilityMessage(CMessage& Message) {
for (const CString& sCap : vsTokens) {
RemoveCap(sCap);
m_msCapLsValues.erase(sCap);
}
}
@@ -1448,6 +1450,16 @@ CString CIRCSock::GetISupport(const CString& sKey,
}
}
CString CIRCSock::GetCapLsValue(const CString& sKey,
const CString& sDefault) const {
MCString::const_iterator i = m_msCapLsValues.find(sKey);
if (i == m_msCapLsValues.end()) {
return sDefault;
} else {
return i->second;
}
}
void CIRCSock::SendAltNick(const CString& sBadNick) {
const CString& sLastNick = m_Nick.GetNick();

View File

@@ -534,6 +534,117 @@ TEST_F(ZNCTest, CAP302LSValue) {
client2.ReadUntil("testcap=");
}
TEST_F(ZNCTest, ServerDependentCapInModule) {
auto znc = Run();
auto ircd = ConnectIRCd();
auto client = LoginClient();
InstallModule("testmod.cpp", R"(
#include <znc/Modules.h>
#include <znc/Client.h>
#include <znc/IRCNetwork.h>
#include <znc/IRCSock.h>
class TestModule : public CModule {
public:
MODCONSTRUCTOR(TestModule) {}
void OnClientCapLs(CClient* pClient, SCString& ssCaps) override {
if (GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap")) {
if (pClient->HasCap302()) {
CString sValue = GetNetwork()->GetIRCSock()->GetCapLsValue("testcap");
if (!sValue.empty()) {
ssCaps.insert("testcap=" + sValue);
} else {
ssCaps.insert("testcap");
}
} else {
ssCaps.insert("testcap");
}
}
}
bool IsClientCapSupported(CClient* pClient, const CString& sCap,
bool bState) override {
if (!bState) return false;
if (sCap != "testcap") return false;
return GetNetwork() && GetNetwork()->IsServerCapAccepted("testcap");
}
void OnClientCapRequest(CClient* pClient, const CString& sCap,
bool bState) override {
PutModule("OnClientCapRequest " + sCap + " " + CString(bState));
}
bool OnServerCap302Available(const CString& sCap, const CString& sValue) override {
PutModule("OnServerCapAvailable " + sCap + " " + sValue);
return sCap == "testcap";
}
void OnServerCapResult(const CString& sCap, bool bSuccess) override {
if (sCap == "testcap") {
GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", bSuccess, [=](CClient* pClient, bool bState) {
PutModule("OnServerCapResult " + sCap + " " + CString(bSuccess) + " " + CString(bState));
});
}
}
void OnIRCConnected() override {
if (GetNetwork()->IsServerCapAccepted("testcap")) {
GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", true, [=](CClient* pClient, bool bState) {
PutModule("OnIRCConnected " + CString(bState));
});
}
}
void OnIRCDisconnected() override {
GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) {
PutModule("OnIRCDisconnected " + CString(bState));
});
}
~TestModule() override {
// TODO user module
GetNetwork()->NotifyClientsAboutServerDependentCap("testcap", false, [=](CClient* pClient, bool bState) {
PutModule("~ " + CString(bState));
});
}
};
MODULEDEFS(TestModule, "Test")
)");
client.Write("znc loadmod testmod");
client.ReadUntil("Loaded module testmod");
client.Close();
client = ConnectClient();
client.Write("CAP LS 302");
client.Write("PASS :hunter2");
client.Write("NICK nick");
client.Write("USER user x x :x");
client.Write("CAP END");
client.ReadUntil("Welcome");
ircd.Write("001 nick Welcome");
ircd.Write("CAP nick NEW testcap=value");
ircd.ReadUntil("CAP REQ :testcap");
ircd.Write("CAP nick ACK :testcap");
client.ReadUntil("CAP nick NEW :testcap=value");
client.Write("CAP REQ testcap");
client.ReadUntil("CAP nick ACK :testcap");
client.Write("CAP LS");
client.ReadUntil(" testcap=value ");
ircd.Write("CAP nick DEL testcap");
client.ReadUntil("CAP nick DEL :testcap");
client.ReadUntil(":OnServerCapResult testcap false false");
ircd.Close();
ircd = ConnectIRCd();
ircd.ReadUntil("CAP LS 302");
ircd.Write("CAP nick LS :testcap=new");
ircd.ReadUntil("CAP REQ :testcap");
ircd.Write("CAP nick ACK :testcap");
ircd.ReadUntil("CAP END");
// TODO should NEW wait until 001?
// TODO combine multiple NEWs to single line
client.ReadUntil("CAP nick NEW :testcap=new");
ircd.ReadUntil("001 nick Welcome");
client.ReadUntil("Welcome2");
// TODO NEW with new value without DEL
// TODO client.Write("jumpnetwork");
}
TEST_F(ZNCTest, HashUpgrade) {
QFile conf(m_dir.path() + "/configs/znc.conf");
ASSERT_TRUE(conf.open(QIODevice::Append | QIODevice::Text));