From ce8cdfd48dadfdb1dc90097cf362810590d79cad Mon Sep 17 00:00:00 2001 From: Lloyd Date: Mon, 9 Mar 2026 10:47:52 +0000 Subject: [PATCH] Refactor version handling in update endpoints to use importlib.metadata for installed version retrieval --- repeater/web/update_endpoints.py | 174 +++++++++++++++++++++---------- 1 file changed, 119 insertions(+), 55 deletions(-) diff --git a/repeater/web/update_endpoints.py b/repeater/web/update_endpoints.py index cf16b6d..e9f00dc 100644 --- a/repeater/web/update_endpoints.py +++ b/repeater/web/update_endpoints.py @@ -29,8 +29,6 @@ from typing import List, Optional import cherrypy -from repeater import __version__ as _installed_version - logger = logging.getLogger("HTTPServer") # --------------------------------------------------------------------------- @@ -45,6 +43,21 @@ PACKAGE_NAME = "pymc_repeater" # How long (seconds) before a cached check result expires CHECK_CACHE_TTL = 600 # 10 minutes + +def _get_installed_version() -> str: + """Read the currently installed package version fresh from importlib.metadata.""" + try: + from importlib.metadata import version as _pkg_ver + return _pkg_ver(PACKAGE_NAME) + except Exception: + pass + # Fallback: repeater __version__ + try: + from repeater import __version__ + return __version__ + except Exception: + return "unknown" + # Channels file – persisted so the choice survives daemon restarts _CHANNELS_FILE = "/var/lib/pymc_repeater/.update_channel" @@ -58,7 +71,7 @@ class _UpdateState: def __init__(self): self._lock = threading.Lock() # version info - self.current_version: str = _installed_version + self.current_version: str = _get_installed_version() self.latest_version: Optional[str] = None self.has_update: bool = False self.channel: str = self._load_channel() @@ -97,6 +110,13 @@ class _UpdateState: # ------------------------------------------------------------------ # def snapshot(self) -> dict: with self._lock: + # Always read installed version fresh so it reflects post-restart state + fresh_current = _get_installed_version() + if fresh_current != "unknown": + self.current_version = fresh_current + # Recompute has_update with fresh installed version + if self.latest_version is not None: + self.has_update = _has_update(self.current_version, self.latest_version) return { "current_version": self.current_version, "latest_version": self.latest_version, @@ -127,8 +147,10 @@ class _UpdateState: def _finish_check(self, latest: str) -> None: with self._lock: self.latest_version = latest - self.current_version = _installed_version # refresh in case just updated - self.has_update = (latest != self.current_version) + fresh = _get_installed_version() + if fresh != "unknown": + self.current_version = fresh + self.has_update = _has_update(self.current_version, latest) self.last_checked = datetime.utcnow() self.state = "idle" self.error_message = None @@ -155,13 +177,8 @@ class _UpdateState: self.error_message = None if success else msg if success: self.progress_lines.append(f"[pyMC updater] ✓ {msg}") - # Refresh installed version from importlib metadata - try: - from importlib.metadata import version as _pkg_ver - self.current_version = _pkg_ver(PACKAGE_NAME) - except Exception: - pass self.has_update = False + # current_version will be refreshed on next snapshot() call else: self.progress_lines.append(f"[pyMC updater] ✗ {msg}") @@ -179,64 +196,108 @@ _state = _UpdateState() def _fetch_url(url: str, timeout: int = 10) -> str: """Perform a simple GET and return text body, or raise on failure.""" - req = urllib.request.Request(url, headers={"User-Agent": f"pymc-repeater/{_installed_version}"}) + installed = _get_installed_version() + req = urllib.request.Request(url, headers={"User-Agent": f"pymc-repeater/{installed}"}) with urllib.request.urlopen(req, timeout=timeout) as resp: return resp.read().decode("utf-8", errors="replace") -def _latest_version_from_raw(channel: str) -> str: - """ - Fetch repeater/__init__.py from *channel* on GitHub and extract __version__. - Falls back to scanning the _version.py stub. - """ - raw_url = f"{GITHUB_RAW_BASE}/{channel}/repeater/__init__.py" - text = _fetch_url(raw_url) - # __init__.py doesn't embed version directly; it imports from _version.py - # Try to read _version.py on the same channel instead - try: - ver_url = f"{GITHUB_RAW_BASE}/{channel}/repeater/_version.py" - ver_text = _fetch_url(ver_url) - m = re.search(r'version\s*=\s*["\']([^"\']+)["\']', ver_text) - if m: - return m.group(1) - except Exception: - pass +def _get_latest_tag() -> str: + """Return the most recent semver tag from the repo, or raise.""" + tags_url = f"{GITHUB_API_BASE}/tags?per_page=10" + body = _fetch_url(tags_url, timeout=8) + tags = json.loads(body) + for tag in tags: + name = tag.get("name", "").lstrip("v") + if re.match(r'^\d+\.\d+', name): + return name + raise RuntimeError("No semver tags found in repository") - # Last resort: try pyproject.toml static version field + +def _branch_is_dynamic(channel: str) -> bool: + """Return True if the branch uses setuptools_scm dynamic versioning.""" try: toml_url = f"{GITHUB_RAW_BASE}/{channel}/pyproject.toml" - toml_text = _fetch_url(toml_url) - m = re.search(r'^version\s*=\s*["\']([^"\']+)["\']', toml_text, re.MULTILINE) - if m: - return m.group(1) + toml_text = _fetch_url(toml_url, timeout=8) + # Static pin looks like: version = "1.0.5" + if re.search(r'^version\s*=\s*["\'][0-9]', toml_text, re.MULTILINE): + return False + # Dynamic looks like: dynamic = ["version"] + if re.search(r'^dynamic\s*=', toml_text, re.MULTILINE): + return True except Exception: pass + return True # assume dynamic if we can't tell - return "unknown" + +def _next_dev_version(base_tag: str, ahead_by: int) -> str: + """ + Generate a display version string for a dynamic branch. + e.g. base_tag="1.0.5", ahead_by=191 -> "1.0.6.dev191" + Mirrors what setuptools_scm guess-next-dev produces. + """ + parts = base_tag.split(".") + try: + parts[-1] = str(int(parts[-1]) + 1) + except (ValueError, IndexError): + parts.append("1") + return ".".join(parts) + f".dev{ahead_by}" + + +def _parse_dev_number(version_str: str) -> Optional[int]: + """Extract the dev commit count from a setuptools_scm version like 1.0.6.dev118.""" + m = re.search(r'\.dev(\d+)', version_str) + return int(m.group(1)) if m else None + + +def _has_update(installed: str, latest: str) -> bool: + """ + Compare installed vs latest version. + + For dev-versioned strings (both contain .devN): + Compare the dev numbers numerically so that + installed=1.0.6.dev200 vs latest=1.0.6.dev191 → False (already ahead). + + For static versions or mismatched types: + Simple string inequality. + """ + if installed == latest: + return False + installed_dev = _parse_dev_number(installed) + latest_dev = _parse_dev_number(latest) + if installed_dev is not None and latest_dev is not None: + return latest_dev > installed_dev + # Static release comparison + return installed != latest def _fetch_latest_version(channel: str) -> str: """ - Multi-strategy version fetch. Returns version string or raises. - Strategy: - 1. GitHub Releases API (only works when tagged releases exist) - 2. Raw _version.py / pyproject.toml from the branch - """ - # Strategy 1: releases API (works for stable main branch tags) - if channel == "main": - try: - api_url = f"{GITHUB_API_BASE}/releases/latest" - body = _fetch_url(api_url, timeout=8) - data = json.loads(body) - tag = data.get("tag_name", "") - ver = tag.lstrip("v") - if ver: - return ver - except Exception: - pass + Return the latest available version string for *channel*. - # Strategy 2: raw source files on the branch - return _latest_version_from_raw(channel) + For static-versioned channels (e.g. main after a release commit): + Uses the GitHub tags API -> e.g. "1.0.5" + + For dynamic-versioned channels (dev, feature branches using setuptools_scm): + Uses GET /compare/{tag}...{channel} to count commits ahead of the + last tag, then returns a version like "1.0.6.dev191" that mirrors + what setuptools_scm would produce on that branch. + has_update is then True when branch_dev_number > installed_dev_number. + """ + base_tag = _get_latest_tag() # always needed; single API call + + if _branch_is_dynamic(channel): + compare_url = f"{GITHUB_API_BASE}/compare/{base_tag}...{channel}" + try: + body = _fetch_url(compare_url, timeout=10) + data = json.loads(body) + ahead_by = int(data.get("ahead_by", 0)) + return _next_dev_version(base_tag, ahead_by) + except Exception: + return base_tag # fallback: show the tag + + # Static version channel — the tag IS the release version + return base_tag def _fetch_branches() -> List[str]: @@ -259,7 +320,10 @@ def _do_check() -> None: try: latest = _fetch_latest_version(channel) _state._finish_check(latest) - logger.info(f"[Update] Check complete – current={_state.current_version} latest={latest} channel={channel}") + logger.info( + f"[Update] Check complete – installed={_state.current_version} " + f"latest={latest} channel={channel} has_update={_state.has_update}" + ) except Exception as exc: msg = str(exc) _state._fail_check(msg)