mirror of
https://github.com/pyMC-dev/pyMC_Repeater.git
synced 2026-06-11 00:34:46 +02:00
692 lines
25 KiB
Python
692 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional
|
|
|
|
from pymc_core.protocol.constants import PAYLOAD_TYPE_GRP_DATA, PAYLOAD_TYPE_GRP_TXT
|
|
from pymc_core.protocol.crypto import CryptoUtils
|
|
|
|
logger = logging.getLogger("PolicyEngine")
|
|
|
|
|
|
SUPPORTED_ACTIONS = {
|
|
"allow",
|
|
"drop",
|
|
"log_only",
|
|
}
|
|
|
|
|
|
def default_policy_engine_config() -> dict[str, Any]:
|
|
return {
|
|
"enabled": False,
|
|
"default_action": "allow",
|
|
"rules": [],
|
|
"objects": {},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class PolicyDecision:
|
|
action: str = "allow"
|
|
matched: bool = False
|
|
rule_id: Optional[Any] = None
|
|
reason: Optional[str] = None
|
|
|
|
|
|
class PolicyEngine:
|
|
"""Readable top-down rule evaluator for repeater policy decisions."""
|
|
|
|
def __init__(self, policy_config: Optional[dict] = None):
|
|
cfg = default_policy_engine_config()
|
|
if isinstance(policy_config, dict):
|
|
cfg.update(policy_config)
|
|
|
|
self.enabled = bool(cfg.get("enabled", False))
|
|
self.default_action = str(cfg.get("default_action", "allow"))
|
|
if self.default_action not in SUPPORTED_ACTIONS:
|
|
logger.warning(
|
|
"Policy default_action '%s' is not supported, using 'allow'",
|
|
self.default_action,
|
|
)
|
|
self.default_action = "allow"
|
|
|
|
rules = cfg.get("rules")
|
|
self.rules: list[dict[str, Any]] = rules if isinstance(rules, list) else []
|
|
|
|
objects = cfg.get("objects")
|
|
self.objects: dict[str, Any] = objects if isinstance(objects, dict) else {}
|
|
self._channel_decrypt_cache: dict[int, dict[str, Any]] = {}
|
|
self._inline_channel_secrets = self._collect_inline_rule_channel_secrets(self.rules)
|
|
|
|
@classmethod
|
|
def from_runtime_config(cls, runtime_config: Optional[dict]) -> "PolicyEngine":
|
|
if not isinstance(runtime_config, dict):
|
|
return cls()
|
|
return cls(runtime_config.get("policy_engine", {}))
|
|
|
|
def evaluate(self, packet, context: dict) -> PolicyDecision:
|
|
if not self.enabled:
|
|
return PolicyDecision(action="allow", matched=False, reason="policy_disabled")
|
|
|
|
for rule in self.rules:
|
|
if not isinstance(rule, dict):
|
|
continue
|
|
if not bool(rule.get("enabled", True)):
|
|
continue
|
|
|
|
if not self._rule_matches(rule, packet, context):
|
|
continue
|
|
|
|
action = self._resolve_action(rule)
|
|
rule_id = rule.get("id")
|
|
rule_name = rule.get("name") or "unnamed"
|
|
reason = f"Policy rule matched: id={rule_id}, name={rule_name}, action={action}"
|
|
return PolicyDecision(action=action, matched=True, rule_id=rule_id, reason=reason)
|
|
|
|
return PolicyDecision(action=self.default_action, matched=False, reason="default_action")
|
|
|
|
def _resolve_action(self, rule: dict) -> str:
|
|
then_block = rule.get("then", {})
|
|
action = None
|
|
|
|
if isinstance(then_block, dict):
|
|
action = then_block.get("action")
|
|
elif isinstance(then_block, str):
|
|
action = then_block
|
|
|
|
if not action:
|
|
action = rule.get("action")
|
|
|
|
action = str(action or "allow")
|
|
if action not in SUPPORTED_ACTIONS:
|
|
logger.warning("Unsupported policy action '%s', coercing to 'allow'", action)
|
|
return "allow"
|
|
return action
|
|
|
|
def _rule_matches(self, rule: dict, packet, context: dict) -> bool:
|
|
cond = rule.get("if", {})
|
|
|
|
# Support implicit single-condition form.
|
|
if isinstance(cond, dict) and "field" in cond:
|
|
return self._condition_matches(cond, packet, context)
|
|
|
|
if not isinstance(cond, dict):
|
|
return False
|
|
|
|
all_conds = cond.get("all")
|
|
any_conds = cond.get("any")
|
|
|
|
if isinstance(all_conds, list):
|
|
return all(self._condition_matches(c, packet, context) for c in all_conds)
|
|
|
|
if isinstance(any_conds, list):
|
|
return any(self._condition_matches(c, packet, context) for c in any_conds)
|
|
|
|
return False
|
|
|
|
def _condition_matches(self, condition: dict, packet, context: dict) -> bool:
|
|
if not isinstance(condition, dict):
|
|
return False
|
|
|
|
field = condition.get("field")
|
|
op = condition.get("op", "equals")
|
|
try:
|
|
expected = self._resolve_value(condition.get("value"))
|
|
|
|
actual = self._get_field_value(field, packet, context)
|
|
if field == "path_hashes":
|
|
actual = self._normalize_path_hash_values(actual)
|
|
expected = self._normalize_path_hash_values(expected)
|
|
if field == "channel_hash":
|
|
actual = self._normalize_channel_hash_values(actual)
|
|
expected = self._normalize_channel_hash_values(expected)
|
|
result = self._compare(actual, op, expected)
|
|
logger.debug(
|
|
"Condition eval: field=%s op=%s expected=%r actual=%r -> %s",
|
|
field,
|
|
op,
|
|
expected,
|
|
actual,
|
|
"MATCH" if result else "no match",
|
|
)
|
|
return result
|
|
except ValueError as exc:
|
|
logger.debug("Condition eval: field=%s raised ValueError: %s -> no match", field, exc)
|
|
return False
|
|
|
|
def _resolve_value(self, value: Any) -> Any:
|
|
if isinstance(value, str) and value.startswith("@"):
|
|
# Object reference format: @group.name
|
|
ref = value[1:]
|
|
parts = ref.split(".", 1)
|
|
if len(parts) == 2:
|
|
group, key = parts
|
|
group_obj = self.objects.get(group, {})
|
|
if isinstance(group_obj, dict):
|
|
return group_obj.get(key)
|
|
return value
|
|
|
|
def _get_field_value(self, field: Any, packet, context: dict) -> Any:
|
|
if not isinstance(field, str):
|
|
return None
|
|
|
|
# Existing packet/context fields only.
|
|
if field in context:
|
|
return context.get(field)
|
|
|
|
if field == "payload_hex":
|
|
payload = getattr(packet, "payload", None) or b""
|
|
return bytes(payload).hex()
|
|
|
|
if field == "channel_hash":
|
|
decrypted = getattr(packet, "decrypted", None)
|
|
if isinstance(decrypted, dict):
|
|
group_text = decrypted.get("group_text_data", {})
|
|
if isinstance(group_text, dict):
|
|
candidate = group_text.get("channel_hash")
|
|
if candidate is not None:
|
|
return candidate
|
|
|
|
try:
|
|
payload_type = (
|
|
packet.get_payload_type() if hasattr(packet, "get_payload_type") else None
|
|
)
|
|
except Exception:
|
|
payload_type = None
|
|
if payload_type in (PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_GRP_DATA):
|
|
payload = (
|
|
packet.get_payload()
|
|
if hasattr(packet, "get_payload")
|
|
else getattr(packet, "payload", None)
|
|
)
|
|
if payload and len(payload) >= 1:
|
|
return payload[0]
|
|
|
|
if field == "channel_message_body":
|
|
channel_info = self._get_channel_decrypt_info(packet)
|
|
return channel_info.get("message_body")
|
|
|
|
if field == "channel_sender":
|
|
channel_info = self._get_channel_decrypt_info(packet)
|
|
return channel_info.get("sender")
|
|
|
|
if field == "channel_decryptable":
|
|
channel_info = self._get_channel_decrypt_info(packet)
|
|
return bool(channel_info.get("decryptable", False))
|
|
|
|
if field == "path_hashes":
|
|
if hasattr(packet, "get_path_hashes_hex"):
|
|
return packet.get_path_hashes_hex()
|
|
return []
|
|
|
|
if field == "transport_code_0":
|
|
if hasattr(packet, "transport_codes") and packet.transport_codes:
|
|
return packet.transport_codes[0]
|
|
return None
|
|
|
|
if field == "transport_code_1":
|
|
if hasattr(packet, "transport_codes") and len(packet.transport_codes) > 1:
|
|
return packet.transport_codes[1]
|
|
return None
|
|
|
|
return None
|
|
|
|
def _extract_channel_message_body(self, packet) -> Optional[str]:
|
|
channel_info = self._compute_channel_decrypt_info(packet)
|
|
return channel_info.get("message_body")
|
|
|
|
def _get_channel_decrypt_info(self, packet) -> dict[str, Any]:
|
|
packet_key = id(packet)
|
|
cached = self._channel_decrypt_cache.get(packet_key)
|
|
if isinstance(cached, dict):
|
|
return cached
|
|
|
|
computed = self._compute_channel_decrypt_info(packet)
|
|
self._channel_decrypt_cache[packet_key] = computed
|
|
return computed
|
|
|
|
def _compute_channel_decrypt_info(self, packet) -> dict[str, Any]:
|
|
decrypted = getattr(packet, "decrypted", None)
|
|
if isinstance(decrypted, dict):
|
|
group_text = decrypted.get("group_text_data", {})
|
|
if isinstance(group_text, dict):
|
|
sender = group_text.get("sender")
|
|
text = group_text.get("text")
|
|
if isinstance(text, str):
|
|
if not isinstance(sender, str) or not sender.strip():
|
|
sender, text = self._extract_sender_from_message(text)
|
|
return {
|
|
"decryptable": True,
|
|
"sender": sender,
|
|
"message_body": text,
|
|
}
|
|
|
|
try:
|
|
payload_type = (
|
|
packet.get_payload_type() if hasattr(packet, "get_payload_type") else None
|
|
)
|
|
except Exception:
|
|
return {
|
|
"decryptable": False,
|
|
"message_body": None,
|
|
}
|
|
|
|
if payload_type != PAYLOAD_TYPE_GRP_TXT:
|
|
return {
|
|
"decryptable": False,
|
|
"message_body": None,
|
|
}
|
|
|
|
payload = (
|
|
packet.get_payload()
|
|
if hasattr(packet, "get_payload")
|
|
else getattr(packet, "payload", None)
|
|
)
|
|
if not payload or len(payload) < 4:
|
|
return {
|
|
"decryptable": False,
|
|
"message_body": None,
|
|
}
|
|
|
|
channel_hash = payload[0]
|
|
cipher_mac = bytes(payload[1:3])
|
|
ciphertext = bytes(payload[3:])
|
|
|
|
secrets_tried = 0
|
|
for secret in self._iter_policy_channel_secrets():
|
|
secrets_tried += 1
|
|
derived = self._derive_channel_hash(secret)
|
|
secret_preview = secret[:8] + "..." if len(secret) > 8 else secret
|
|
if derived != channel_hash:
|
|
logger.debug(
|
|
"Channel decrypt: secret %s derived hash 0x%02X != packet hash 0x%02X, skipping",
|
|
secret_preview,
|
|
derived,
|
|
channel_hash,
|
|
)
|
|
continue
|
|
|
|
logger.debug(
|
|
"Channel decrypt: secret %s hash matches 0x%02X, attempting MAC+decrypt",
|
|
secret_preview,
|
|
channel_hash,
|
|
)
|
|
plaintext = self._decrypt_channel_message(secret, cipher_mac, ciphertext)
|
|
if plaintext is None:
|
|
logger.debug(
|
|
"Channel decrypt: secret %s MAC/decrypt failed",
|
|
secret_preview,
|
|
)
|
|
continue
|
|
|
|
parsed = self._parse_channel_plaintext(plaintext)
|
|
if not isinstance(parsed, dict):
|
|
logger.debug(
|
|
"Channel decrypt: secret %s parse failed",
|
|
secret_preview,
|
|
)
|
|
continue
|
|
|
|
content = parsed.get("content")
|
|
if not isinstance(content, str):
|
|
continue
|
|
|
|
sender, message_body = self._extract_sender_from_message(content)
|
|
logger.debug(
|
|
"Channel decrypt: SUCCESS with secret %s, sender=%r, message_body=%r",
|
|
secret_preview,
|
|
sender,
|
|
message_body[:40] if message_body else "",
|
|
)
|
|
return {
|
|
"decryptable": True,
|
|
"sender": sender.rstrip("\x00").rstrip(),
|
|
"message_body": message_body.rstrip("\x00").rstrip(),
|
|
}
|
|
|
|
if secrets_tried == 0:
|
|
logger.debug(
|
|
"Channel decrypt: no policy channel secrets configured "
|
|
"(objects.channels / objects.channel_hash_groups and inline rule secrets are empty/missing); "
|
|
"decryptable=False",
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"Channel decrypt: no matching secret found (tried %d), decryptable=False",
|
|
secrets_tried,
|
|
)
|
|
return {
|
|
"decryptable": False,
|
|
"sender": None,
|
|
"message_body": None,
|
|
}
|
|
|
|
def _iter_policy_channel_secrets(self):
|
|
channels = self.objects.get("channels", {})
|
|
if isinstance(channels, dict):
|
|
channel_items = channels.values()
|
|
elif isinstance(channels, (list, tuple)):
|
|
channel_items = channels
|
|
else:
|
|
logger.debug(
|
|
"Channel decrypt: objects.channels has unsupported type %s",
|
|
type(channels).__name__,
|
|
)
|
|
return
|
|
|
|
for channel_cfg in channel_items:
|
|
if isinstance(channel_cfg, str):
|
|
yield channel_cfg
|
|
continue
|
|
|
|
if isinstance(channel_cfg, dict):
|
|
# Accept common schema variations used across policy/companion exports.
|
|
secret = (
|
|
channel_cfg.get("secret")
|
|
or channel_cfg.get("key")
|
|
or channel_cfg.get("psk")
|
|
or channel_cfg.get("channel_secret")
|
|
)
|
|
if secret:
|
|
yield str(secret)
|
|
else:
|
|
logger.debug(
|
|
"Channel decrypt: channel entry missing secret/key/psk/channel_secret keys",
|
|
)
|
|
continue
|
|
|
|
logger.debug(
|
|
"Channel decrypt: skipping unsupported channel entry type %s",
|
|
type(channel_cfg).__name__,
|
|
)
|
|
|
|
# Also accept full channel secrets provided via policy object groups.
|
|
channel_hash_groups = self.objects.get("channel_hash_groups", {})
|
|
if isinstance(channel_hash_groups, dict):
|
|
for group_values in channel_hash_groups.values():
|
|
values = (
|
|
group_values if isinstance(group_values, (list, tuple, set)) else [group_values]
|
|
)
|
|
for candidate in values:
|
|
secret = self._extract_channel_secret_literal(candidate)
|
|
if secret:
|
|
yield secret
|
|
elif channel_hash_groups not in ({}, None):
|
|
logger.debug(
|
|
"Channel decrypt: objects.channel_hash_groups has unsupported type %s",
|
|
type(channel_hash_groups).__name__,
|
|
)
|
|
|
|
for inline_secret in self._inline_channel_secrets:
|
|
yield inline_secret
|
|
|
|
@staticmethod
|
|
def _secret_bytes_for_hash(channel_secret: str) -> bytes:
|
|
try:
|
|
secret_bytes = bytes.fromhex(channel_secret)
|
|
except ValueError:
|
|
secret_bytes = channel_secret.encode("utf-8")
|
|
if len(secret_bytes) >= 32 and secret_bytes[16:32] == b"\x00" * 16:
|
|
return secret_bytes[:16]
|
|
if len(secret_bytes) > 32:
|
|
return secret_bytes[:32]
|
|
return secret_bytes
|
|
|
|
def _derive_channel_hash(self, channel_secret: str) -> int:
|
|
secret_bytes = self._secret_bytes_for_hash(channel_secret)
|
|
return hashlib.sha256(secret_bytes).digest()[0]
|
|
|
|
@staticmethod
|
|
def _decrypt_channel_message(
|
|
channel_secret: str, mac: bytes, ciphertext: bytes
|
|
) -> Optional[bytes]:
|
|
try:
|
|
try:
|
|
secret_bytes = bytes.fromhex(channel_secret)
|
|
except ValueError:
|
|
secret_bytes = channel_secret.encode("utf-8")
|
|
|
|
if len(secret_bytes) < 32:
|
|
secret_bytes = secret_bytes + b"\x00" * (32 - len(secret_bytes))
|
|
elif len(secret_bytes) > 32:
|
|
secret_bytes = secret_bytes[:32]
|
|
|
|
expected_mac = CryptoUtils._hmac_sha256(secret_bytes, ciphertext)[:2]
|
|
if mac != expected_mac:
|
|
return None
|
|
|
|
return CryptoUtils._aes_decrypt(secret_bytes[:16], ciphertext)
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _parse_channel_plaintext(plaintext: bytes) -> Optional[dict]:
|
|
if len(plaintext) < 5:
|
|
return None
|
|
|
|
try:
|
|
timestamp = int.from_bytes(plaintext[:4], "little")
|
|
flags = plaintext[4]
|
|
raw = plaintext[5:].decode("utf-8", errors="replace")
|
|
message_content = raw.rstrip("\x00")
|
|
|
|
message_type = "unknown"
|
|
if flags == 0x00:
|
|
message_type = "plain_text"
|
|
elif flags == 0x01:
|
|
message_type = "cli_command"
|
|
elif flags == 0x02:
|
|
message_type = "signed_text"
|
|
if len(plaintext) >= 7:
|
|
raw = plaintext[7:].decode("utf-8", errors="replace")
|
|
message_content = raw.rstrip("\x00")
|
|
|
|
return {
|
|
"timestamp": timestamp,
|
|
"flags": flags,
|
|
"message_type": message_type,
|
|
"content": message_content,
|
|
}
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_sender_from_message(message_content: str) -> tuple[str, str]:
|
|
if ": " in message_content:
|
|
parts = message_content.split(": ", 1)
|
|
if len(parts) == 2:
|
|
return parts[0], parts[1]
|
|
return "Unknown", message_content
|
|
|
|
@staticmethod
|
|
def _extract_channel_secret_literal(value: Any) -> Optional[str]:
|
|
if isinstance(value, str):
|
|
raw = value.strip()
|
|
if not raw:
|
|
return None
|
|
normalized_hex = raw[2:] if raw.lower().startswith("0x") else raw
|
|
if len(normalized_hex) in (32, 64) and all(
|
|
ch in "0123456789abcdefABCDEF" for ch in normalized_hex
|
|
):
|
|
return normalized_hex
|
|
return None
|
|
|
|
@classmethod
|
|
def _collect_inline_rule_channel_secrets(cls, rules: list) -> list[str]:
|
|
secrets: list[str] = []
|
|
|
|
def _consume_condition(cond: Any):
|
|
if not isinstance(cond, dict):
|
|
return
|
|
if cond.get("field") != "channel_hash":
|
|
return
|
|
raw_value = cond.get("value")
|
|
values = raw_value if isinstance(raw_value, (list, tuple, set)) else [raw_value]
|
|
for candidate in values:
|
|
secret = cls._extract_channel_secret_literal(candidate)
|
|
if secret:
|
|
secrets.append(secret)
|
|
|
|
for rule in rules:
|
|
if not isinstance(rule, dict):
|
|
continue
|
|
cond = rule.get("if", {})
|
|
if isinstance(cond, dict) and "field" in cond:
|
|
_consume_condition(cond)
|
|
continue
|
|
if not isinstance(cond, dict):
|
|
continue
|
|
for key in ("all", "any"):
|
|
branch = cond.get(key)
|
|
if isinstance(branch, list):
|
|
for item in branch:
|
|
_consume_condition(item)
|
|
|
|
return list(dict.fromkeys(secrets))
|
|
|
|
@staticmethod
|
|
def _normalize_path_hash_value(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
|
|
if isinstance(value, int):
|
|
parsed = value
|
|
if parsed < 0:
|
|
return None
|
|
if parsed <= 0xFF:
|
|
width = 2
|
|
elif parsed <= 0xFFFF:
|
|
width = 4
|
|
elif parsed <= 0xFFFFFF:
|
|
width = 6
|
|
else:
|
|
raise ValueError("path hash exceeds 3 bytes")
|
|
return f"{parsed:0{width}X}"
|
|
else:
|
|
raw = str(value).strip()
|
|
if not raw:
|
|
return None
|
|
if raw.lower().startswith("0x"):
|
|
raw = raw[2:]
|
|
if not raw:
|
|
return None
|
|
if len(raw) % 2 != 0:
|
|
raise ValueError("path hash hex length must be even")
|
|
if len(raw) not in (2, 4, 6):
|
|
raise ValueError("path hash must be 1, 2, or 3 bytes")
|
|
if not all(ch in "0123456789abcdefABCDEF" for ch in raw):
|
|
raise ValueError("path hash must be hex")
|
|
return raw.upper()
|
|
|
|
@classmethod
|
|
def _normalize_path_hash_values(cls, value: Any) -> Any:
|
|
if isinstance(value, (list, tuple, set)):
|
|
normalized = []
|
|
for item in value:
|
|
normalized_item = cls._normalize_path_hash_value(item)
|
|
if normalized_item is not None:
|
|
normalized.append(normalized_item)
|
|
lengths = {len(item) for item in normalized}
|
|
if len(lengths) > 1:
|
|
raise ValueError("path hashes cannot mix byte lengths")
|
|
return normalized
|
|
|
|
return cls._normalize_path_hash_value(value)
|
|
|
|
@staticmethod
|
|
def _normalize_channel_hash_value(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
|
|
if isinstance(value, int):
|
|
parsed = value
|
|
else:
|
|
raw = str(value).strip()
|
|
if not raw:
|
|
return None
|
|
normalized_hex = raw[2:] if raw.lower().startswith("0x") else raw
|
|
if len(normalized_hex) in (32, 64) and all(
|
|
ch in "0123456789abcdefABCDEF" for ch in normalized_hex
|
|
):
|
|
secret_bytes = PolicyEngine._secret_bytes_for_hash(normalized_hex)
|
|
parsed = hashlib.sha256(secret_bytes).digest()[0]
|
|
return f"0x{parsed:02X}"
|
|
if raw.lower().startswith("0x"):
|
|
parsed = int(raw, 16)
|
|
elif raw.isdigit():
|
|
parsed = int(raw, 10)
|
|
else:
|
|
parsed = int(raw, 16)
|
|
|
|
if parsed < 0:
|
|
raise ValueError("channel hash must be non-negative")
|
|
if parsed > 0xFF:
|
|
raise ValueError("channel hash must be one byte (0x00-0xFF)")
|
|
return f"0x{parsed:02X}"
|
|
|
|
@classmethod
|
|
def _normalize_channel_hash_values(cls, value: Any) -> Any:
|
|
if isinstance(value, (list, tuple, set)):
|
|
normalized = []
|
|
for item in value:
|
|
normalized_item = cls._normalize_channel_hash_value(item)
|
|
if normalized_item is not None:
|
|
normalized.append(normalized_item)
|
|
return normalized
|
|
|
|
return cls._normalize_channel_hash_value(value)
|
|
|
|
@staticmethod
|
|
def _compare(actual: Any, op: Any, expected: Any) -> bool:
|
|
op_name = str(op or "equals").lower()
|
|
|
|
try:
|
|
if op_name in ("equals", "eq", "=="):
|
|
return actual == expected
|
|
if op_name in ("not_equals", "ne", "!="):
|
|
return actual != expected
|
|
if op_name in ("greater_than", "gt", ">"):
|
|
return actual is not None and expected is not None and actual > expected
|
|
if op_name in ("greater_or_equal", "gte", ">="):
|
|
return actual is not None and expected is not None and actual >= expected
|
|
if op_name in ("less_than", "lt", "<"):
|
|
return actual is not None and expected is not None and actual < expected
|
|
if op_name in ("less_or_equal", "lte", "<="):
|
|
return actual is not None and expected is not None and actual <= expected
|
|
if op_name == "contains":
|
|
if isinstance(actual, (list, tuple, set)):
|
|
return expected in actual
|
|
if isinstance(actual, str) and expected is not None:
|
|
return str(expected) in actual
|
|
return False
|
|
if op_name in ("in", "is_in"):
|
|
if isinstance(expected, (list, tuple, set)):
|
|
return actual in expected
|
|
if isinstance(expected, str) and actual is not None:
|
|
return str(actual) in expected
|
|
return False
|
|
if op_name in ("intersects", "overlaps"):
|
|
if isinstance(actual, (list, tuple, set)) and isinstance(
|
|
expected, (list, tuple, set)
|
|
):
|
|
return len(set(actual).intersection(set(expected))) > 0
|
|
return False
|
|
if op_name == "starts_with":
|
|
return (
|
|
isinstance(actual, str)
|
|
and isinstance(expected, str)
|
|
and actual.startswith(expected)
|
|
)
|
|
if op_name == "ends_with":
|
|
return (
|
|
isinstance(actual, str)
|
|
and isinstance(expected, str)
|
|
and actual.endswith(expected)
|
|
)
|
|
except Exception:
|
|
return False
|
|
|
|
return False
|