diff --git a/include/znc/Client.h b/include/znc/Client.h index 5428cf44..083f2696 100644 --- a/include/znc/Client.h +++ b/include/znc/Client.h @@ -116,6 +116,10 @@ class CClient : public CIRCSocket { m_bBatch(false), m_bEchoMessage(false), m_bSelfMessage(false), + m_bSasl(false), + m_bSaslAuthenticating(false), + m_bSaslAuthenticated(false), + m_bSaslMultipart(false), m_bPlaybackActive(false), m_pUser(nullptr), m_pNetwork(nullptr), @@ -124,6 +128,8 @@ class CClient : public CIRCSocket { m_sUser(""), m_sNetwork(""), m_sIdentifier(""), + m_sSaslBuffer(""), + m_sSaslMechanism(""), m_spAuth(), m_ssAcceptedCaps(), m_ssSupportedTags(), @@ -156,6 +162,7 @@ class CClient : public CIRCSocket { }}}, {"extended-join", {true, [this](bool bVal) { m_bExtendedJoin = bVal; }}}, + {"sasl", {false, [this](bool bVal) { m_bSasl = bVal; m_bSaslAuthenticating = bVal; }}}, }) { EnableReadLine(); // RFC says a line can have 512 chars max, but we are @@ -333,6 +340,10 @@ class CClient : public CIRCSocket { unsigned int DetachChans(const std::set& sChans); bool OnActionMessage(CActionMessage& Message); + void OnAuthenticateMessage(CAuthenticateMessage& Message); + + CString EnumerateSaslMechanisms(SCString& ssMechanisms); + bool OnCTCPMessage(CCTCPMessage& Message); bool OnJoinMessage(CJoinMessage& Message); bool OnModeMessage(CModeMessage& Message); @@ -362,6 +373,10 @@ class CClient : public CIRCSocket { bool m_bBatch; bool m_bEchoMessage; bool m_bSelfMessage; + bool m_bSasl; + bool m_bSaslAuthenticating; + bool m_bSaslAuthenticated; + bool m_bSaslMultipart; bool m_bPlaybackActive; CUser* m_pUser; CIRCNetwork* m_pNetwork; @@ -370,6 +385,8 @@ class CClient : public CIRCSocket { CString m_sUser; CString m_sNetwork; CString m_sIdentifier; + CString m_sSaslBuffer; + CString m_sSaslMechanism; std::shared_ptr m_spAuth; SCString m_ssAcceptedCaps; SCString m_ssSupportedTags; diff --git a/include/znc/Message.h b/include/znc/Message.h index 0b6d374f..064a6984 100644 --- a/include/znc/Message.h +++ b/include/znc/Message.h @@ -78,6 +78,7 @@ class CMessage { Unknown, Account, Action, + Authenticate, Away, Capability, CTCP, @@ -250,6 +251,13 @@ class CActionMessage : public CTargetMessage { }; REGISTER_ZNC_MESSAGE(CActionMessage); +class CAuthenticateMessage : public CMessage { + public: + CString GetText() const { return GetParam(0); } + void SetText(const CString& sText) { SetParam(0, sText); } +}; +REGISTER_ZNC_MESSAGE(CAuthenticateMessage); + class CCTCPMessage : public CTargetMessage { public: bool IsReply() const { return GetCommand().Equals("NOTICE"); } diff --git a/include/znc/Modules.h b/include/znc/Modules.h index ef8f2219..a36531db 100644 --- a/include/znc/Modules.h +++ b/include/znc/Modules.h @@ -1308,6 +1308,14 @@ class CModule { */ virtual void OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState); + virtual EModRet OnSaslServerChallenge(const CString& sMechanism, + CString& sResponse); + virtual EModRet OnClientSaslAuthenticate(const CString& sMechanism, + const CString& sBuffer, + CString& sUser, + CString& sMechanismResponse, + bool& bAuthenticationSuccess); + virtual void OnGetSaslMechanisms(SCString& ssMechanisms); /** Called when a module is going to be loaded. * @param sModName name of the module. @@ -1587,6 +1595,15 @@ class CModules : public std::vector, private CCoreTranslationMixin { bool IsClientCapSupported(CClient* pClient, const CString& sCap, bool bState); bool OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState); + bool OnSaslServerChallenge(const CString& sMechanism, + CString& sResponse); + bool OnClientSaslAuthenticate(const CString& sMechanism, + const CString& sBuffer, + CString& sUser, + CString& sResponse, + bool& bAuthenticationSuccess); + bool OnGetSaslMechanisms(SCString& ssMechanisms); + bool OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, CString& sRetMsg); diff --git a/modules/saslplain.cpp b/modules/saslplain.cpp new file mode 100644 index 00000000..ce4cfbb0 --- /dev/null +++ b/modules/saslplain.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2004-2018 ZNC, see the NOTICE file for details. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +class CSASLMechanismPlain : public CModule { + public: + MODCONSTRUCTOR(CSASLMechanismPlain) { AddHelpCommand(); } + + EModRet OnClientSaslAuthenticate(const CString& sMechanism, + const CString& sBuffer, CString& sUser, + CString& sMechanismResponse, + bool& bAuthenticationSuccess) override { + if (!sMechanism.Equals("PLAIN")) { + return CONTINUE; + } + + bAuthenticationSuccess = false; + + CString sNullSeparator = std::string("\0", 1); + auto sAuthzId = sBuffer.Token(0, false, sNullSeparator); + auto sAuthcId = sBuffer.Token(1, false, sNullSeparator); + auto sPassword = sBuffer.Token(2, false, sNullSeparator); + + if (sAuthzId.empty()) sAuthzId = sAuthcId; + + auto pUser = CZNC::Get().FindUser(sAuthcId); + + if (!sAuthcId.empty() && !sPassword.empty()) { + if (pUser->CheckPass(sPassword)) { + bAuthenticationSuccess = true; + sUser = sAuthcId; + } + } + + return HALTMODS; + } + + void OnGetSaslMechanisms(SCString& ssMechanisms) override { + ssMechanisms.insert("PLAIN"); + } +}; + +template <> +void TModInfo(CModInfo& Info) { + Info.SetWikiPage("saslplain"); +} + +GLOBALMODULEDEFS( + CSASLMechanismPlain, + t_s("Allows users to authenticate via the PLAIN SASL mechanism.")) diff --git a/src/Client.cpp b/src/Client.cpp index 6d215218..f02676b8 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -179,6 +179,12 @@ void CClient::ReadLine(const CString& sData) { return; } + if (Message.GetType() == CMessage::Type::Authenticate) { + OnAuthenticateMessage(Message); + + return; + } + if (!m_pUser) { // Only CAP, NICK, USER and PASS are allowed before login return; @@ -314,9 +320,16 @@ bool CClient::SendMotd() { } void CClient::AuthUser() { - if (!m_bGotNick || !m_bGotUser || !m_bGotPass || m_bInCap || IsAttached()) + if (!m_bGotNick || !m_bGotUser || m_bInCap || + (!m_bSaslAuthenticated && !m_bGotPass) || IsAttached()) return; + if (m_bSasl && m_bSaslAuthenticated) { + auto pUser = CZNC::Get().FindUser(m_sUser); + AcceptLogin(*pUser); + return; + } + m_spAuth = std::make_shared(this, m_sUser, m_sPass); CZNC::Get().AuthUser(m_spAuth); @@ -380,6 +393,7 @@ void CClientAuth::AcceptedLogin(CUser& User) { void CClient::AcceptLogin(CUser& User) { m_sPass = ""; m_pUser = &User; + m_bSaslAuthenticating = m_bSasl; // Set our proper timeout and set back our proper timeout mode // (constructor set a different timeout and mode) @@ -695,8 +709,15 @@ void CClient::HandleCap(const CMessage& Message) { for (const auto& it : m_mCoreCaps) { bool bServerDependent = std::get<0>(it.second); if (!bServerDependent || - m_ssServerDependentCaps.count(it.first) > 0) + m_ssServerDependentCaps.count(it.first) > 0) { + if (it.first.Equals("sasl")) { + SCString ssMechanisms; + ssOfferCaps.insert(it.first + "=" + + EnumerateSaslMechanisms(ssMechanisms)); + } else { ssOfferCaps.insert(it.first); + } + } } GLOBALMODULECALL(OnClientCapLs(this, ssOfferCaps), NOTHING); CString sRes = @@ -709,7 +730,17 @@ void CClient::HandleCap(const CMessage& Message) { } else if (sSubCmd.Equals("END")) { m_bInCap = false; if (!IsAttached()) { - if (!m_pUser && m_bGotUser && !m_bGotPass) { + if (m_bSasl && !m_bSaslAuthenticated && m_bSaslAuthenticating) { + PutClient(":irc.znc.in 906 " + GetNick() + + " :SASL authentication aborted"); + m_sSaslMechanism = ""; + m_bSaslAuthenticated = false; + m_bSaslMultipart = false; + m_bSaslAuthenticating = false; + } + + if (!m_pUser && m_bGotUser && + (!m_bSaslAuthenticated && !m_bGotPass)) { SendRequiredPasswordNotice(); } else { AuthUser(); @@ -966,6 +997,149 @@ bool CClient::OnActionMessage(CActionMessage& Message) { return true; } +void CClient::OnAuthenticateMessage(CAuthenticateMessage& Message) { + const auto uiMaxSaslMsgLength = 400u; + auto bAuthenticationSuccess = false; + auto sMessage = Message.GetText(); + const auto sBufferSize = sMessage.length(); + SCString ssMechanisms; + + auto SaslReset = [this]() { + m_sSaslMechanism = ""; + m_sSaslBuffer = ""; + m_bSaslMultipart = false; + }; + + auto SaslChallenge = [this](CString sChallenge) { + sChallenge.Base64Encode(); + auto sChallengeSize = sChallenge.length(); + + if (sChallengeSize > uiMaxSaslMsgLength) { + for (auto i = 0u; i < sChallengeSize; i += uiMaxSaslMsgLength) { + CString sMsgPart = sChallenge.substr(i, uiMaxSaslMsgLength); + PutClient("AUTHENTICATE " + sMsgPart); + } + } else { + PutClient("AUTHENTICATE " + sChallenge); + } + }; + + if (!m_bSasl) return; + + if (m_bSaslAuthenticated || IsAttached()) { + PutClient(":irc.znc.in 907 " + GetNick() + + " :You have already authenticated using SASL"); + return; + } + + if (!m_bSaslAuthenticating || sMessage.Equals("*")) { + PutClient(":irc.znc.in 906 " + GetNick() + + " :SASL authentication aborted"); + if (!IsAttached()) { + m_bSaslAuthenticating = false; + SaslReset(); + } + return; + } + + auto sMechanisms = EnumerateSaslMechanisms(ssMechanisms); + + if (sBufferSize > uiMaxSaslMsgLength) { + PutClient(":irc.znc.in 905 " + GetNick() + " :SASL message too long"); + SaslReset(); + return; + } + + if (m_sSaslMechanism.empty()) { + if (ssMechanisms.find(sMessage) == ssMechanisms.end()) { + PutClient(":irc.znc.in 908 " + GetNick() + " " + sMechanisms + + " :are available SASL mechanisms"); + PutClient(":irc.znc.in 904 " + GetNick() + + " :SASL authentication failed"); + SaslReset(); + + return; + } + + m_sSaslMechanism = sMessage; + + auto bResult = false; + CString sChallenge; + GLOBALMODULECALL(OnSaslServerChallenge(m_sSaslMechanism, sChallenge), + &bResult); + if (bResult) { + SaslChallenge(sChallenge); + } else { + PutClient("AUTHENTICATE +"); + } + return; + } + + if (sBufferSize == uiMaxSaslMsgLength) { + m_bSaslMultipart = true; + m_sSaslBuffer.append(sMessage); + + return; + } + + if ((m_bSaslMultipart && !sMessage.Equals("+"))) { + m_sSaslBuffer.append(sMessage); + m_bSaslMultipart = false; + } else if (!m_bSaslMultipart && !sMessage.Equals("+")) { + m_sSaslBuffer.assign(sMessage); + } + + m_sSaslBuffer.Base64Decode(); + + auto sAuthcId = m_sUser; + auto sAuthzId = m_sUser; + + CString sResponse; + bool bResult; + + GLOBALMODULECALL(OnClientSaslAuthenticate( + m_sSaslMechanism, m_sSaslBuffer, sAuthcId, + sResponse, bAuthenticationSuccess), + &bResult); + + if (bResult && !sResponse.empty()) { + SaslChallenge(sResponse); + return; + } + + m_sSaslBuffer.clear(); + + auto pUser = CZNC::Get().FindUser(sAuthcId); + + if (pUser && bAuthenticationSuccess) { + PutClient(":irc.znc.in 900 " + GetNick() + " " + GetNick() + "!" + + pUser->GetIdent() + "@" + GetHostName() + " " + sAuthcId + + " :You are now logged in as " + sAuthzId); + PutClient(":irc.znc.in 903 " + GetNick() + + " :SASL authentication successful"); + m_bSaslAuthenticated = true; + m_bSaslAuthenticating = false; + } else { + PutClient(":irc.znc.in 904 " + GetNick() + " :SASL authentication failed"); + SaslReset(); + } + + return; +} + +CString CClient::EnumerateSaslMechanisms(SCString& ssMechanisms) { + CString sMechanisms; + + GLOBALMODULECALL(OnGetSaslMechanisms(ssMechanisms), NOTHING); + + if (ssMechanisms.size()) { + sMechanisms = + CString(",").Join(ssMechanisms.begin(), ssMechanisms.end()); + } + + return sMechanisms; +} + bool CClient::OnCTCPMessage(CCTCPMessage& Message) { CString sTargets = Message.GetTarget(); diff --git a/src/Message.cpp b/src/Message.cpp index 6a6af073..2956e31b 100644 --- a/src/Message.cpp +++ b/src/Message.cpp @@ -267,6 +267,7 @@ void CMessage::InitType() { } else { std::map mTypes = { {"ACCOUNT", Type::Account}, + {"AUTHENTICATE", Type::Authenticate}, {"AWAY", Type::Away}, {"CAP", Type::Capability}, {"ERROR", Type::Error}, diff --git a/src/Modules.cpp b/src/Modules.cpp index d69d6050..92a89451 100644 --- a/src/Modules.cpp +++ b/src/Modules.cpp @@ -1075,6 +1075,22 @@ bool CModule::IsClientCapSupported(CClient* pClient, const CString& sCap, } void CModule::OnClientCapRequest(CClient* pClient, const CString& sCap, bool bState) {} + +CModule::EModRet CModule::OnClientSaslAuthenticate(const CString& sMechanism, + const CString& sBuffer, + CString& sUser, + CString& sMechanismResponse, + bool& bAuthenticationSuccess) { + return CONTINUE; +} + +CModule::EModRet CModule::OnSaslServerChallenge(const CString& sMechanism, + CString& sResponse) { + return CONTINUE; +} + +void CModule::OnGetSaslMechanisms(SCString& ssMechanisms) {} + CModule::EModRet CModule::OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, @@ -1592,6 +1608,25 @@ bool CModules::OnClientCapRequest(CClient* pClient, const CString& sCap, return false; } +bool CModules::OnClientSaslAuthenticate(const CString& sMechanism, + const CString& sBuffer, + CString& sUser, + CString& sResponse, + bool& bAuthenticationSuccess) { + MODHALTCHK(OnClientSaslAuthenticate(sMechanism, sBuffer, sUser, + sResponse, bAuthenticationSuccess)); +} + +bool CModules::OnSaslServerChallenge(const CString& sMechanism, + CString& sResponse) { + MODHALTCHK(OnSaslServerChallenge(sMechanism, sResponse)); +} + +bool CModules::OnGetSaslMechanisms(SCString& ssMechanisms) { + MODUNLOADCHK(OnGetSaslMechanisms(ssMechanisms)); + return false; +} + bool CModules::OnModuleLoading(const CString& sModName, const CString& sArgs, CModInfo::EModuleType eType, bool& bSuccess, CString& sRetMsg) {