diff --git a/include/znc/Client.h b/include/znc/Client.h index 083f2696..82eb1bc0 100644 --- a/include/znc/Client.h +++ b/include/znc/Client.h @@ -116,10 +116,8 @@ class CClient : public CIRCSocket { m_bBatch(false), m_bEchoMessage(false), m_bSelfMessage(false), - m_bSasl(false), - m_bSaslAuthenticating(false), - m_bSaslAuthenticated(false), - m_bSaslMultipart(false), + m_bSASL(false), + m_bSASLAuthenticating(false), m_bPlaybackActive(false), m_pUser(nullptr), m_pNetwork(nullptr), @@ -128,8 +126,9 @@ class CClient : public CIRCSocket { m_sUser(""), m_sNetwork(""), m_sIdentifier(""), - m_sSaslBuffer(""), - m_sSaslMechanism(""), + m_sSASLBuffer(""), + m_sSASLMechanism(""), + m_sSASLUser(""), m_spAuth(), m_ssAcceptedCaps(), m_ssSupportedTags(), @@ -162,7 +161,7 @@ class CClient : public CIRCSocket { }}}, {"extended-join", {true, [this](bool bVal) { m_bExtendedJoin = bVal; }}}, - {"sasl", {false, [this](bool bVal) { m_bSasl = bVal; m_bSaslAuthenticating = bVal; }}}, + {"sasl", {false, [this](bool bVal) { m_bSASL = bVal; m_bSASLAuthenticating = bVal; }}}, }) { EnableReadLine(); // RFC says a line can have 512 chars max, but we are @@ -342,7 +341,12 @@ class CClient : public CIRCSocket { bool OnActionMessage(CActionMessage& Message); void OnAuthenticateMessage(CAuthenticateMessage& Message); - CString EnumerateSaslMechanisms(SCString& ssMechanisms); + /** + * Fills all available SASL mechanisms in the passed set, and returns a comma-joined string of those mechanisms. + * @param ssMechanisms Set of supported mechanisms, filled by this method. + * @return A comma-joined string of supported mechanisms. + */ + CString EnumerateSASLMechanisms(SCString& ssMechanisms); bool OnCTCPMessage(CCTCPMessage& Message); bool OnJoinMessage(CJoinMessage& Message); @@ -373,10 +377,8 @@ class CClient : public CIRCSocket { bool m_bBatch; bool m_bEchoMessage; bool m_bSelfMessage; - bool m_bSasl; - bool m_bSaslAuthenticating; - bool m_bSaslAuthenticated; - bool m_bSaslMultipart; + bool m_bSASL; + bool m_bSASLAuthenticating; bool m_bPlaybackActive; CUser* m_pUser; CIRCNetwork* m_pNetwork; @@ -385,8 +387,9 @@ class CClient : public CIRCSocket { CString m_sUser; CString m_sNetwork; CString m_sIdentifier; - CString m_sSaslBuffer; - CString m_sSaslMechanism; + CString m_sSASLBuffer; + CString m_sSASLMechanism; + CString m_sSASLUser; std::shared_ptr m_spAuth; SCString m_ssAcceptedCaps; SCString m_ssSupportedTags; diff --git a/include/znc/Modules.h b/include/znc/Modules.h index a36531db..a9d454ae 100644 --- a/include/znc/Modules.h +++ b/include/znc/Modules.h @@ -1308,14 +1308,35 @@ class CModule { */ virtual void OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState); - virtual EModRet OnSaslServerChallenge(const CString& sMechanism, + /** Called when a client requests SASL authentication. Use ssMechanisms.insert("mechanism") + * for announcing sASL mechanisms which your module supports. + * @param ssMechanisms The set of supported SASL mechanisms to append to. + */ + virtual void OnGetSASLMechanisms(SCString& ssMechanisms); + /** Called when a client has selected a SASL mechanism for SASL authentication. + * If implementing a SASL authentication mechanism, set sResponse to specify an initial challenge + * message to send to the client. Otherwise, an empty response will be sent. + * @param sMechanism The SASL mechanism selected by the client. + * @param sResponse The optional value of an initial SASL challenge message to send to the client. + */ + virtual EModRet OnSASLServerChallenge(const CString& sMechanism, CString& sResponse); - virtual EModRet OnClientSaslAuthenticate(const CString& sMechanism, + /** Called when a client is sending us a SASL message after the mechanism was selected. + * If implementing a SASL authentication mechanism, check the passed credentials, + * then either request more data by sending a challenge in sMechanismResponse, + * reject authentication by setting bAuthenticationSuccess to false, + * or accept authentication by setting bAuthenticationSuccess to true and setting sUser to the authenticated user name. + * @param sMechanism The SASL mechanism selected by the client. + * @param sBuffer The SASL opaque value/credentials sent by the client. + * @param sUser The optional name of the authenticated user to log in the user as, if authentication is accepted. + * @param sMechanismResponse The optional value of a SASL challenge message to reply to the client to ask for more data. + * @param bAuthenticationSuccess If sMechanismResponse is not set, whether to accept or reject the authentication request. + */ + virtual EModRet OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sMechanismResponse, bool& bAuthenticationSuccess); - virtual void OnGetSaslMechanisms(SCString& ssMechanisms); /** Called when a module is going to be loaded. * @param sModName name of the module. @@ -1595,14 +1616,14 @@ class CModules : public std::vector, private CCoreTranslationMixin { bool IsClientCapSupported(CClient* pClient, const CString& sCap, bool bState); bool OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState); - bool OnSaslServerChallenge(const CString& sMechanism, + bool OnGetSASLMechanisms(SCString& ssMechanisms); + bool OnSASLServerChallenge(const CString& sMechanism, CString& sResponse); - bool OnClientSaslAuthenticate(const CString& sMechanism, + bool OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sResponse, bool& bAuthenticationSuccess); - bool OnGetSaslMechanisms(SCString& ssMechanisms); bool OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, diff --git a/modules/modpython/functions.in b/modules/modpython/functions.in index f779e577..52707995 100644 --- a/modules/modpython/functions.in +++ b/modules/modpython/functions.in @@ -109,6 +109,9 @@ EModRet OnUnknownUserRaw(CClient* pClient, CString& sLine) EModRet OnUnknownUserRawMessage(CMessage& Message) bool IsClientCapSupported(CClient* pClient, const CString& sCap, bool bState) void OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState) +void OnGetSASLMechanisms(SCString& ssMechanisms) +EModRet OnSASLServerChallenge(const CString& sMechanism, CString& sResponse) +EModRet OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sMechanismResponse, bool& bAuthenticationSuccess) EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, CString& sRetMsg) EModRet OnModuleUnloading(CModule* pModule, bool& bSuccess, CString& sRetMsg) EModRet OnGetModInfo(CModInfo& ModInfo, const CString& sModule, bool& bSuccess, CString& sRetMsg) diff --git a/modules/modpython/module.h b/modules/modpython/module.h index a0847a20..d4f20d4d 100644 --- a/modules/modpython/module.h +++ b/modules/modpython/module.h @@ -191,6 +191,14 @@ class ZNC_EXPORT_LIB_EXPORT CPyModule : public CModule { bool bState) override; void OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState) override; + void OnGetSASLMechanisms(SCString& ssMechanisms) override; + EModRet OnSASLServerChallenge(const CString& sMechanism, + CString& sResponse) override; + EModRet OnClientSASLAuthenticate(const CString& sMechanism, + const CString& sBuffer, + CString& sUser, + CString& sMechanismResponse, + bool& bAuthenticationSuccess) override; virtual EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, diff --git a/modules/modpython/znc.py b/modules/modpython/znc.py index 4e8fc295..b8c56e80 100644 --- a/modules/modpython/znc.py +++ b/modules/modpython/znc.py @@ -469,6 +469,15 @@ class Module: def OnClientCapRequest(self, pClient, sCap, bState): pass + def OnGetSASLMechanisms(self, ssMechanisms): + pass + + def OnSASLServerChallenge(self, sMechanism, sResponse): + pass + + def OnClientSASLAuthenticate(self, sMechanism, sBuffer, sUser, sResponse, bAuthenticationSuccess): + pass + def OnModuleLoading(self, sModName, sArgs, eType, bSuccess, sRetMsg): pass diff --git a/modules/saslplain.cpp b/modules/saslplain.cpp index ce4cfbb0..3f467397 100644 --- a/modules/saslplain.cpp +++ b/modules/saslplain.cpp @@ -21,7 +21,7 @@ class CSASLMechanismPlain : public CModule { public: MODCONSTRUCTOR(CSASLMechanismPlain) { AddHelpCommand(); } - EModRet OnClientSaslAuthenticate(const CString& sMechanism, + EModRet OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sMechanismResponse, bool& bAuthenticationSuccess) override { @@ -32,11 +32,14 @@ class CSASLMechanismPlain : public CModule { bAuthenticationSuccess = false; CString sNullSeparator = std::string("\0", 1); - auto sAuthzId = sBuffer.Token(0, false, sNullSeparator); - auto sAuthcId = sBuffer.Token(1, false, sNullSeparator); - auto sPassword = sBuffer.Token(2, false, sNullSeparator); + auto sAuthzId = sBuffer.Token(0, false, sNullSeparator, true); + auto sAuthcId = sBuffer.Token(1, false, sNullSeparator, true); + auto sPassword = sBuffer.Token(2, false, sNullSeparator, true); - if (sAuthzId.empty()) sAuthzId = sAuthcId; + if (!sAuthzId.empty() && sAuthzId != sAuthcId) { + // Reject custom SASL plain authorization identifiers + return HALTMODS; + } auto pUser = CZNC::Get().FindUser(sAuthcId); @@ -50,7 +53,7 @@ class CSASLMechanismPlain : public CModule { return HALTMODS; } - void OnGetSaslMechanisms(SCString& ssMechanisms) override { + void OnGetSASLMechanisms(SCString& ssMechanisms) override { ssMechanisms.insert("PLAIN"); } }; diff --git a/src/Client.cpp b/src/Client.cpp index f02676b8..c76b94e8 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -321,10 +321,11 @@ bool CClient::SendMotd() { void CClient::AuthUser() { if (!m_bGotNick || !m_bGotUser || m_bInCap || - (!m_bSaslAuthenticated && !m_bGotPass) || IsAttached()) + (m_sSASLUser.empty() && !m_bGotPass) || IsAttached()) return; - if (m_bSasl && m_bSaslAuthenticated) { + if (m_bSASL && !m_sSASLUser.empty()) { + m_sUser = m_sSASLUser; auto pUser = CZNC::Get().FindUser(m_sUser); AcceptLogin(*pUser); return; @@ -393,7 +394,7 @@ void CClientAuth::AcceptedLogin(CUser& User) { void CClient::AcceptLogin(CUser& User) { m_sPass = ""; m_pUser = &User; - m_bSaslAuthenticating = m_bSasl; + m_bSASLAuthenticating = m_bSASL; // Set our proper timeout and set back our proper timeout mode // (constructor set a different timeout and mode) @@ -705,15 +706,16 @@ void CClient::HandleCap(const CMessage& Message) { CString sSubCmd = Message.GetParam(0); if (sSubCmd.Equals("LS")) { + int iCapVersion = Message.GetParam(1).ToInt(); SCString ssOfferCaps; for (const auto& it : m_mCoreCaps) { bool bServerDependent = std::get<0>(it.second); if (!bServerDependent || m_ssServerDependentCaps.count(it.first) > 0) { - if (it.first.Equals("sasl")) { + if (it.first.Equals("sasl") && iCapVersion >= 302) { SCString ssMechanisms; ssOfferCaps.insert(it.first + "=" + - EnumerateSaslMechanisms(ssMechanisms)); + EnumerateSASLMechanisms(ssMechanisms)); } else { ssOfferCaps.insert(it.first); } @@ -724,23 +726,21 @@ void CClient::HandleCap(const CMessage& Message) { CString(" ").Join(ssOfferCaps.begin(), ssOfferCaps.end()); RespondCap("LS :" + sRes); m_bInCap = true; - if (Message.GetParam(1).ToInt() >= 302) { + if (iCapVersion >= 302) { m_bCapNotify = true; } } else if (sSubCmd.Equals("END")) { m_bInCap = false; if (!IsAttached()) { - if (m_bSasl && !m_bSaslAuthenticated && m_bSaslAuthenticating) { + if (m_bSASL && m_sSASLUser.empty() && m_bSASLAuthenticating) { PutClient(":irc.znc.in 906 " + GetNick() + " :SASL authentication aborted"); - m_sSaslMechanism = ""; - m_bSaslAuthenticated = false; - m_bSaslMultipart = false; - m_bSaslAuthenticating = false; + m_sSASLMechanism = ""; + m_bSASLAuthenticating = false; } if (!m_pUser && m_bGotUser && - (!m_bSaslAuthenticated && !m_bGotPass)) { + (m_sSASLUser.empty() && !m_bGotPass)) { SendRequiredPasswordNotice(); } else { AuthUser(); @@ -998,139 +998,139 @@ bool CClient::OnActionMessage(CActionMessage& Message) { } void CClient::OnAuthenticateMessage(CAuthenticateMessage& Message) { - const auto uiMaxSaslMsgLength = 400u; + const auto uiMaxSASLMsgLength = 400u; auto bAuthenticationSuccess = false; auto sMessage = Message.GetText(); - const auto sBufferSize = sMessage.length(); - SCString ssMechanisms; + const auto iBufferSize = sMessage.length(); - auto SaslReset = [this]() { - m_sSaslMechanism = ""; - m_sSaslBuffer = ""; - m_bSaslMultipart = false; + auto SASLReset = [this]() { + m_sSASLMechanism = ""; + m_sSASLBuffer = ""; }; - auto SaslChallenge = [this](CString sChallenge) { + auto SASLChallenge = [this](CString sChallenge) { sChallenge.Base64Encode(); auto sChallengeSize = sChallenge.length(); - if (sChallengeSize > uiMaxSaslMsgLength) { - for (auto i = 0u; i < sChallengeSize; i += uiMaxSaslMsgLength) { - CString sMsgPart = sChallenge.substr(i, uiMaxSaslMsgLength); + if (sChallengeSize > uiMaxSASLMsgLength) { + for (int i = 0; i < sChallengeSize; i += uiMaxSASLMsgLength) { + CString sMsgPart = sChallenge.substr(i, uiMaxSASLMsgLength); PutClient("AUTHENTICATE " + sMsgPart); } - } else { + } else if (sChallengeSize > 0) { PutClient("AUTHENTICATE " + sChallenge); } + if (sChallengeSize % uiMaxSASLMsgLength == 0) { + PutClient("AUTHENTICATE +"); + } }; - if (!m_bSasl) return; + if (!m_bSASL) return; - if (m_bSaslAuthenticated || IsAttached()) { + if (!m_sSASLUser.empty() || IsAttached()) { PutClient(":irc.znc.in 907 " + GetNick() + " :You have already authenticated using SASL"); return; } - if (!m_bSaslAuthenticating || sMessage.Equals("*")) { + if (!m_bSASLAuthenticating || sMessage.Equals("*")) { PutClient(":irc.znc.in 906 " + GetNick() + " :SASL authentication aborted"); if (!IsAttached()) { - m_bSaslAuthenticating = false; - SaslReset(); + m_bSASLAuthenticating = false; + SASLReset(); } return; } - auto sMechanisms = EnumerateSaslMechanisms(ssMechanisms); - - if (sBufferSize > uiMaxSaslMsgLength) { + if (iBufferSize > uiMaxSASLMsgLength) { PutClient(":irc.znc.in 905 " + GetNick() + " :SASL message too long"); - SaslReset(); + SASLReset(); return; } - if (m_sSaslMechanism.empty()) { + if (m_sSASLMechanism.empty()) { + SCString ssMechanisms; + auto sMechanisms = EnumerateSASLMechanisms(ssMechanisms); + if (ssMechanisms.find(sMessage) == ssMechanisms.end()) { PutClient(":irc.znc.in 908 " + GetNick() + " " + sMechanisms + " :are available SASL mechanisms"); PutClient(":irc.znc.in 904 " + GetNick() + " :SASL authentication failed"); - SaslReset(); + SASLReset(); return; } - m_sSaslMechanism = sMessage; + m_sSASLMechanism = sMessage; auto bResult = false; CString sChallenge; - GLOBALMODULECALL(OnSaslServerChallenge(m_sSaslMechanism, sChallenge), + GLOBALMODULECALL(OnSASLServerChallenge(m_sSASLMechanism, sChallenge), &bResult); if (bResult) { - SaslChallenge(sChallenge); + SASLChallenge(sChallenge); } else { PutClient("AUTHENTICATE +"); } return; } - if (sBufferSize == uiMaxSaslMsgLength) { - m_bSaslMultipart = true; - m_sSaslBuffer.append(sMessage); + if (m_sSASLBuffer.length() + sMessage.length() > 10 * 1024) { + PutClient(":irc.znc.in 904 " + GetNick() + " :SASL response too long"); + SASLReset(); + return; + } + if (iBufferSize == uiMaxSASLMsgLength) { + m_sSASLBuffer.append(sMessage); return; } - if ((m_bSaslMultipart && !sMessage.Equals("+"))) { - m_sSaslBuffer.append(sMessage); - m_bSaslMultipart = false; - } else if (!m_bSaslMultipart && !sMessage.Equals("+")) { - m_sSaslBuffer.assign(sMessage); - } + if (sMessage != "+") { + m_sSASLBuffer += sMessage; + } - m_sSaslBuffer.Base64Decode(); - - auto sAuthcId = m_sUser; - auto sAuthzId = m_sUser; + m_sSASLBuffer.Base64Decode(); CString sResponse; bool bResult; - GLOBALMODULECALL(OnClientSaslAuthenticate( - m_sSaslMechanism, m_sSaslBuffer, sAuthcId, + CString sSASLUser; + GLOBALMODULECALL(OnClientSASLAuthenticate( + m_sSASLMechanism, m_sSASLBuffer, sSASLUser, sResponse, bAuthenticationSuccess), &bResult); + m_sSASLBuffer.clear(); if (bResult && !sResponse.empty()) { - SaslChallenge(sResponse); + SASLChallenge(sResponse); return; } - m_sSaslBuffer.clear(); - - auto pUser = CZNC::Get().FindUser(sAuthcId); + auto pUser = CZNC::Get().FindUser(sSASLUser); if (pUser && bAuthenticationSuccess) { PutClient(":irc.znc.in 900 " + GetNick() + " " + GetNick() + "!" + - pUser->GetIdent() + "@" + GetHostName() + " " + sAuthcId + - " :You are now logged in as " + sAuthzId); + pUser->GetIdent() + "@" + GetHostName() + " " + sSASLUser + + " :You are now logged in as " + sSASLUser); PutClient(":irc.znc.in 903 " + GetNick() + " :SASL authentication successful"); - m_bSaslAuthenticated = true; - m_bSaslAuthenticating = false; + m_sSASLUser = sSASLUser; + m_bSASLAuthenticating = false; } else { PutClient(":irc.znc.in 904 " + GetNick() + " :SASL authentication failed"); - SaslReset(); + SASLReset(); } return; } -CString CClient::EnumerateSaslMechanisms(SCString& ssMechanisms) { +CString CClient::EnumerateSASLMechanisms(SCString& ssMechanisms) { CString sMechanisms; - GLOBALMODULECALL(OnGetSaslMechanisms(ssMechanisms), NOTHING); + GLOBALMODULECALL(OnGetSASLMechanisms(ssMechanisms), NOTHING); if (ssMechanisms.size()) { sMechanisms = diff --git a/src/Modules.cpp b/src/Modules.cpp index 92a89451..a9ecd784 100644 --- a/src/Modules.cpp +++ b/src/Modules.cpp @@ -1076,7 +1076,7 @@ bool CModule::IsClientCapSupported(CClient* pClient, const CString& sCap, void CModule::OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState) {} -CModule::EModRet CModule::OnClientSaslAuthenticate(const CString& sMechanism, +CModule::EModRet CModule::OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sMechanismResponse, @@ -1084,12 +1084,12 @@ CModule::EModRet CModule::OnClientSaslAuthenticate(const CString& sMechanism, return CONTINUE; } -CModule::EModRet CModule::OnSaslServerChallenge(const CString& sMechanism, +CModule::EModRet CModule::OnSASLServerChallenge(const CString& sMechanism, CString& sResponse) { return CONTINUE; } -void CModule::OnGetSaslMechanisms(SCString& ssMechanisms) {} +void CModule::OnGetSASLMechanisms(SCString& ssMechanisms) {} CModule::EModRet CModule::OnModuleLoading(const CString& sModName, const CString& sArgs, @@ -1608,22 +1608,22 @@ bool CModules::OnClientCapRequest(CClient* pClient, const CString& sCap, return false; } -bool CModules::OnClientSaslAuthenticate(const CString& sMechanism, +bool CModules::OnClientSASLAuthenticate(const CString& sMechanism, const CString& sBuffer, CString& sUser, CString& sResponse, bool& bAuthenticationSuccess) { - MODHALTCHK(OnClientSaslAuthenticate(sMechanism, sBuffer, sUser, + MODHALTCHK(OnClientSASLAuthenticate(sMechanism, sBuffer, sUser, sResponse, bAuthenticationSuccess)); } -bool CModules::OnSaslServerChallenge(const CString& sMechanism, +bool CModules::OnSASLServerChallenge(const CString& sMechanism, CString& sResponse) { - MODHALTCHK(OnSaslServerChallenge(sMechanism, sResponse)); + MODHALTCHK(OnSASLServerChallenge(sMechanism, sResponse)); } -bool CModules::OnGetSaslMechanisms(SCString& ssMechanisms) { - MODUNLOADCHK(OnGetSaslMechanisms(ssMechanisms)); +bool CModules::OnGetSASLMechanisms(SCString& ssMechanisms) { + MODUNLOADCHK(OnGetSASLMechanisms(ssMechanisms)); return false; } diff --git a/test/integration/tests/modules.cpp b/test/integration/tests/modules.cpp index 6f2737ca..20afd3d1 100644 --- a/test/integration/tests/modules.cpp +++ b/test/integration/tests/modules.cpp @@ -330,5 +330,25 @@ TEST_F(ZNCTest, SaslMechsNotInit) { ircd.ReadUntil("PONG foo"); } +TEST_F(ZNCTest, SaslPlainModule) { + auto znc = Run(); + auto ircd = ConnectIRCd(); + auto client = LoginClient(); + client.Write("znc loadmod saslplain"); + client.ReadUntil("Loaded module"); + client.Close(); + + auto client2 = ConnectClient(); + client2.Write("NICK foo"); + client2.Write("CAP LS"); + client2.Write("CAP REQ :sasl"); + client2.ReadUntil(":irc.znc.in CAP foo ACK :sasl"); + client2.Write("USER bar"); + client2.Write("AUTHENTICATE PLAIN"); + client2.ReadUntil("AUTHENTICATE +"); + client2.Write("AUTHENTICATE AHVzZXIAaHVudGVyMg=="); // \0user\0hunter2 + client2.ReadUntil(":irc.znc.in 903 foo :SASL authentication successful"); +} + } // namespace } // namespace znc_inttest