Reabse and address PR comments

This commit is contained in:
delthas
2023-08-31 11:24:53 +02:00
parent d27e2cce5c
commit 1dd995ef77
9 changed files with 166 additions and 99 deletions
+17 -14
View File
@@ -116,10 +116,8 @@ class CClient : public CIRCSocket {
m_bBatch(false),
m_bEchoMessage(false),
m_bSelfMessage(false),
m_bSasl(false),
m_bSaslAuthenticating(false),
m_bSaslAuthenticated(false),
m_bSaslMultipart(false),
m_bSASL(false),
m_bSASLAuthenticating(false),
m_bPlaybackActive(false),
m_pUser(nullptr),
m_pNetwork(nullptr),
@@ -128,8 +126,9 @@ class CClient : public CIRCSocket {
m_sUser(""),
m_sNetwork(""),
m_sIdentifier(""),
m_sSaslBuffer(""),
m_sSaslMechanism(""),
m_sSASLBuffer(""),
m_sSASLMechanism(""),
m_sSASLUser(""),
m_spAuth(),
m_ssAcceptedCaps(),
m_ssSupportedTags(),
@@ -162,7 +161,7 @@ class CClient : public CIRCSocket {
}}},
{"extended-join",
{true, [this](bool bVal) { m_bExtendedJoin = bVal; }}},
{"sasl", {false, [this](bool bVal) { m_bSasl = bVal; m_bSaslAuthenticating = bVal; }}},
{"sasl", {false, [this](bool bVal) { m_bSASL = bVal; m_bSASLAuthenticating = bVal; }}},
}) {
EnableReadLine();
// RFC says a line can have 512 chars max, but we are
@@ -342,7 +341,12 @@ class CClient : public CIRCSocket {
bool OnActionMessage(CActionMessage& Message);
void OnAuthenticateMessage(CAuthenticateMessage& Message);
CString EnumerateSaslMechanisms(SCString& ssMechanisms);
/**
* Fills all available SASL mechanisms in the passed set, and returns a comma-joined string of those mechanisms.
* @param ssMechanisms Set of supported mechanisms, filled by this method.
* @return A comma-joined string of supported mechanisms.
*/
CString EnumerateSASLMechanisms(SCString& ssMechanisms);
bool OnCTCPMessage(CCTCPMessage& Message);
bool OnJoinMessage(CJoinMessage& Message);
@@ -373,10 +377,8 @@ class CClient : public CIRCSocket {
bool m_bBatch;
bool m_bEchoMessage;
bool m_bSelfMessage;
bool m_bSasl;
bool m_bSaslAuthenticating;
bool m_bSaslAuthenticated;
bool m_bSaslMultipart;
bool m_bSASL;
bool m_bSASLAuthenticating;
bool m_bPlaybackActive;
CUser* m_pUser;
CIRCNetwork* m_pNetwork;
@@ -385,8 +387,9 @@ class CClient : public CIRCSocket {
CString m_sUser;
CString m_sNetwork;
CString m_sIdentifier;
CString m_sSaslBuffer;
CString m_sSaslMechanism;
CString m_sSASLBuffer;
CString m_sSASLMechanism;
CString m_sSASLUser;
std::shared_ptr<CAuthBase> m_spAuth;
SCString m_ssAcceptedCaps;
SCString m_ssSupportedTags;
+27 -6
View File
@@ -1308,14 +1308,35 @@ class CModule {
*/
virtual void OnClientCapRequest(CClient* pClient, const CString& sCap,
bool bState);
virtual EModRet OnSaslServerChallenge(const CString& sMechanism,
/** Called when a client requests SASL authentication. Use ssMechanisms.insert("mechanism")
* for announcing sASL mechanisms which your module supports.
* @param ssMechanisms The set of supported SASL mechanisms to append to.
*/
virtual void OnGetSASLMechanisms(SCString& ssMechanisms);
/** Called when a client has selected a SASL mechanism for SASL authentication.
* If implementing a SASL authentication mechanism, set sResponse to specify an initial challenge
* message to send to the client. Otherwise, an empty response will be sent.
* @param sMechanism The SASL mechanism selected by the client.
* @param sResponse The optional value of an initial SASL challenge message to send to the client.
*/
virtual EModRet OnSASLServerChallenge(const CString& sMechanism,
CString& sResponse);
virtual EModRet OnClientSaslAuthenticate(const CString& sMechanism,
/** Called when a client is sending us a SASL message after the mechanism was selected.
* If implementing a SASL authentication mechanism, check the passed credentials,
* then either request more data by sending a challenge in sMechanismResponse,
* reject authentication by setting bAuthenticationSuccess to false,
* or accept authentication by setting bAuthenticationSuccess to true and setting sUser to the authenticated user name.
* @param sMechanism The SASL mechanism selected by the client.
* @param sBuffer The SASL opaque value/credentials sent by the client.
* @param sUser The optional name of the authenticated user to log in the user as, if authentication is accepted.
* @param sMechanismResponse The optional value of a SASL challenge message to reply to the client to ask for more data.
* @param bAuthenticationSuccess If sMechanismResponse is not set, whether to accept or reject the authentication request.
*/
virtual EModRet OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer,
CString& sUser,
CString& sMechanismResponse,
bool& bAuthenticationSuccess);
virtual void OnGetSaslMechanisms(SCString& ssMechanisms);
/** Called when a module is going to be loaded.
* @param sModName name of the module.
@@ -1595,14 +1616,14 @@ class CModules : public std::vector<CModule*>, private CCoreTranslationMixin {
bool IsClientCapSupported(CClient* pClient, const CString& sCap,
bool bState);
bool OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState);
bool OnSaslServerChallenge(const CString& sMechanism,
bool OnGetSASLMechanisms(SCString& ssMechanisms);
bool OnSASLServerChallenge(const CString& sMechanism,
CString& sResponse);
bool OnClientSaslAuthenticate(const CString& sMechanism,
bool OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer,
CString& sUser,
CString& sResponse,
bool& bAuthenticationSuccess);
bool OnGetSaslMechanisms(SCString& ssMechanisms);
bool OnModuleLoading(const CString& sModName, const CString& sArgs,
CModInfo::EModuleType eType, bool& bSuccess,
+3
View File
@@ -109,6 +109,9 @@ EModRet OnUnknownUserRaw(CClient* pClient, CString& sLine)
EModRet OnUnknownUserRawMessage(CMessage& Message)
bool IsClientCapSupported(CClient* pClient, const CString& sCap, bool bState)
void OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState)
void OnGetSASLMechanisms(SCString& ssMechanisms)
EModRet OnSASLServerChallenge(const CString& sMechanism, CString& sResponse)
EModRet OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sMechanismResponse, bool& bAuthenticationSuccess)
EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, CString& sRetMsg)
EModRet OnModuleUnloading(CModule* pModule, bool& bSuccess, CString& sRetMsg)
EModRet OnGetModInfo(CModInfo& ModInfo, const CString& sModule, bool& bSuccess, CString& sRetMsg)
+8
View File
@@ -191,6 +191,14 @@ class ZNC_EXPORT_LIB_EXPORT CPyModule : public CModule {
bool bState) override;
void OnClientCapRequest(CClient* pClient, const CString& sCap,
bool bState) override;
void OnGetSASLMechanisms(SCString& ssMechanisms) override;
EModRet OnSASLServerChallenge(const CString& sMechanism,
CString& sResponse) override;
EModRet OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer,
CString& sUser,
CString& sMechanismResponse,
bool& bAuthenticationSuccess) override;
virtual EModRet OnModuleLoading(const CString& sModName,
const CString& sArgs,
CModInfo::EModuleType eType, bool& bSuccess,
+9
View File
@@ -469,6 +469,15 @@ class Module:
def OnClientCapRequest(self, pClient, sCap, bState):
pass
def OnGetSASLMechanisms(self, ssMechanisms):
pass
def OnSASLServerChallenge(self, sMechanism, sResponse):
pass
def OnClientSASLAuthenticate(self, sMechanism, sBuffer, sUser, sResponse, bAuthenticationSuccess):
pass
def OnModuleLoading(self, sModName, sArgs, eType, bSuccess, sRetMsg):
pass
+9 -6
View File
@@ -21,7 +21,7 @@ class CSASLMechanismPlain : public CModule {
public:
MODCONSTRUCTOR(CSASLMechanismPlain) { AddHelpCommand(); }
EModRet OnClientSaslAuthenticate(const CString& sMechanism,
EModRet OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer, CString& sUser,
CString& sMechanismResponse,
bool& bAuthenticationSuccess) override {
@@ -32,11 +32,14 @@ class CSASLMechanismPlain : public CModule {
bAuthenticationSuccess = false;
CString sNullSeparator = std::string("\0", 1);
auto sAuthzId = sBuffer.Token(0, false, sNullSeparator);
auto sAuthcId = sBuffer.Token(1, false, sNullSeparator);
auto sPassword = sBuffer.Token(2, false, sNullSeparator);
auto sAuthzId = sBuffer.Token(0, false, sNullSeparator, true);
auto sAuthcId = sBuffer.Token(1, false, sNullSeparator, true);
auto sPassword = sBuffer.Token(2, false, sNullSeparator, true);
if (sAuthzId.empty()) sAuthzId = sAuthcId;
if (!sAuthzId.empty() && sAuthzId != sAuthcId) {
// Reject custom SASL plain authorization identifiers
return HALTMODS;
}
auto pUser = CZNC::Get().FindUser(sAuthcId);
@@ -50,7 +53,7 @@ class CSASLMechanismPlain : public CModule {
return HALTMODS;
}
void OnGetSaslMechanisms(SCString& ssMechanisms) override {
void OnGetSASLMechanisms(SCString& ssMechanisms) override {
ssMechanisms.insert("PLAIN");
}
};
+64 -64
View File
@@ -321,10 +321,11 @@ bool CClient::SendMotd() {
void CClient::AuthUser() {
if (!m_bGotNick || !m_bGotUser || m_bInCap ||
(!m_bSaslAuthenticated && !m_bGotPass) || IsAttached())
(m_sSASLUser.empty() && !m_bGotPass) || IsAttached())
return;
if (m_bSasl && m_bSaslAuthenticated) {
if (m_bSASL && !m_sSASLUser.empty()) {
m_sUser = m_sSASLUser;
auto pUser = CZNC::Get().FindUser(m_sUser);
AcceptLogin(*pUser);
return;
@@ -393,7 +394,7 @@ void CClientAuth::AcceptedLogin(CUser& User) {
void CClient::AcceptLogin(CUser& User) {
m_sPass = "";
m_pUser = &User;
m_bSaslAuthenticating = m_bSasl;
m_bSASLAuthenticating = m_bSASL;
// Set our proper timeout and set back our proper timeout mode
// (constructor set a different timeout and mode)
@@ -705,15 +706,16 @@ void CClient::HandleCap(const CMessage& Message) {
CString sSubCmd = Message.GetParam(0);
if (sSubCmd.Equals("LS")) {
int iCapVersion = Message.GetParam(1).ToInt();
SCString ssOfferCaps;
for (const auto& it : m_mCoreCaps) {
bool bServerDependent = std::get<0>(it.second);
if (!bServerDependent ||
m_ssServerDependentCaps.count(it.first) > 0) {
if (it.first.Equals("sasl")) {
if (it.first.Equals("sasl") && iCapVersion >= 302) {
SCString ssMechanisms;
ssOfferCaps.insert(it.first + "=" +
EnumerateSaslMechanisms(ssMechanisms));
EnumerateSASLMechanisms(ssMechanisms));
} else {
ssOfferCaps.insert(it.first);
}
@@ -724,23 +726,21 @@ void CClient::HandleCap(const CMessage& Message) {
CString(" ").Join(ssOfferCaps.begin(), ssOfferCaps.end());
RespondCap("LS :" + sRes);
m_bInCap = true;
if (Message.GetParam(1).ToInt() >= 302) {
if (iCapVersion >= 302) {
m_bCapNotify = true;
}
} else if (sSubCmd.Equals("END")) {
m_bInCap = false;
if (!IsAttached()) {
if (m_bSasl && !m_bSaslAuthenticated && m_bSaslAuthenticating) {
if (m_bSASL && m_sSASLUser.empty() && m_bSASLAuthenticating) {
PutClient(":irc.znc.in 906 " + GetNick() +
" :SASL authentication aborted");
m_sSaslMechanism = "";
m_bSaslAuthenticated = false;
m_bSaslMultipart = false;
m_bSaslAuthenticating = false;
m_sSASLMechanism = "";
m_bSASLAuthenticating = false;
}
if (!m_pUser && m_bGotUser &&
(!m_bSaslAuthenticated && !m_bGotPass)) {
(m_sSASLUser.empty() && !m_bGotPass)) {
SendRequiredPasswordNotice();
} else {
AuthUser();
@@ -998,139 +998,139 @@ bool CClient::OnActionMessage(CActionMessage& Message) {
}
void CClient::OnAuthenticateMessage(CAuthenticateMessage& Message) {
const auto uiMaxSaslMsgLength = 400u;
const auto uiMaxSASLMsgLength = 400u;
auto bAuthenticationSuccess = false;
auto sMessage = Message.GetText();
const auto sBufferSize = sMessage.length();
SCString ssMechanisms;
const auto iBufferSize = sMessage.length();
auto SaslReset = [this]() {
m_sSaslMechanism = "";
m_sSaslBuffer = "";
m_bSaslMultipart = false;
auto SASLReset = [this]() {
m_sSASLMechanism = "";
m_sSASLBuffer = "";
};
auto SaslChallenge = [this](CString sChallenge) {
auto SASLChallenge = [this](CString sChallenge) {
sChallenge.Base64Encode();
auto sChallengeSize = sChallenge.length();
if (sChallengeSize > uiMaxSaslMsgLength) {
for (auto i = 0u; i < sChallengeSize; i += uiMaxSaslMsgLength) {
CString sMsgPart = sChallenge.substr(i, uiMaxSaslMsgLength);
if (sChallengeSize > uiMaxSASLMsgLength) {
for (int i = 0; i < sChallengeSize; i += uiMaxSASLMsgLength) {
CString sMsgPart = sChallenge.substr(i, uiMaxSASLMsgLength);
PutClient("AUTHENTICATE " + sMsgPart);
}
} else {
} else if (sChallengeSize > 0) {
PutClient("AUTHENTICATE " + sChallenge);
}
if (sChallengeSize % uiMaxSASLMsgLength == 0) {
PutClient("AUTHENTICATE +");
}
};
if (!m_bSasl) return;
if (!m_bSASL) return;
if (m_bSaslAuthenticated || IsAttached()) {
if (!m_sSASLUser.empty() || IsAttached()) {
PutClient(":irc.znc.in 907 " + GetNick() +
" :You have already authenticated using SASL");
return;
}
if (!m_bSaslAuthenticating || sMessage.Equals("*")) {
if (!m_bSASLAuthenticating || sMessage.Equals("*")) {
PutClient(":irc.znc.in 906 " + GetNick() +
" :SASL authentication aborted");
if (!IsAttached()) {
m_bSaslAuthenticating = false;
SaslReset();
m_bSASLAuthenticating = false;
SASLReset();
}
return;
}
auto sMechanisms = EnumerateSaslMechanisms(ssMechanisms);
if (sBufferSize > uiMaxSaslMsgLength) {
if (iBufferSize > uiMaxSASLMsgLength) {
PutClient(":irc.znc.in 905 " + GetNick() + " :SASL message too long");
SaslReset();
SASLReset();
return;
}
if (m_sSaslMechanism.empty()) {
if (m_sSASLMechanism.empty()) {
SCString ssMechanisms;
auto sMechanisms = EnumerateSASLMechanisms(ssMechanisms);
if (ssMechanisms.find(sMessage) == ssMechanisms.end()) {
PutClient(":irc.znc.in 908 " + GetNick() + " " + sMechanisms +
" :are available SASL mechanisms");
PutClient(":irc.znc.in 904 " + GetNick() +
" :SASL authentication failed");
SaslReset();
SASLReset();
return;
}
m_sSaslMechanism = sMessage;
m_sSASLMechanism = sMessage;
auto bResult = false;
CString sChallenge;
GLOBALMODULECALL(OnSaslServerChallenge(m_sSaslMechanism, sChallenge),
GLOBALMODULECALL(OnSASLServerChallenge(m_sSASLMechanism, sChallenge),
&bResult);
if (bResult) {
SaslChallenge(sChallenge);
SASLChallenge(sChallenge);
} else {
PutClient("AUTHENTICATE +");
}
return;
}
if (sBufferSize == uiMaxSaslMsgLength) {
m_bSaslMultipart = true;
m_sSaslBuffer.append(sMessage);
if (m_sSASLBuffer.length() + sMessage.length() > 10 * 1024) {
PutClient(":irc.znc.in 904 " + GetNick() + " :SASL response too long");
SASLReset();
return;
}
if (iBufferSize == uiMaxSASLMsgLength) {
m_sSASLBuffer.append(sMessage);
return;
}
if ((m_bSaslMultipart && !sMessage.Equals("+"))) {
m_sSaslBuffer.append(sMessage);
m_bSaslMultipart = false;
} else if (!m_bSaslMultipart && !sMessage.Equals("+")) {
m_sSaslBuffer.assign(sMessage);
}
if (sMessage != "+") {
m_sSASLBuffer += sMessage;
}
m_sSaslBuffer.Base64Decode();
auto sAuthcId = m_sUser;
auto sAuthzId = m_sUser;
m_sSASLBuffer.Base64Decode();
CString sResponse;
bool bResult;
GLOBALMODULECALL(OnClientSaslAuthenticate(
m_sSaslMechanism, m_sSaslBuffer, sAuthcId,
CString sSASLUser;
GLOBALMODULECALL(OnClientSASLAuthenticate(
m_sSASLMechanism, m_sSASLBuffer, sSASLUser,
sResponse, bAuthenticationSuccess),
&bResult);
m_sSASLBuffer.clear();
if (bResult && !sResponse.empty()) {
SaslChallenge(sResponse);
SASLChallenge(sResponse);
return;
}
m_sSaslBuffer.clear();
auto pUser = CZNC::Get().FindUser(sAuthcId);
auto pUser = CZNC::Get().FindUser(sSASLUser);
if (pUser && bAuthenticationSuccess) {
PutClient(":irc.znc.in 900 " + GetNick() + " " + GetNick() + "!" +
pUser->GetIdent() + "@" + GetHostName() + " " + sAuthcId +
" :You are now logged in as " + sAuthzId);
pUser->GetIdent() + "@" + GetHostName() + " " + sSASLUser +
" :You are now logged in as " + sSASLUser);
PutClient(":irc.znc.in 903 " + GetNick() +
" :SASL authentication successful");
m_bSaslAuthenticated = true;
m_bSaslAuthenticating = false;
m_sSASLUser = sSASLUser;
m_bSASLAuthenticating = false;
} else {
PutClient(":irc.znc.in 904 " + GetNick() + " :SASL authentication failed");
SaslReset();
SASLReset();
}
return;
}
CString CClient::EnumerateSaslMechanisms(SCString& ssMechanisms) {
CString CClient::EnumerateSASLMechanisms(SCString& ssMechanisms) {
CString sMechanisms;
GLOBALMODULECALL(OnGetSaslMechanisms(ssMechanisms), NOTHING);
GLOBALMODULECALL(OnGetSASLMechanisms(ssMechanisms), NOTHING);
if (ssMechanisms.size()) {
sMechanisms =
+9 -9
View File
@@ -1076,7 +1076,7 @@ bool CModule::IsClientCapSupported(CClient* pClient, const CString& sCap,
void CModule::OnClientCapRequest(CClient* pClient, const CString& sCap,
bool bState) {}
CModule::EModRet CModule::OnClientSaslAuthenticate(const CString& sMechanism,
CModule::EModRet CModule::OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer,
CString& sUser,
CString& sMechanismResponse,
@@ -1084,12 +1084,12 @@ CModule::EModRet CModule::OnClientSaslAuthenticate(const CString& sMechanism,
return CONTINUE;
}
CModule::EModRet CModule::OnSaslServerChallenge(const CString& sMechanism,
CModule::EModRet CModule::OnSASLServerChallenge(const CString& sMechanism,
CString& sResponse) {
return CONTINUE;
}
void CModule::OnGetSaslMechanisms(SCString& ssMechanisms) {}
void CModule::OnGetSASLMechanisms(SCString& ssMechanisms) {}
CModule::EModRet CModule::OnModuleLoading(const CString& sModName,
const CString& sArgs,
@@ -1608,22 +1608,22 @@ bool CModules::OnClientCapRequest(CClient* pClient, const CString& sCap,
return false;
}
bool CModules::OnClientSaslAuthenticate(const CString& sMechanism,
bool CModules::OnClientSASLAuthenticate(const CString& sMechanism,
const CString& sBuffer,
CString& sUser,
CString& sResponse,
bool& bAuthenticationSuccess) {
MODHALTCHK(OnClientSaslAuthenticate(sMechanism, sBuffer, sUser,
MODHALTCHK(OnClientSASLAuthenticate(sMechanism, sBuffer, sUser,
sResponse, bAuthenticationSuccess));
}
bool CModules::OnSaslServerChallenge(const CString& sMechanism,
bool CModules::OnSASLServerChallenge(const CString& sMechanism,
CString& sResponse) {
MODHALTCHK(OnSaslServerChallenge(sMechanism, sResponse));
MODHALTCHK(OnSASLServerChallenge(sMechanism, sResponse));
}
bool CModules::OnGetSaslMechanisms(SCString& ssMechanisms) {
MODUNLOADCHK(OnGetSaslMechanisms(ssMechanisms));
bool CModules::OnGetSASLMechanisms(SCString& ssMechanisms) {
MODUNLOADCHK(OnGetSASLMechanisms(ssMechanisms));
return false;
}
+20
View File
@@ -330,5 +330,25 @@ TEST_F(ZNCTest, SaslMechsNotInit) {
ircd.ReadUntil("PONG foo");
}
TEST_F(ZNCTest, SaslPlainModule) {
auto znc = Run();
auto ircd = ConnectIRCd();
auto client = LoginClient();
client.Write("znc loadmod saslplain");
client.ReadUntil("Loaded module");
client.Close();
auto client2 = ConnectClient();
client2.Write("NICK foo");
client2.Write("CAP LS");
client2.Write("CAP REQ :sasl");
client2.ReadUntil(":irc.znc.in CAP foo ACK :sasl");
client2.Write("USER bar");
client2.Write("AUTHENTICATE PLAIN");
client2.ReadUntil("AUTHENTICATE +");
client2.Write("AUTHENTICATE AHVzZXIAaHVudGVyMg=="); // \0user\0hunter2
client2.ReadUntil(":irc.znc.in 903 foo :SASL authentication successful");
}
} // namespace
} // namespace znc_inttest