diff --git a/app/fanout/bot.py b/app/fanout/bot.py index 7384325..202fe06 100644 --- a/app/fanout/bot.py +++ b/app/fanout/bot.py @@ -10,6 +10,36 @@ from app.fanout.base import FanoutModule logger = logging.getLogger(__name__) +def _derive_path_bytes_per_hop(paths: object, path_value: str | None) -> int | None: + """Derive hop width from the first serialized message path when possible.""" + if not isinstance(path_value, str) or not path_value: + return None + if not isinstance(paths, list) or not paths: + return None + + first_path = paths[0] + if not isinstance(first_path, dict): + return None + + path_hops = first_path.get("path_len") + if not isinstance(path_hops, int) or path_hops <= 0: + return None + + path_hex_chars = len(path_value) + if path_hex_chars % 2 != 0: + return None + + path_bytes = path_hex_chars // 2 + if path_bytes % path_hops != 0: + return None + + hop_width = path_bytes // path_hops + if hop_width not in (1, 2, 3): + return None + + return hop_width + + class BotModule(FanoutModule): """Wraps a single bot's code execution and response routing. @@ -101,11 +131,11 @@ class BotModule(FanoutModule): sender_timestamp = data.get("sender_timestamp") path_value = data.get("path") + paths = data.get("paths") # Message model serializes paths as list of dicts; extract first path string - if path_value is None: - paths = data.get("paths") - if paths and isinstance(paths, list) and len(paths) > 0: - path_value = paths[0].get("path") if isinstance(paths[0], dict) else None + if path_value is None and paths and isinstance(paths, list) and len(paths) > 0: + path_value = paths[0].get("path") if isinstance(paths[0], dict) else None + path_bytes_per_hop = _derive_path_bytes_per_hop(paths, path_value) # Wait for message to settle (allows retransmissions to be deduped) await asyncio.sleep(2) @@ -130,6 +160,7 @@ class BotModule(FanoutModule): sender_timestamp, path_value, is_outgoing, + path_bytes_per_hop, ), timeout=BOT_EXECUTION_TIMEOUT, ) diff --git a/app/fanout/bot_exec.py b/app/fanout/bot_exec.py index 8ca1200..ba77d28 100644 --- a/app/fanout/bot_exec.py +++ b/app/fanout/bot_exec.py @@ -15,6 +15,7 @@ import inspect import logging import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from typing import Any from fastapi import HTTPException @@ -39,6 +40,89 @@ _bot_send_lock = asyncio.Lock() _last_bot_send_time: float = 0.0 +@dataclass(frozen=True) +class BotCallPlan: + """How to call a validated bot() function.""" + + call_style: str + keyword_args: tuple[str, ...] = () + + +def _analyze_bot_signature(bot_func_or_sig) -> BotCallPlan: + """Validate bot() signature and return a supported call plan.""" + try: + sig = ( + bot_func_or_sig + if isinstance(bot_func_or_sig, inspect.Signature) + else inspect.signature(bot_func_or_sig) + ) + except (ValueError, TypeError) as exc: + raise ValueError("Bot function signature could not be inspected") from exc + + params = sig.parameters + param_values = tuple(params.values()) + positional_params = [ + p + for p in param_values + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + has_varargs = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in param_values) + has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in param_values) + explicit_optional_names = tuple( + name for name in ("is_outgoing", "path_bytes_per_hop") if name in params + ) + unsupported_required_kwonly = [ + p.name + for p in param_values + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is inspect.Parameter.empty + and p.name not in {"is_outgoing", "path_bytes_per_hop"} + ] + if unsupported_required_kwonly: + raise ValueError( + "Bot function signature is not supported. Unsupported required keyword-only " + "parameters: " + ", ".join(unsupported_required_kwonly) + ) + + positional_capacity = len(positional_params) + if not has_varargs and positional_capacity < 8: + raise ValueError( + "Bot function must accept at least 8 positional parameters before optional extras" + ) + + base_args = [object()] * 8 + candidate_specs: list[tuple[str, list[object], dict[str, object]]] = [] + if has_kwargs or explicit_optional_names: + kwargs: dict[str, object] = {} + if has_kwargs or "is_outgoing" in params: + kwargs["is_outgoing"] = False + if has_kwargs or "path_bytes_per_hop" in params: + kwargs["path_bytes_per_hop"] = 1 + candidate_specs.append(("keyword", base_args, kwargs)) + else: + if has_varargs or positional_capacity >= 10: + candidate_specs.append(("positional_10", base_args + [False, 1], {})) + if has_varargs or positional_capacity >= 9: + candidate_specs.append(("positional_9", base_args + [False], {})) + candidate_specs.append(("legacy", base_args, {})) + + for call_style, args, kwargs in candidate_specs: + try: + sig.bind(*args, **kwargs) + except TypeError: + continue + if call_style == "keyword": + return BotCallPlan(call_style="keyword", keyword_args=tuple(kwargs.keys())) + return BotCallPlan(call_style=call_style) + + raise ValueError( + "Bot function signature is not supported. Use the default bot template as a reference. " + "Supported trailing parameters are: path; path + is_outgoing; " + "path + path_bytes_per_hop; path + is_outgoing + path_bytes_per_hop; " + "or use **kwargs for forward compatibility." + ) + + def execute_bot_code( code: str, sender_name: str | None, @@ -50,17 +134,18 @@ def execute_bot_code( sender_timestamp: int | None, path: str | None, is_outgoing: bool = False, + path_bytes_per_hop: int | None = None, ) -> str | list[str] | None: """ Execute user-provided bot code with message context. The code should define a function: - `bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, sender_timestamp, path, is_outgoing)` + `bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, sender_timestamp, path, is_outgoing, path_bytes_per_hop)` that returns either None (no response), a string (single response message), or a list of strings (multiple messages sent in order). - Legacy bot functions with 8 parameters (without is_outgoing) are detected - via inspect and called without the new parameter for backward compatibility. + Legacy bot functions with older signatures are detected via inspect and + called without the newer parameters for backward compatibility. Args: code: Python code defining the bot function @@ -73,6 +158,7 @@ def execute_bot_code( sender_timestamp: Sender's timestamp from the message (may be None) path: Hex-encoded routing path (may be None) is_outgoing: True if this is our own outgoing message + path_bytes_per_hop: Number of bytes per routing hop (1, 2, or 3), if known Returns: Response string, list of strings, or None. @@ -100,30 +186,28 @@ def execute_bot_code( return None bot_func = namespace["bot"] - - # Detect whether the bot function accepts is_outgoing (new 9-param signature) - # or uses the legacy 8-param signature, for backward compatibility. - # Three cases: explicit is_outgoing param or 9+ params (positional), - # **kwargs (pass as keyword), or legacy 8-param (omit). - call_style = "legacy" # "positional", "keyword", or "legacy" try: - sig = inspect.signature(bot_func) - params = sig.parameters - non_variadic = [ - p - for p in params.values() - if p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - ] - if "is_outgoing" in params or len(non_variadic) >= 9: - call_style = "positional" - elif any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()): - call_style = "keyword" - except (ValueError, TypeError): - pass + call_plan = _analyze_bot_signature(bot_func) + except ValueError as exc: + logger.error("%s", exc) + return None try: # Call the bot function with appropriate signature - if call_style == "positional": + if call_plan.call_style == "positional_10": + result = bot_func( + sender_name, + sender_key, + message_text, + is_dm, + channel_key, + channel_name, + sender_timestamp, + path, + is_outgoing, + path_bytes_per_hop, + ) + elif call_plan.call_style == "positional_9": result = bot_func( sender_name, sender_key, @@ -135,7 +219,12 @@ def execute_bot_code( path, is_outgoing, ) - elif call_style == "keyword": + elif call_plan.call_style == "keyword": + keyword_args: dict[str, Any] = {} + if "is_outgoing" in call_plan.keyword_args: + keyword_args["is_outgoing"] = is_outgoing + if "path_bytes_per_hop" in call_plan.keyword_args: + keyword_args["path_bytes_per_hop"] = path_bytes_per_hop result = bot_func( sender_name, sender_key, @@ -145,7 +234,7 @@ def execute_bot_code( channel_name, sender_timestamp, path, - is_outgoing=is_outgoing, + **keyword_args, ) else: result = bot_func( diff --git a/app/routers/fanout.py b/app/routers/fanout.py index 76349e9..676f52b 100644 --- a/app/routers/fanout.py +++ b/app/routers/fanout.py @@ -1,5 +1,7 @@ """REST API for fanout config CRUD.""" +import ast +import inspect import logging import re import string @@ -8,6 +10,7 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from app.config import settings as server_settings +from app.fanout.bot_exec import _analyze_bot_signature from app.repository.fanout import FanoutConfigRepository logger = logging.getLogger(__name__) @@ -144,18 +147,78 @@ def _validate_mqtt_community_config(config: dict) -> None: def _validate_bot_config(config: dict) -> None: - """Validate bot config blob (syntax-check the code).""" + """Validate bot config blob (syntax-check the code and supported signature).""" code = config.get("code", "") if not code or not code.strip(): raise HTTPException(status_code=400, detail="Bot code cannot be empty") try: - compile(code, "", "exec") + tree = ast.parse(code, filename="", mode="exec") except SyntaxError as e: raise HTTPException( status_code=400, detail=f"Bot code has syntax error at line {e.lineno}: {e.msg}", ) from None + bot_def = next( + ( + node + for node in tree.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == "bot" + ), + None, + ) + if bot_def is None: + raise HTTPException( + status_code=400, + detail=( + "Bot code must define a callable bot() function. " + "Use the default bot template as a reference." + ), + ) + + try: + parameters: list[inspect.Parameter] = [] + positional_args = [ + *((arg, inspect.Parameter.POSITIONAL_ONLY) for arg in bot_def.args.posonlyargs), + *((arg, inspect.Parameter.POSITIONAL_OR_KEYWORD) for arg in bot_def.args.args), + ] + positional_defaults_start = len(positional_args) - len(bot_def.args.defaults) + sentinel_default = object() + + for index, (arg, kind) in enumerate(positional_args): + has_default = index >= positional_defaults_start + parameters.append( + inspect.Parameter( + arg.arg, + kind=kind, + default=sentinel_default if has_default else inspect.Parameter.empty, + ) + ) + if bot_def.args.vararg is not None: + parameters.append( + inspect.Parameter(bot_def.args.vararg.arg, kind=inspect.Parameter.VAR_POSITIONAL) + ) + for kwonly_arg, kw_default in zip( + bot_def.args.kwonlyargs, bot_def.args.kw_defaults, strict=True + ): + parameters.append( + inspect.Parameter( + kwonly_arg.arg, + kind=inspect.Parameter.KEYWORD_ONLY, + default=( + sentinel_default if kw_default is not None else inspect.Parameter.empty + ), + ) + ) + if bot_def.args.kwarg is not None: + parameters.append( + inspect.Parameter(bot_def.args.kwarg.arg, kind=inspect.Parameter.VAR_KEYWORD) + ) + + _analyze_bot_signature(inspect.Signature(parameters)) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from None + def _validate_apprise_config(config: dict) -> None: """Validate apprise config blob.""" diff --git a/frontend/src/components/settings/SettingsFanoutSection.tsx b/frontend/src/components/settings/SettingsFanoutSection.tsx index 31e6a51..c06228a 100644 --- a/frontend/src/components/settings/SettingsFanoutSection.tsx +++ b/frontend/src/components/settings/SettingsFanoutSection.tsx @@ -74,35 +74,33 @@ function getDefaultIntegrationName(type: string, configs: FanoutConfig[]) { return `${label} #${nextIndex}`; } -const DEFAULT_BOT_CODE = `def bot( - sender_name: str | None, - sender_key: str | None, - message_text: str, - is_dm: bool, - channel_key: str | None, - channel_name: str | None, - sender_timestamp: int | None, - path: str | None, - is_outgoing: bool = False, -) -> str | list[str] | None: +const DEFAULT_BOT_CODE = `def bot(**kwargs) -> str | list[str] | None: """ Process messages and optionally return a reply. Args: - sender_name: Display name of sender (may be None) - sender_key: 64-char hex public key (None for channel msgs) - message_text: The message content - is_dm: True for direct messages, False for channel - channel_key: 32-char hex key for channels, None for DMs - channel_name: Channel name with hash (e.g. "#bot"), None for DMs - sender_timestamp: Sender's timestamp (unix seconds, may be None) - path: Hex-encoded routing path (may be None) - is_outgoing: True if this is our own outgoing message + kwargs keys currently provided: + sender_name: Display name of sender (may be None) + sender_key: 64-char hex public key (None for channel msgs) + message_text: The message content + is_dm: True for direct messages, False for channel + channel_key: 32-char hex key for channels, None for DMs + channel_name: Channel name with hash (e.g. "#bot"), None for DMs + sender_timestamp: Sender's timestamp (unix seconds, may be None) + path: Hex-encoded routing path (may be None) + is_outgoing: True if this is our own outgoing message + path_bytes_per_hop: Bytes per hop in path (1, 2, or 3) when known Returns: None for no reply, a string for a single reply, or a list of strings to send multiple messages in order """ + sender_name = kwargs.get("sender_name") + message_text = kwargs.get("message_text", "") + channel_name = kwargs.get("channel_name") + is_outgoing = kwargs.get("is_outgoing", False) + path_bytes_per_hop = kwargs.get("path_bytes_per_hop") + # Don't reply to our own outgoing messages if is_outgoing: return None diff --git a/tests/test_bot.py b/tests/test_bot.py index ccc604b..3b2bf7b 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -312,6 +312,67 @@ def bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, ) assert result == "outgoing=True" + def test_new_10_param_bot_receives_path_bytes_per_hop(self): + """Bots that declare path_bytes_per_hop receive it positionally.""" + code = """ +def bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, sender_timestamp, path, is_outgoing, path_bytes_per_hop): + return f"bytes={path_bytes_per_hop}" +""" + result = execute_bot_code( + code=code, + sender_name="Alice", + sender_key="abc123", + message_text="Hi", + is_dm=True, + channel_key=None, + channel_name=None, + sender_timestamp=None, + path="aabb", + path_bytes_per_hop=2, + ) + assert result == "bytes=2" + + def test_9_param_bot_with_path_bytes_only_receives_it(self): + """Bots may opt into path_bytes_per_hop without also declaring is_outgoing.""" + code = """ +def bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, sender_timestamp, path, path_bytes_per_hop): + return f"bytes={path_bytes_per_hop}" +""" + result = execute_bot_code( + code=code, + sender_name="Alice", + sender_key="abc123", + message_text="Hi", + is_dm=True, + channel_key=None, + channel_name=None, + sender_timestamp=None, + path="aabb", + is_outgoing=True, + path_bytes_per_hop=2, + ) + assert result == "bytes=2" + + def test_legacy_bot_with_kwargs_receives_path_bytes_per_hop(self): + """Bots using **kwargs receive the new path_bytes_per_hop field.""" + code = """ +def bot(sender_name, sender_key, message_text, is_dm, channel_key, channel_name, sender_timestamp, path, **kwargs): + return f"bytes={kwargs.get('path_bytes_per_hop', 'missing')}" +""" + result = execute_bot_code( + code=code, + sender_name="Alice", + sender_key="abc123", + message_text="Hi", + is_dm=True, + channel_key=None, + channel_name=None, + sender_timestamp=None, + path="aabb", + path_bytes_per_hop=2, + ) + assert result == "bytes=2" + def test_channel_message_with_none_sender_key(self): """Channel messages correctly pass None for sender_key.""" code = """ @@ -419,7 +480,14 @@ class TestBotCodeValidation: from app.routers.fanout import _validate_bot_config # Should not raise - _validate_bot_config({"code": "def bot(): return 'hello'"}) + _validate_bot_config( + { + "code": ( + "def bot(sender_name, sender_key, message_text, is_dm, channel_key, " + "channel_name, sender_timestamp, path):\n return 'hello'" + ) + } + ) def test_syntax_error_raises(self): """Syntax error in code raises HTTPException.""" @@ -456,6 +524,38 @@ class TestBotCodeValidation: assert exc_info.value.status_code == 400 + def test_missing_bot_function_raises(self): + """Code must define a callable bot() function.""" + from fastapi import HTTPException + + from app.routers.fanout import _validate_bot_config + + with pytest.raises(HTTPException) as exc_info: + _validate_bot_config({"code": "def helper():\n return 'hello'"}) + + assert exc_info.value.status_code == 400 + assert "callable bot() function" in exc_info.value.detail + + def test_unsupported_signature_raises(self): + """Unsupported bot signatures are rejected with guidance.""" + from fastapi import HTTPException + + from app.routers.fanout import _validate_bot_config + + with pytest.raises(HTTPException) as exc_info: + _validate_bot_config( + { + "code": ( + "def bot(sender_name, sender_key, message_text, is_dm, channel_key, " + "channel_name, sender_timestamp, path, *, extra_required):\n" + " return extra_required" + ) + } + ) + + assert exc_info.value.status_code == 400 + assert "signature is not supported" in exc_info.value.detail.lower() + class TestBotMessageRateLimiting: """Test bot message rate limiting for repeater compatibility.""" diff --git a/tests/test_fanout_hitlist.py b/tests/test_fanout_hitlist.py index 4cee09a..fe60090 100644 --- a/tests/test_fanout_hitlist.py +++ b/tests/test_fanout_hitlist.py @@ -36,6 +36,7 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["is_outgoing"] = is_outgoing captured["is_dm"] = is_dm @@ -84,6 +85,7 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["is_outgoing"] = is_outgoing return None @@ -129,8 +131,10 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["path"] = path + captured["path_bytes_per_hop"] = path_bytes_per_hop return None mod = BotModule("test", {"code": "def bot(**k): pass"}, name="Test") @@ -150,11 +154,12 @@ class TestBotModuleParameterExtraction: "type": "PRIV", "conversation_key": "pk1", "text": "hello", - "paths": [{"path": "aabb", "rssi": -50}], + "paths": [{"path": "aabbccdd", "path_len": 2, "rssi": -50}], } ) - assert captured["path"] == "aabb" + assert captured["path"] == "aabbccdd" + assert captured["path_bytes_per_hop"] == 2 @pytest.mark.asyncio async def test_channel_sender_prefix_stripped(self): @@ -174,6 +179,7 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["message_text"] = message_text captured["sender_name"] = sender_name @@ -221,6 +227,7 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["channel_name"] = channel_name return None @@ -267,6 +274,7 @@ class TestBotModuleParameterExtraction: sender_timestamp, path, is_outgoing, + path_bytes_per_hop, ): captured["sender_name"] = sender_name captured["sender_key"] = sender_key