diff --git a/llamacpp-watchdog.py b/llamacpp-watchdog.py index 0ba4fc6..aed0b85 100644 --- a/llamacpp-watchdog.py +++ b/llamacpp-watchdog.py @@ -7,6 +7,12 @@ Detects: - Router health endpoint failures - Zombie child model-server processes - Loaded models that are unreachable through the router + +Per-model tracking: + - Individual model failures never trigger a full service restart + - Newly loaded models get a grace period before probing + - Persistently failing models are unloaded and put in cooldown instead of causing restarts + - Only router-level health failures trigger service restarts """ import subprocess @@ -14,6 +20,7 @@ import requests import time import signal from datetime import datetime +from collections import defaultdict # Configuration SERVERS = [ @@ -25,13 +32,25 @@ LLAMA_SERVER_BIN = "/home/aj/llama.cpp/build/bin/llama-server" MODELS_DIR = "/home/aj/models" MODELS_PRESET = "/home/aj/models/models.ini" -CHECK_INTERVAL = 30 # seconds between health checks -HEALTH_TIMEOUT = 10 # seconds to wait for health response -DEEP_CHECK_TIMEOUT = 30 # seconds to wait for model probe -MAX_CONSECUTIVE_FAILURES = 2 # restart after this many failures +CHECK_INTERVAL = 30 # seconds between checks +HEALTH_TIMEOUT = 10 # router health check timeout +DEEP_CHECK_TIMEOUT = 30 # model probe timeout +MAX_HEALTH_FAILURES = 2 # restart after N router health failures +MAX_MODEL_PROBE_FAILURES = 5 # ignore model after N probe failures +MODEL_LOAD_GRACE_PERIOD = 300 # skip probing models loaded within last 5 min +MODEL_FAILURE_COOLDOWN = 600 # stop probing a failed model for 10 min -# Track failures per server -failure_counts = {} +# Per-server health failure tracking (router-level) +health_failures = {} + +# Per-model failure tracking: {port: {model_name: count}} +model_failures = defaultdict(lambda: defaultdict(int)) + +# Per-model first-seen timestamps: {port: {model_name: timestamp}} +model_first_seen = defaultdict(dict) + +# Per-model cooldown timestamps: {port: {model_name: timestamp}} +model_cooldowns = defaultdict(dict) def log(message): @@ -49,16 +68,20 @@ def check_health(port): def get_loaded_models(port): - """Get list of models the router reports as loaded.""" + """Get list of models the router reports as loaded, with status info. + + Returns list of dicts: [{"id": "model-name", "status": "loaded"}, ...] + """ try: response = requests.get(f"http://localhost:{port}/v1/models", timeout=HEALTH_TIMEOUT) if response.status_code != 200: return [] data = response.json() - return [ - m["id"] for m in data.get("data", []) - if m.get("status", {}).get("value") == "loaded" - ] + models = [] + for m in data.get("data", []): + status_value = m.get("status", {}).get("value", "unknown") + models.append({"id": m["id"], "status": status_value}) + return models except Exception: return [] @@ -80,6 +103,19 @@ def probe_model(port, model_name): return False +def unload_model(port, model_name): + """Ask the router to unload a specific model.""" + try: + response = requests.post( + f"http://localhost:{port}/models/unload", + json={"model": model_name}, + timeout=HEALTH_TIMEOUT, + ) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + def check_zombies(): """Check for zombie llama-server processes.""" result = subprocess.run(["ps", "aux"], capture_output=True, text=True) @@ -145,62 +181,140 @@ def restart_server(server): restart_manual(server) +def clear_model_tracking(port): + """Clear all per-model tracking state for a server after restart.""" + model_failures[port].clear() + model_first_seen[port].clear() + model_cooldowns[port].clear() + + def run_watchdog(): """Main watchdog loop.""" log("llama.cpp watchdog starting...") + log(f"Config: health_failures_threshold={MAX_HEALTH_FAILURES}, " + f"model_probe_failures_threshold={MAX_MODEL_PROBE_FAILURES}, " + f"grace_period={MODEL_LOAD_GRACE_PERIOD}s, " + f"cooldown={MODEL_FAILURE_COOLDOWN}s") for server in SERVERS: - failure_counts[server["port"]] = 0 + health_failures[server["port"]] = 0 while True: try: + now = time.time() + # --- Phase 1: Check for zombie child processes --- zombies = check_zombies() if zombies: log(f"Found {len(zombies)} zombie llama-server process(es): {zombies}") for server in SERVERS: restart_server(server) - failure_counts[server["port"]] = 0 + health_failures[server["port"]] = 0 + clear_model_tracking(server["port"]) time.sleep(CHECK_INTERVAL) continue - # --- Phase 2: Basic health checks --- + # --- Phase 2: Router health checks --- for server in SERVERS: port = server["port"] name = server["name"] if not check_health(port): - failure_counts[port] += 1 - log(f"{name} health check failed ({failure_counts[port]}/{MAX_CONSECUTIVE_FAILURES})") + health_failures[port] += 1 + log(f"{name} health check failed ({health_failures[port]}/{MAX_HEALTH_FAILURES})") - if failure_counts[port] >= MAX_CONSECUTIVE_FAILURES: + if health_failures[port] >= MAX_HEALTH_FAILURES: restart_server(server) - failure_counts[port] = 0 + health_failures[port] = 0 + clear_model_tracking(port) continue - # --- Phase 3: Deep check - probe loaded models --- - loaded = get_loaded_models(port) - if loaded: - all_ok = True - for model in loaded: - if not probe_model(port, model): - log(f"{name}: loaded model '{model}' is unreachable!") - all_ok = False - break + # Health check passed - reset health failure counter + if health_failures[port] > 0: + log(f"{name} router recovered after {health_failures[port]} failure(s)") + health_failures[port] = 0 - if not all_ok: - failure_counts[port] += 1 - log(f"{name} deep check failed ({failure_counts[port]}/{MAX_CONSECUTIVE_FAILURES})") + # --- Phase 3: Deep check - probe loaded models individually --- + models = get_loaded_models(port) + if not models: + continue - if failure_counts[port] >= MAX_CONSECUTIVE_FAILURES: - restart_server(server) - failure_counts[port] = 0 + # Track which models are currently loaded so we can clean up stale entries + current_model_ids = {m["id"] for m in models} + + # Clean up tracking for models that are no longer loaded + for tracking_dict in [model_first_seen, model_failures, model_cooldowns]: + if port in tracking_dict: + stale = [m for m in tracking_dict[port] if m not in current_model_ids] + for m in stale: + del tracking_dict[port][m] + + for model_info in models: + model_name = model_info["id"] + model_status = model_info["status"] + + # Skip models that are still loading + if model_status == "loading": + log(f"{name}: model '{model_name}' is still loading, skipping probe") continue - # All checks passed - if failure_counts[port] > 0: - log(f"{name} recovered") - failure_counts[port] = 0 + # Skip models that aren't fully loaded + if model_status != "loaded": + continue + + # Track first-seen time for grace period + if model_name not in model_first_seen[port]: + model_first_seen[port][model_name] = now + log(f"{name}: new model '{model_name}' detected, " + f"grace period {MODEL_LOAD_GRACE_PERIOD}s before probing") + continue + + # Skip if within grace period + age = now - model_first_seen[port][model_name] + if age < MODEL_LOAD_GRACE_PERIOD: + remaining = int(MODEL_LOAD_GRACE_PERIOD - age) + # Only log occasionally to avoid spam + if int(age) % 60 < CHECK_INTERVAL: + log(f"{name}: model '{model_name}' in grace period ({remaining}s remaining)") + continue + + # Skip if in cooldown from previous failures + if model_name in model_cooldowns[port]: + cooldown_elapsed = now - model_cooldowns[port][model_name] + if cooldown_elapsed < MODEL_FAILURE_COOLDOWN: + remaining = int(MODEL_FAILURE_COOLDOWN - cooldown_elapsed) + if int(cooldown_elapsed) % 120 < CHECK_INTERVAL: + log(f"{name}: model '{model_name}' in cooldown ({remaining}s remaining)") + continue + else: + # Cooldown expired, give it another chance + log(f"{name}: model '{model_name}' cooldown expired, resuming probes") + del model_cooldowns[port][model_name] + model_failures[port][model_name] = 0 + + # Probe the model + if probe_model(port, model_name): + # Probe succeeded - reset failure counter + if model_failures[port][model_name] > 0: + log(f"{name}: model '{model_name}' recovered after " + f"{model_failures[port][model_name]} failure(s)") + model_failures[port][model_name] = 0 + else: + # Probe failed + model_failures[port][model_name] += 1 + fail_count = model_failures[port][model_name] + log(f"{name}: model '{model_name}' probe failed " + f"({fail_count}/{MAX_MODEL_PROBE_FAILURES})") + + if fail_count >= MAX_MODEL_PROBE_FAILURES: + # Try to unload the bad model to free resources + if unload_model(port, model_name): + log(f"{name}: model '{model_name}' persistently unreachable, " + f"unloaded successfully, cooldown {MODEL_FAILURE_COOLDOWN}s (NO restart)") + else: + log(f"{name}: model '{model_name}' persistently unreachable, " + f"unload failed, cooldown {MODEL_FAILURE_COOLDOWN}s (NO restart)") + model_cooldowns[port][model_name] = now time.sleep(CHECK_INTERVAL)