Files
pyMC_Repeater/scripts/check_openapi_contract.py
T

331 lines
10 KiB
Python

#!/usr/bin/env python3
"""Check OpenAPI contract coverage against CherryPy exposed endpoints.
This check enforces both directions:
- every OpenAPI path must exist in code
- every API endpoint in code must exist in OpenAPI OR be explicitly allowlisted
Method checks are warning-only by default and can be made strict.
"""
from __future__ import annotations
import argparse
import ast
import re
from dataclasses import dataclass
from pathlib import Path
import yaml
ROOT = Path(__file__).resolve().parents[1]
WEB_DIR = ROOT / "repeater" / "web"
OPENAPI_PATH = WEB_DIR / "openapi.yaml"
ALLOWLIST_PATH = ROOT / "scripts" / "openapi_contract_allowlist.yaml"
@dataclass
class RouteInfo:
methods: set[str]
confident: bool
@dataclass
class Allowlist:
exact: set[str]
prefixes: tuple[str, ...]
def _normalize_path(path: str) -> str:
path = path.strip()
if not path:
return "/"
if not path.startswith("/"):
path = "/" + path
path = re.sub(r"/{2,}", "/", path)
path = re.sub(r"\{[^/}]+\}", "{}", path)
if len(path) > 1 and path.endswith("/"):
path = path[:-1]
return path
def _load_openapi() -> dict[str, set[str]]:
with OPENAPI_PATH.open("r", encoding="utf-8") as f:
doc = yaml.safe_load(f) or {}
paths = doc.get("paths", {})
out: dict[str, set[str]] = {}
for raw_path, ops in paths.items():
if not isinstance(ops, dict):
continue
methods = {
m.lower()
for m in ops.keys()
if m.lower() in {"get", "post", "put", "delete", "patch", "options", "head"}
}
# We do not enforce OPTIONS/HEAD in this checker.
methods.discard("options")
methods.discard("head")
out[_normalize_path(str(raw_path))] = methods
return out
def _load_allowlist() -> Allowlist:
if not ALLOWLIST_PATH.exists():
return Allowlist(exact=set(), prefixes=tuple())
with ALLOWLIST_PATH.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
exact_raw = data.get("exact_paths", [])
prefix_raw = data.get("path_prefixes", [])
exact = {_normalize_path(str(p)) for p in exact_raw if str(p).strip()}
prefixes = tuple(_normalize_path(str(p)) for p in prefix_raw if str(p).strip())
return Allowlist(exact=exact, prefixes=prefixes)
def _is_allowlisted(path: str, allowlist: Allowlist) -> bool:
if path in allowlist.exact:
return True
for prefix in allowlist.prefixes:
if path == prefix or path.startswith(prefix + "/"):
return True
return False
def _is_cherrypy_request_method_expr(node: ast.AST) -> bool:
# cherrypy.request.method
return (
isinstance(node, ast.Attribute)
and node.attr == "method"
and isinstance(node.value, ast.Attribute)
and node.value.attr == "request"
and isinstance(node.value.value, ast.Name)
and node.value.value.id == "cherrypy"
)
def _extract_method_strings(node: ast.AST) -> set[str]:
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return {node.value.upper()}
if isinstance(node, (ast.Tuple, ast.List, ast.Set)):
vals: set[str] = set()
for e in node.elts:
vals |= _extract_method_strings(e)
return vals
return set()
def _infer_methods(fn: ast.FunctionDef) -> tuple[set[str], bool]:
methods: set[str] = set()
confidence = False
has_require_post = False
saw_method_compare = False
for node in ast.walk(fn):
if isinstance(node, ast.Call):
# self._require_post()
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "_require_post"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "self"
):
has_require_post = True
methods.add("POST")
confidence = True
if not isinstance(node, ast.Compare):
continue
if not _is_cherrypy_request_method_expr(node.left):
continue
saw_method_compare = True
if not node.ops or not node.comparators:
continue
op = node.ops[0]
rhs_vals = _extract_method_strings(node.comparators[0])
if not rhs_vals:
continue
# Treat equality and inequality guards as declared allowed methods.
if isinstance(op, (ast.Eq, ast.In, ast.NotEq, ast.NotIn)):
methods |= rhs_vals
confidence = True
methods.discard("OPTIONS")
methods.discard("HEAD")
# If a handler branches on request.method but is not explicitly POST-only,
# CherryPy's default method for uncovered branches is typically GET.
if saw_method_compare and not has_require_post and methods and "POST" in methods:
methods.add("GET")
if not methods:
return {"GET"}, False
return methods, confidence
def _has_expose_decorator(fn: ast.FunctionDef) -> bool:
for d in fn.decorator_list:
# @cherrypy.expose
if isinstance(d, ast.Attribute):
if d.attr == "expose" and isinstance(d.value, ast.Name) and d.value.id == "cherrypy":
return True
return False
def _fn_params(fn: ast.FunctionDef) -> list[str]:
params = [a.arg for a in fn.args.args if a.arg != "self"]
return [p for p in params if p not in {"kwargs", "args"}]
def _candidate_suffixes(fn: ast.FunctionDef) -> list[str]:
name = fn.name
params = _fn_params(fn)
if name == "index":
return [""]
if name == "default":
if params:
return ["/{}"]
return ["/{path}"]
base = f"/{name}"
# Keep named endpoints canonical. Parameters are often query parameters,
# so adding path-segment variants here produces false positives.
return [base]
def _collect_class_routes(
module_path: Path, class_name: str, prefixes: list[str]
) -> dict[str, RouteInfo]:
tree = ast.parse(module_path.read_text(encoding="utf-8"), filename=str(module_path))
cls = next(
(n for n in tree.body if isinstance(n, ast.ClassDef) and n.name == class_name),
None,
)
if cls is None:
return {}
routes: dict[str, RouteInfo] = {}
for node in cls.body:
if not isinstance(node, ast.FunctionDef):
continue
if class_name == "APIEndpoints" and node.name == "default":
# Catch-all handlers are not part of API contract surface.
continue
if not _has_expose_decorator(node):
continue
methods, confident = _infer_methods(node)
suffixes = _candidate_suffixes(node)
for prefix in prefixes:
for suffix in suffixes:
path = _normalize_path(prefix + suffix)
cur = routes.get(path)
if cur is None:
routes[path] = RouteInfo(methods=set(methods), confident=confident)
else:
cur.methods |= methods
cur.confident = cur.confident or confident
return routes
def _collect_routes() -> dict[str, RouteInfo]:
route_map: dict[str, RouteInfo] = {}
class_specs = [
# /api/* methods are described in OpenAPI as /<endpoint>
(WEB_DIR / "api_endpoints.py", "APIEndpoints", [""]),
# Nested /api/companion/* endpoints are described as /companion/*.
(WEB_DIR / "companion_endpoints.py", "CompanionAPIEndpoints", ["/companion"]),
# Nested /api/update/* endpoints are described as /update/* when documented.
(WEB_DIR / "update_endpoints.py", "UpdateAPIEndpoints", ["/update"]),
# Auth top-level endpoints are mounted at /auth/*
(WEB_DIR / "auth_endpoints.py", "AuthEndpoints", ["/auth"]),
# Token sub-resource is exposed both under /auth and /api/auth in current routing.
(WEB_DIR / "auth_endpoints.py", "TokensAPIEndpoint", ["/auth/tokens", "/api/auth/tokens"]),
]
for file_path, class_name, prefixes in class_specs:
class_routes = _collect_class_routes(file_path, class_name, prefixes)
for path, info in class_routes.items():
cur = route_map.get(path)
if cur is None:
route_map[path] = info
else:
cur.methods |= info.methods
cur.confident = cur.confident or info.confident
return route_map
def main() -> int:
parser = argparse.ArgumentParser(description="Check OpenAPI contract coverage.")
parser.add_argument(
"--strict-methods",
action="store_true",
help="Fail when inferred HTTP methods differ from OpenAPI methods.",
)
args = parser.parse_args()
if not OPENAPI_PATH.exists():
print(f"ERROR: OpenAPI spec not found at {OPENAPI_PATH}")
return 1
spec = _load_openapi()
allowlist = _load_allowlist()
code_routes = _collect_routes()
errors: list[str] = []
warnings: list[str] = []
for path, spec_methods in sorted(spec.items()):
code = code_routes.get(path)
if code is None:
errors.append(f"Missing endpoint in code for OpenAPI path: {path}")
continue
if spec_methods and code.confident:
code_methods = {m.lower() for m in code.methods}
missing_methods = sorted(m for m in spec_methods if m not in code_methods)
if missing_methods:
msg = (
f"Method mismatch for {path}: OpenAPI has {sorted(spec_methods)}, "
f"code inference has {sorted(code_methods)}"
)
if args.strict_methods:
errors.append(msg)
else:
warnings.append(msg)
# Enforce code -> OpenAPI (unless allowlisted)
for path in sorted(code_routes.keys()):
if path in spec:
continue
if _is_allowlisted(path, allowlist):
continue
errors.append(f"Undocumented endpoint in code (not in OpenAPI and not allowlisted): {path}")
if warnings:
print("OpenAPI contract warnings:")
for w in warnings:
print(f"- {w}")
if errors:
print("OpenAPI contract check failed:")
for e in errors:
print(f"- {e}")
return 1
print("OpenAPI contract check passed.")
return 0
if __name__ == "__main__":
raise SystemExit(main())