File size: 3,300 Bytes
0a72d15
 
 
 
 
 
8aa8219
 
 
 
0a72d15
 
 
 
 
 
 
8aa8219
 
 
 
 
 
 
0a72d15
 
 
 
f69b655
0a72d15
 
 
 
 
f69b655
 
0a72d15
 
 
 
8aa8219
 
 
0a72d15
8aa8219
 
 
f69b655
 
 
 
 
8aa8219
 
 
 
f69b655
8aa8219
 
f69b655
8aa8219
0a72d15
 
 
 
 
 
8aa8219
0a72d15
 
 
f69b655
0a72d15
 
 
 
8aa8219
 
 
 
 
f69b655
0a72d15
8aa8219
0a72d15
8aa8219
0a72d15
 
8aa8219
 
0a72d15
 
 
8aa8219
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# endpoint_utils.py
from __future__ import annotations
from typing import Optional, Tuple, Callable, Dict, Any
from urllib.parse import urlparse
import os, time, requests

def hf_headers():
    tok = os.environ.get("HF_TOKEN", "").strip()
    return {"Authorization": f"Bearer {tok}"} if tok else {}
    
def _valid_uri(uri: Optional[str]) -> bool:
    if not uri:
        return False
    p = urlparse(uri)
    return p.scheme in {"http", "https"} and bool(p.netloc)


def _detail(resp: requests.Response) -> str:
    try:
        j = resp.json()
        return (j.get("error") or j.get("message") or "").strip()
    except Exception:
        return (resp.text or "").strip()

def wake_endpoint(
    uri: Optional[str],
    *,
    token: Optional[str] = None,
    max_wait: int = 600,          # was 180 — bump to 10 minutes
    poll_every: float = 5.0,
    warm_payload: Optional[Dict[str, Any]] = None,
    log: Callable[[str], None] = lambda _: None,
) -> Tuple[bool, Optional[str]]:
    """
    Wake a scale-to-zero HF Inference Endpoint by nudging it, then polling until ready.
    Returns (True, None) if ready; otherwise (False, "<last status/message>").
    """
    if not _valid_uri(uri):
        return False, "invalid or missing URI (expect http(s)://...)"

    headers = hf_headers()
    if not headers:
        log("⚠️ HF_TOKEN not set — POST / will likely return 401/403.")

    last = "no response"
    
     # /health probe (auth included if required)
    try:
        hr = requests.get(f"{uri.rstrip('/')}/health", headers=headers, timeout=5)
        if hr.ok:
            log("✅ /health reports ready.")
            return True, None
        last = f"HTTP {hr.status_code}{_detail(hr) or 'warming?'}"
        log(f"[health] {last}")
        if hr.status_code in (401, 403):
            return False, f"Unauthorized (check HF_TOKEN). {last}"
    except requests.RequestException as e:
        last = type(e).__name__
        log(f"[health] {last}")

    # warmup nudge
    payload = warm_payload if warm_payload is not None else {"inputs": "wake"}
    try:
        requests.post(uri, headers=headers, json=payload, timeout=5)
    except requests.RequestException:
        pass

    # poll
    deadline = time.time() + max_wait
    while time.time() < deadline:
        try:
            r = requests.post(uri, headers=headers, json={"inputs": "ping"}, timeout=8)
            if r.ok:
                log("✅ Endpoint is awake and responsive.")
                return True, None

            d = _detail(r)
            last = f"HTTP {r.status_code}" + (f" – {d}" if d else "")

            if r.status_code in (401, 403):
                return False, f"Unauthorized (check HF_TOKEN, org access). {last}"

            if r.status_code in (429, 503, 504):
                log(f"[server] {d or 'warming up'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")
            else:
                log(f"[server] {d or 'unexpected response'} (HTTP {r.status_code}); retrying in {int(poll_every)}s…")

        except requests.RequestException as e:
            last = type(e).__name__
            log(f"[client] {last}; retrying in {int(poll_every)}s…")

        time.sleep(poll_every)

    return False, f"Timed out after {max_wait}s — last status: {last}"