#!/usr/bin/env python3 """Check OpenAPI contract coverage against CherryPy exposed endpoints. This check enforces both directions: - every OpenAPI path must exist in code - every API endpoint in code must exist in OpenAPI OR be explicitly allowlisted Method checks are warning-only by default and can be made strict. """ from __future__ import annotations import ast import argparse import re import sys from dataclasses import dataclass from pathlib import Path import yaml ROOT = Path(__file__).resolve().parents[1] WEB_DIR = ROOT / "repeater" / "web" OPENAPI_PATH = WEB_DIR / "openapi.yaml" ALLOWLIST_PATH = ROOT / "scripts" / "openapi_contract_allowlist.yaml" @dataclass class RouteInfo: methods: set[str] confident: bool @dataclass class Allowlist: exact: set[str] prefixes: tuple[str, ...] def _normalize_path(path: str) -> str: path = path.strip() if not path: return "/" if not path.startswith("/"): path = "/" + path path = re.sub(r"/{2,}", "/", path) path = re.sub(r"\{[^/}]+\}", "{}", path) if len(path) > 1 and path.endswith("/"): path = path[:-1] return path def _load_openapi() -> dict[str, set[str]]: with OPENAPI_PATH.open("r", encoding="utf-8") as f: doc = yaml.safe_load(f) or {} paths = doc.get("paths", {}) out: dict[str, set[str]] = {} for raw_path, ops in paths.items(): if not isinstance(ops, dict): continue methods = { m.lower() for m in ops.keys() if m.lower() in {"get", "post", "put", "delete", "patch", "options", "head"} } # We do not enforce OPTIONS/HEAD in this checker. methods.discard("options") methods.discard("head") out[_normalize_path(str(raw_path))] = methods return out def _load_allowlist() -> Allowlist: if not ALLOWLIST_PATH.exists(): return Allowlist(exact=set(), prefixes=tuple()) with ALLOWLIST_PATH.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} exact_raw = data.get("exact_paths", []) prefix_raw = data.get("path_prefixes", []) exact = {_normalize_path(str(p)) for p in exact_raw if str(p).strip()} prefixes = tuple(_normalize_path(str(p)) for p in prefix_raw if str(p).strip()) return Allowlist(exact=exact, prefixes=prefixes) def _is_allowlisted(path: str, allowlist: Allowlist) -> bool: if path in allowlist.exact: return True for prefix in allowlist.prefixes: if path == prefix or path.startswith(prefix + "/"): return True return False def _is_cherrypy_request_method_expr(node: ast.AST) -> bool: # cherrypy.request.method return ( isinstance(node, ast.Attribute) and node.attr == "method" and isinstance(node.value, ast.Attribute) and node.value.attr == "request" and isinstance(node.value.value, ast.Name) and node.value.value.id == "cherrypy" ) def _extract_method_strings(node: ast.AST) -> set[str]: if isinstance(node, ast.Constant) and isinstance(node.value, str): return {node.value.upper()} if isinstance(node, (ast.Tuple, ast.List, ast.Set)): vals: set[str] = set() for e in node.elts: vals |= _extract_method_strings(e) return vals return set() def _infer_methods(fn: ast.FunctionDef) -> tuple[set[str], bool]: methods: set[str] = set() confidence = False has_require_post = False saw_method_compare = False for node in ast.walk(fn): if isinstance(node, ast.Call): # self._require_post() if ( isinstance(node.func, ast.Attribute) and node.func.attr == "_require_post" and isinstance(node.func.value, ast.Name) and node.func.value.id == "self" ): has_require_post = True methods.add("POST") confidence = True if not isinstance(node, ast.Compare): continue if not _is_cherrypy_request_method_expr(node.left): continue saw_method_compare = True if not node.ops or not node.comparators: continue op = node.ops[0] rhs_vals = _extract_method_strings(node.comparators[0]) if not rhs_vals: continue # Treat equality and inequality guards as declared allowed methods. if isinstance(op, (ast.Eq, ast.In, ast.NotEq, ast.NotIn)): methods |= rhs_vals confidence = True methods.discard("OPTIONS") methods.discard("HEAD") # If a handler branches on request.method but is not explicitly POST-only, # CherryPy's default method for uncovered branches is typically GET. if saw_method_compare and not has_require_post and methods and "POST" in methods: methods.add("GET") if not methods: return {"GET"}, False return methods, confidence def _has_expose_decorator(fn: ast.FunctionDef) -> bool: for d in fn.decorator_list: # @cherrypy.expose if isinstance(d, ast.Attribute): if ( d.attr == "expose" and isinstance(d.value, ast.Name) and d.value.id == "cherrypy" ): return True return False def _fn_params(fn: ast.FunctionDef) -> list[str]: params = [a.arg for a in fn.args.args if a.arg != "self"] return [p for p in params if p not in {"kwargs", "args"}] def _candidate_suffixes(fn: ast.FunctionDef) -> list[str]: name = fn.name params = _fn_params(fn) if name == "index": return [""] if name == "default": if params: return ["/{}"] return ["/{path}"] base = f"/{name}" # Keep named endpoints canonical. Parameters are often query parameters, # so adding path-segment variants here produces false positives. return [base] def _collect_class_routes(module_path: Path, class_name: str, prefixes: list[str]) -> dict[str, RouteInfo]: tree = ast.parse(module_path.read_text(encoding="utf-8"), filename=str(module_path)) cls = next( ( n for n in tree.body if isinstance(n, ast.ClassDef) and n.name == class_name ), None, ) if cls is None: return {} routes: dict[str, RouteInfo] = {} for node in cls.body: if not isinstance(node, ast.FunctionDef): continue if class_name == "APIEndpoints" and node.name == "default": # Catch-all handlers are not part of API contract surface. continue if not _has_expose_decorator(node): continue methods, confident = _infer_methods(node) suffixes = _candidate_suffixes(node) for prefix in prefixes: for suffix in suffixes: path = _normalize_path(prefix + suffix) cur = routes.get(path) if cur is None: routes[path] = RouteInfo(methods=set(methods), confident=confident) else: cur.methods |= methods cur.confident = cur.confident or confident return routes def _collect_routes() -> dict[str, RouteInfo]: route_map: dict[str, RouteInfo] = {} class_specs = [ # /api/* methods are described in OpenAPI as / (WEB_DIR / "api_endpoints.py", "APIEndpoints", [""]), # Nested /api/companion/* endpoints are described as /companion/*. (WEB_DIR / "companion_endpoints.py", "CompanionAPIEndpoints", ["/companion"]), # Nested /api/update/* endpoints are described as /update/* when documented. (WEB_DIR / "update_endpoints.py", "UpdateAPIEndpoints", ["/update"]), # Auth top-level endpoints are mounted at /auth/* (WEB_DIR / "auth_endpoints.py", "AuthEndpoints", ["/auth"]), # Token sub-resource is exposed both under /auth and /api/auth in current routing. (WEB_DIR / "auth_endpoints.py", "TokensAPIEndpoint", ["/auth/tokens", "/api/auth/tokens"]), ] for file_path, class_name, prefixes in class_specs: class_routes = _collect_class_routes(file_path, class_name, prefixes) for path, info in class_routes.items(): cur = route_map.get(path) if cur is None: route_map[path] = info else: cur.methods |= info.methods cur.confident = cur.confident or info.confident return route_map def main() -> int: parser = argparse.ArgumentParser(description="Check OpenAPI contract coverage.") parser.add_argument( "--strict-methods", action="store_true", help="Fail when inferred HTTP methods differ from OpenAPI methods.", ) args = parser.parse_args() if not OPENAPI_PATH.exists(): print(f"ERROR: OpenAPI spec not found at {OPENAPI_PATH}") return 1 spec = _load_openapi() allowlist = _load_allowlist() code_routes = _collect_routes() errors: list[str] = [] warnings: list[str] = [] for path, spec_methods in sorted(spec.items()): code = code_routes.get(path) if code is None: errors.append(f"Missing endpoint in code for OpenAPI path: {path}") continue if spec_methods and code.confident: code_methods = {m.lower() for m in code.methods} missing_methods = sorted(m for m in spec_methods if m not in code_methods) if missing_methods: msg = ( f"Method mismatch for {path}: OpenAPI has {sorted(spec_methods)}, " f"code inference has {sorted(code_methods)}" ) if args.strict_methods: errors.append(msg) else: warnings.append(msg) # Enforce code -> OpenAPI (unless allowlisted) for path in sorted(code_routes.keys()): if path in spec: continue if _is_allowlisted(path, allowlist): continue errors.append( f"Undocumented endpoint in code (not in OpenAPI and not allowlisted): {path}" ) if warnings: print("OpenAPI contract warnings:") for w in warnings: print(f"- {w}") if errors: print("OpenAPI contract check failed:") for e in errors: print(f"- {e}") return 1 print("OpenAPI contract check passed.") return 0 if __name__ == "__main__": raise SystemExit(main())