diff --git a/include/znc/Client.h b/include/znc/Client.h index c56961ab..2c77fa17 100644 --- a/include/znc/Client.h +++ b/include/znc/Client.h @@ -258,14 +258,25 @@ class CClient : public CIRCSocket { void AcceptSASLLogin(CUser& User); /** Start potentially asynchronous process of checking the credentials. * When finished, will send the success/failure SASL numerics to the - * client. This is mostly useful for SASL PLAIN. */ - void StartSASLPasswordCheck(const CString& sUser, const CString& sPassword); + * client. This is mostly useful for SASL PLAIN. + * sAuthorizationId is internally passed through ParseUser() to extract + * network and client id. + * Currently sUser should match the username from + * sAuthorizationId: either in full, or just the username part; but in a + * future version we may add an ability to actually login as a different + * user, but with your password. + */ + void StartSASLPasswordCheck(const CString& sUser, const CString& sPassword, + const CString& sAuthorizationId); + /** Gathers username, client id, network name, if present. Returns username + * cleaned from client id and network name. + */ + CString ParseUser(const CString& sAuthLine); private: void HandleCap(const CMessage& Message); void RespondCap(const CString& sResponse); void ParsePass(const CString& sAuthLine); - void ParseUser(const CString& sAuthLine); void ParseIdentifier(const CString& sAuthLine); template @@ -322,6 +333,7 @@ class CClient : public CIRCSocket { CIRCNetwork* m_pNetwork; CString m_sNick; CString m_sPass; + // User who didn't necessarily login yet, or might not even exist. CString m_sUser; CString m_sNetwork; CString m_sIdentifier; diff --git a/include/znc/Modules.h b/include/znc/Modules.h index dcdb9e93..7bdfbcba 100644 --- a/include/znc/Modules.h +++ b/include/znc/Modules.h @@ -1388,6 +1388,9 @@ class CModule { * GetClient()->SendSASLChallenge(), or reject authentication by calling * GetClient()->RefuseSASLLogin(), or accept it by calling * GetClient()->AcceptSASLLogin(). + * At some point before accepting the login, you should call + * GetClient()->ParseUser(authz-id) to let it know the network name to + * attach to and the client id. * @param sMechanism The SASL mechanism selected by the client. * @param sMessage The SASL opaque value/credentials sent by the client, * after debase64ing and concatenating if it was split. diff --git a/modules/saslplainauth.cpp b/modules/saslplainauth.cpp index 68bc062e..4cf7a7e4 100644 --- a/modules/saslplainauth.cpp +++ b/modules/saslplainauth.cpp @@ -36,13 +36,11 @@ class CSASLMechanismPlain : public CModule { CString sAuthcId = sMessage.Token(1, false, sNullSeparator, true); CString sPassword = sMessage.Token(2, false, sNullSeparator, true); - if (!sAuthzId.empty() && sAuthzId != sAuthcId) { - // Reject custom SASL plain authorization identifiers - GetClient()->RefuseSASLLogin("No support for custom AuthzId"); - return HALTMODS; + if (sAuthzId.empty()) { + sAuthzId = sAuthcId; } - GetClient()->StartSASLPasswordCheck(sAuthcId, sPassword); + GetClient()->StartSASLPasswordCheck(sAuthcId, sPassword, sAuthzId); return HALTMODS; } }; diff --git a/src/Client.cpp b/src/Client.cpp index 7796bd7c..c21925c4 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -387,8 +387,14 @@ class CClientSASLAuth : public CClientAuth { void RefusedLogin(const CString& sReason) override; }; -void CClient::StartSASLPasswordCheck(const CString& sUser, const CString& sPassword) { - m_spAuth = std::make_shared(this, sUser, sPassword); +void CClient::StartSASLPasswordCheck(const CString& sUser, + const CString& sPassword, const CString& sAuthorizationId) { + ParseUser(sAuthorizationId); + if (sUser != m_sUser && sUser != sAuthorizationId) { + RefuseSASLLogin("No support for custom AuthzId"); + } + + m_spAuth = std::make_shared(this, m_sUser, sPassword); CZNC::Get().AuthUser(m_spAuth); } @@ -973,7 +979,7 @@ void CClient::ParsePass(const CString& sAuthLine) { } } -void CClient::ParseUser(const CString& sAuthLine) { +CString CClient::ParseUser(const CString& sAuthLine) { // user[@identifier][/network] const size_t uSlash = sAuthLine.rfind("/"); @@ -984,6 +990,8 @@ void CClient::ParseUser(const CString& sAuthLine) { } else { ParseIdentifier(sAuthLine); } + + return m_sUser; } void CClient::ParseIdentifier(const CString& sAuthLine) {