diff --git a/ClientCommand.cpp b/ClientCommand.cpp index 80decb72..06f1b7b4 100644 --- a/ClientCommand.cpp +++ b/ClientCommand.cpp @@ -591,13 +591,13 @@ void CClient::UserCommand(CString& sLine) { switch (ModInfo.GetType()) { case ModuleTypeGlobal: if (m_pUser->IsAdmin()) { - b = CZNC::Get().GetModules().LoadModule(sMod, sArgs, NULL, sModRet); + b = CZNC::Get().GetModules().LoadModule(sMod, sArgs, ModuleTypeGlobal, NULL, sModRet); } else { sModRet = "Unable to load global module [" + sMod + "] Access Denied."; } break; case ModuleTypeUser: - b = m_pUser->GetModules().LoadModule(sMod, sArgs, m_pUser, sModRet); + b = m_pUser->GetModules().LoadModule(sMod, sArgs, ModuleTypeUser, m_pUser, sModRet); break; default: sModRet = "Unable to load module [" + sMod + "] Unknown module type"; diff --git a/Modules.cpp b/Modules.cpp index 85db0283..07e2c07b 100644 --- a/Modules.cpp +++ b/Modules.cpp @@ -577,7 +577,7 @@ void CModule::OnClientCapLs(SCString& ssCaps) {} bool CModule::IsClientCapSupported(const CString& sCap, bool bState) { return false; } void CModule::OnClientCapRequest(const CString& sCap, bool bState) {} CModule::EModRet CModule::OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg) { return CONTINUE; } + EModuleType eType, bool& bSuccess, CString& sRetMsg) { return CONTINUE; } CModule::EModRet CModule::OnModuleUnloading(CModule* pModule, bool& bSuccess, CString& sRetMsg) { return CONTINUE; } @@ -770,8 +770,8 @@ bool CModules::OnClientCapRequest(const CString& sCap, bool bState) { } bool CModules::OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg) { - MODHALTCHK(OnModuleLoading(sModName, sArgs, bSuccess, sRetMsg)); + EModuleType eType, bool& bSuccess, CString& sRetMsg) { + MODHALTCHK(OnModuleLoading(sModName, sArgs, eType, bSuccess, sRetMsg)); } bool CModules::OnModuleUnloading(CModule* pModule, bool& bSuccess, CString& sRetMsg) { @@ -799,7 +799,7 @@ CModule* CModules::FindModule(const CString& sModule) const { return NULL; } -bool CModules::LoadModule(const CString& sModule, const CString& sArgs, CUser* pUser, CString& sRetMsg) { +bool CModules::LoadModule(const CString& sModule, const CString& sArgs, EModuleType eType, CUser* pUser, CString& sRetMsg) { sRetMsg = ""; if (FindModule(sModule) != NULL) { @@ -808,7 +808,7 @@ bool CModules::LoadModule(const CString& sModule, const CString& sArgs, CUser* p } bool bSuccess; - GLOBALMODULECALL(OnModuleLoading(sModule, sArgs, bSuccess, sRetMsg), pUser, NULL, return bSuccess); + GLOBALMODULECALL(OnModuleLoading(sModule, sArgs, eType, bSuccess, sRetMsg), pUser, NULL, return bSuccess); CString sModPath, sDataPath; bool bVersionMismatch; @@ -830,20 +830,31 @@ bool CModules::LoadModule(const CString& sModule, const CString& sArgs, CUser* p return false; } - if ((pUser == NULL) != (Info.GetType() == ModuleTypeGlobal)) { + if (!Info.SupportsModule(eType)) { dlclose(p); - sRetMsg = "Module [" + sModule + "] is "; - sRetMsg += (Info.GetType() == ModuleTypeGlobal) ? "" : "not "; - sRetMsg += "a global module."; + sRetMsg = "Module [ + sModule + ] does not support module type."; + return false; + } + + if (!pUser && eType == ModuleTypeUser) { + dlclose(p); + sRetMsg = "Module [" + sModule + "] require a user."; return false; } CModule* pModule = NULL; - if (pUser) { - pModule = Info.GetLoader()(p, pUser, sModule, sDataPath); - } else { - pModule = Info.GetGlobalLoader()(p, sModule, sDataPath); + switch (eType) { + case ModuleTypeUser: + pModule = Info.GetLoader()(p, pUser, sModule, sDataPath); + break; + case ModuleTypeGlobal: + pModule = Info.GetGlobalLoader()(p, sModule, sDataPath); + break; + default: + dlclose(p); + sRetMsg = "Unsupported module type"; + return false; } pModule->SetDescription(Info.GetDescription()); @@ -918,12 +929,22 @@ bool CModules::UnloadModule(const CString& sModule, CString& sRetMsg) { bool CModules::ReloadModule(const CString& sModule, const CString& sArgs, CUser* pUser, CString& sRetMsg) { CString sMod = sModule; // Make a copy incase the reference passed in is from CModule::GetModName() + CModule *pModule = FindModule(sMod); + + if (!pModule) { + sRetMsg = "Module [" + sMod + "] not loaded"; + return false; + } + + EModuleType eType = pModule->GetType(); + pModule = NULL; + sRetMsg = ""; if (!UnloadModule(sMod, sRetMsg)) { return false; } - if (!LoadModule(sMod, sArgs, pUser, sRetMsg)) { + if (!LoadModule(sMod, sArgs, eType, pUser, sRetMsg)) { return false; } diff --git a/Modules.h b/Modules.h index 67ffa568..f1bc9d24 100644 --- a/Modules.h +++ b/Modules.h @@ -195,6 +195,10 @@ public: return (GetName() < Info.GetName()); } + bool SupportsModule(EModuleType eType) { + return eType == m_eType; + } + // Getters const CString& GetName() const { return m_sName; } const CString& GetPath() const { return m_sPath; } @@ -940,7 +944,7 @@ public: * @return See CModule::EModRet. */ virtual EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg); + EModuleType eType, bool& bSuccess, CString& sRetMsg); /** Called when a module is going to be unloaded. * @param pModule the module. * @param[out] bSuccess the module was unloaded successfully @@ -1062,7 +1066,7 @@ public: bool OnServerCapResult(const CString& sCap, bool bSuccess); CModule* FindModule(const CString& sModule) const; - bool LoadModule(const CString& sModule, const CString& sArgs, CUser* pUser, CString& sRetMsg); + bool LoadModule(const CString& sModule, const CString& sArgs, EModuleType eType, CUser* pUser, CString& sRetMsg); bool UnloadModule(const CString& sModule); bool UnloadModule(const CString& sModule, CString& sRetMsg); bool ReloadModule(const CString& sModule, const CString& sArgs, CUser* pUser, CString& sRetMsg); @@ -1091,7 +1095,7 @@ public: bool IsClientCapSupported(const CString& sCap, bool bState); bool OnClientCapRequest(const CString& sCap, bool bState); bool OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg); + EModuleType eType, bool& bSuccess, CString& sRetMsg); bool OnModuleUnloading(CModule* pModule, bool& bSuccess, CString& sRetMsg); bool OnGetModInfo(CModInfo& ModInfo, const CString& sModule, bool& bSuccess, CString& sRetMsg); diff --git a/User.cpp b/User.cpp index a906ed91..e86f6296 100644 --- a/User.cpp +++ b/User.cpp @@ -187,7 +187,7 @@ bool CUser::ParseConfig(CConfig* pConfig, CString& sError) { if (sValue.ToBool()) { CUtils::PrintAction("Loading Module [bouncedcc]"); CString sModRet; - bool bModRet = GetModules().LoadModule("bouncedcc", "", this, sModRet); + bool bModRet = GetModules().LoadModule("bouncedcc", "", ModuleTypeUser, this, sModRet); CUtils::PrintStatus(bModRet, sModRet); if (!bModRet) { @@ -318,7 +318,7 @@ bool CUser::ParseConfig(CConfig* pConfig, CString& sError) { CString sModRet; CString sArgs = sValue.Token(1, true); - bool bModRet = GetModules().LoadModule(sModName, sArgs, this, sModRet); + bool bModRet = GetModules().LoadModule(sModName, sArgs, ModuleTypeUser, this, sModRet); CUtils::PrintStatus(bModRet, sModRet); if (!bModRet) { @@ -353,7 +353,7 @@ bool CUser::UpdateModule(const CString &sModule) { CString sErr; for (it2 = Affected.begin(); it2 != Affected.end(); ++it2) { - if (!it2->first->GetModules().LoadModule(sModule, it2->second, it2->first, sErr)) { + if (!it2->first->GetModules().LoadModule(sModule, it2->second, ModuleTypeUser, it2->first, sErr)) { error = true; DEBUG("Failed to reload [" << sModule << "] for [" << it2->first->GetUserName() << "]: " << sErr); @@ -696,7 +696,7 @@ bool CUser::Clone(const CUser& User, CString& sErrorRet, bool bCloneChans) { CModule* pCurMod = vCurMods.FindModule(pNewMod->GetModName()); if (!pCurMod) { - vCurMods.LoadModule(pNewMod->GetModName(), pNewMod->GetArgs(), this, sModRet); + vCurMods.LoadModule(pNewMod->GetModName(), pNewMod->GetArgs(), ModuleTypeUser, this, sModRet); } else if (pNewMod->GetArgs() != pCurMod->GetArgs()) { vCurMods.ReloadModule(pNewMod->GetModName(), pNewMod->GetArgs(), this, sModRet); } diff --git a/modules/admin.cpp b/modules/admin.cpp index aaaf6a37..a3413278 100644 --- a/modules/admin.cpp +++ b/modules/admin.cpp @@ -741,7 +741,7 @@ class CAdminMod : public CModule { CModule *pMod = (pUser)->GetModules().FindModule(sModName); if (!pMod) { - if (!(pUser)->GetModules().LoadModule(sModName, sArgs, pUser, sModRet)) { + if (!(pUser)->GetModules().LoadModule(sModName, sArgs, ModuleTypeUser, pUser, sModRet)) { PutModule("Unable to load module [" + sModName + "] [" + sModRet + "]"); } else { PutModule("Loaded module [" + sModName + "]"); diff --git a/modules/modperl.cpp b/modules/modperl.cpp index 690f7397..7798620f 100644 --- a/modules/modperl.cpp +++ b/modules/modperl.cpp @@ -75,8 +75,8 @@ public: } virtual EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg) { - if (!GetUser()) { + EModuleType eType, bool& bSuccess, CString& sRetMsg) { + if (!GetUser() || eType != ModuleTypeUser) { return CONTINUE; } EModRet result = HALT; diff --git a/modules/modpython.cpp b/modules/modpython.cpp index f0d635a4..6cba90d8 100644 --- a/modules/modpython.cpp +++ b/modules/modpython.cpp @@ -127,7 +127,7 @@ public: } virtual EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, - bool& bSuccess, CString& sRetMsg) { + EModuleType eType, bool& bSuccess, CString& sRetMsg) { PyObject* pyFunc = PyObject_GetAttrString(m_PyZNCModule, "load_module"); if (!pyFunc) { sRetMsg = GetPyExceptionStr(); @@ -138,7 +138,7 @@ public: PyObject* pyRes = PyObject_CallFunction(pyFunc, const_cast("ssNNN"), sModName.c_str(), sArgs.c_str(), - SWIG_NewInstanceObj(GetUser(), SWIG_TypeQuery("CUser*"), 0), + (eType == ModuleTypeUser ? SWIG_NewInstanceObj(GetUser(), SWIG_TypeQuery("CUser*"), 0) : NULL), CPyRetString::wrap(sRetMsg), SWIG_NewInstanceObj(reinterpret_cast(this), SWIG_TypeQuery("CModule*"), 0)); if (!pyRes) { diff --git a/modules/modpython/functions.in b/modules/modpython/functions.in index fe55c4b1..2c835e30 100644 --- a/modules/modpython/functions.in +++ b/modules/modpython/functions.in @@ -61,4 +61,4 @@ void OnServerCapResult(const CString& sCap, bool bSuccess) EModRet OnTimerAutoJoin(CChan& Channel) bool OnEmbeddedWebRequest(CWebSock& WebSock, const CString& sPageName, CTemplate& Tmpl)=false -EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, bool& bSuccess, CString& sRetMsg) +EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, EModuleType eType, bool& bSuccess, CString& sRetMsg) diff --git a/modules/modpython/module.h b/modules/modpython/module.h index 11da0351..7bb78414 100644 --- a/modules/modpython/module.h +++ b/modules/modpython/module.h @@ -115,7 +115,7 @@ public: virtual void OnServerCapResult(const CString& sCap, bool bSuccess); virtual EModRet OnTimerAutoJoin(CChan& Channel); bool OnEmbeddedWebRequest(CWebSock&, const CString&, CTemplate&); -EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, bool& bSuccess, CString& sRetMsg); +EModRet OnModuleLoading(const CString& sModName, const CString& sArgs, EModuleType eType, bool& bSuccess, CString& sRetMsg); }; static inline CPyModule* AsPyModule(CModule* p) { diff --git a/modules/webadmin.cpp b/modules/webadmin.cpp index 1418cfb2..4bcf3df9 100644 --- a/modules/webadmin.cpp +++ b/modules/webadmin.cpp @@ -272,7 +272,7 @@ public: CString sArgs = WebSock.GetParam("modargs_" + sModName); try { - if (!pNewUser->GetModules().LoadModule(sModName, sArgs, pNewUser, sModRet)) { + if (!pNewUser->GetModules().LoadModule(sModName, sArgs, ModuleTypeUser, pNewUser, sModRet)) { sModLoadError = "Unable to load module [" + sModName + "] [" + sModRet + "]"; } } catch (...) { @@ -295,7 +295,7 @@ public: CString sModLoadError; try { - if (!pNewUser->GetModules().LoadModule(sModName, sArgs, pNewUser, sModRet)) { + if (!pNewUser->GetModules().LoadModule(sModName, sArgs, ModuleTypeUser, pNewUser, sModRet)) { sModLoadError = "Unable to load module [" + sModName + "] [" + sModRet + "]"; } } catch (...) { @@ -1071,7 +1071,7 @@ public: CModule *pMod = CZNC::Get().GetModules().FindModule(sModName); if (!pMod) { - if (!CZNC::Get().GetModules().LoadModule(sModName, sArgs, NULL, sModRet)) { + if (!CZNC::Get().GetModules().LoadModule(sModName, sArgs, ModuleTypeGlobal, NULL, sModRet)) { sModLoadError = "Unable to load module [" + sModName + "] [" + sModRet + "]"; } } else if (pMod->GetArgs() != sArgs) { diff --git a/znc.cpp b/znc.cpp index fca9dda8..bb776d68 100644 --- a/znc.cpp +++ b/znc.cpp @@ -1058,7 +1058,7 @@ bool CZNC::DoRehash(CString& sError) if (!pOldMod) { CUtils::PrintAction("Loading Global Module [" + sModName + "]"); - bool bModRet = GetModules().LoadModule(sModName, sArgs, NULL, sModRet); + bool bModRet = GetModules().LoadModule(sModName, sArgs, ModuleTypeGlobal, NULL, sModRet); CUtils::PrintStatus(bModRet, sModRet); if (!bModRet) { @@ -1090,7 +1090,7 @@ bool CZNC::DoRehash(CString& sError) CUtils::PrintAction("Loading Global Module [identfile]"); CString sModRet; - bool bModRet = GetModules().LoadModule("identfile", "", NULL, sModRet); + bool bModRet = GetModules().LoadModule("identfile", "", ModuleTypeGlobal, NULL, sModRet); CUtils::PrintStatus(bModRet, sModRet); if (!bModRet) {