File size: 3,256 Bytes
0a72d15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f69b655
0a72d15
 
 
 
 
f69b655
 
0a72d15
 
 
 
 
 
 
 
 
f69b655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a72d15
 
 
 
 
 
 
 
 
 
f69b655
0a72d15
 
 
 
f69b655
 
 
 
 
 
 
 
 
0a72d15
f69b655
0a72d15
f69b655
0a72d15
 
f69b655
 
0a72d15
 
 
f69b655
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# endpoint_utils.py
from __future__ import annotations
from typing import Optional, Tuple, Callable, Dict, Any
from urllib.parse import urlparse
import os, time, requests


def _valid_uri(uri: Optional[str]) -> bool:
    if not uri:
        return False
    p = urlparse(uri)
    return p.scheme in {"http", "https"} and bool(p.netloc)


def wake_endpoint(
    uri: Optional[str],
    *,
    token: Optional[str] = None,
    max_wait: int = 600,          # was 180 — bump to 10 minutes
    poll_every: float = 5.0,
    warm_payload: Optional[Dict[str, Any]] = None,
    log: Callable[[str], None] = lambda _: None,
) -> Tuple[bool, Optional[str]]:
    """
    Wake a scale-to-zero HF Inference Endpoint by nudging it, then polling until ready.
    Returns (True, None) if ready; otherwise (False, "<last status/message>").
    """
    if not _valid_uri(uri):
        return False, "invalid or missing URI (expect http(s)://...)"

    headers: Dict[str, str] = {}
    tok = token or os.environ.get("HF_TOKEN")
    if tok:
        headers["Authorization"] = f"Bearer {tok}"

    # 0) Try a quick health check first (cheap)
    last_detail = "no response"
    try:
        hr = requests.get(f"{uri.rstrip('/')}/health", headers=headers, timeout=5)
        if hr.ok:
            log("✅ /health reports ready.")
            return True, None
        try:
            last_detail = (hr.json().get("error") or hr.json().get("message"))  # type: ignore
        except Exception:
            last_detail = (hr.text or "").strip()
        log(f"[health] HTTP {hr.status_code}{last_detail or 'warming?'}")
    except requests.RequestException as e:
        last_detail = type(e).__name__
        log(f"[health] {last_detail}")

    # 1) Initial nudge (ignore errors)
    payload = warm_payload if warm_payload is not None else {"inputs": "wake"}
    try:
        requests.post(uri, headers=headers, json=payload, timeout=5)
    except requests.RequestException:
        pass

    # 2) Poll until healthy or timeout
    deadline = time.time() + max_wait
    while time.time() < deadline:
        try:
            r = requests.post(uri, headers=headers, json={"inputs": "ping"}, timeout=8)
            if r.ok:
                log("✅ Endpoint is awake and responsive.")
                return True, None

            # extract any helpful server message
            detail = ""
            try:
                data = r.json()
                detail = data.get("error") or data.get("message") or ""
            except ValueError:
                detail = (r.text or "").strip()

            last_detail = f"HTTP {r.status_code}" + (f" – {detail}" if detail else "")
            if r.status_code in (429, 503, 504):
                log(f"[server] {detail or 'warming up'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")
            else:
                log(f"[server] {detail or 'unexpected response'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")

        except requests.RequestException as e:
            last_detail = type(e).__name__
            log(f"[client] {last_detail}; retrying in {int(poll_every)}s…")

        time.sleep(poll_every)

    return False, f"Timed out after {max_wait}s — last status: {last_detail}"