Pierre Chapuis
commited on
improve API client
Browse files- pyproject.toml +1 -0
- src/app.py +8 -14
- src/fg.py +192 -60
- typings/pillow_heif/__init__.pyi +2 -0
pyproject.toml
CHANGED
|
@@ -51,3 +51,4 @@ select = [
|
|
| 51 |
[tool.pyright]
|
| 52 |
include = ["src"]
|
| 53 |
exclude = ["**/__pycache__"]
|
|
|
|
|
|
| 51 |
[tool.pyright]
|
| 52 |
include = ["src"]
|
| 53 |
exclude = ["**/__pycache__"]
|
| 54 |
+
strict = ["src/fg.py"]
|
src/app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import dataclasses as dc
|
| 2 |
import io
|
|
|
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
import gradio as gr
|
|
@@ -24,6 +25,7 @@ with env.prefixed("ERASER_"):
|
|
| 24 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
| 25 |
|
| 26 |
|
|
|
|
| 27 |
def _ctx() -> EditorAPIContext:
|
| 28 |
assert API_USER is not None
|
| 29 |
assert API_PASSWORD is not None
|
|
@@ -51,13 +53,7 @@ class ProcessParams:
|
|
| 51 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
| 52 |
with io.BytesIO() as f:
|
| 53 |
params.image.save(f, format="JPEG")
|
| 54 |
-
|
| 55 |
-
response = await client.post(
|
| 56 |
-
f"{ctx.uri}/state/upload",
|
| 57 |
-
files={"file": f},
|
| 58 |
-
headers=ctx.auth_headers,
|
| 59 |
-
)
|
| 60 |
-
response.raise_for_status()
|
| 61 |
st_input = response.json()["state"]
|
| 62 |
|
| 63 |
if params.bbox:
|
|
@@ -74,13 +70,11 @@ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
|
| 74 |
st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
|
| 75 |
st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
)
|
| 83 |
-
response.raise_for_status()
|
| 84 |
f = io.BytesIO()
|
| 85 |
f.write(response.content)
|
| 86 |
f.seek(0)
|
|
|
|
| 1 |
import dataclasses as dc
|
| 2 |
import io
|
| 3 |
+
from functools import cache
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
| 25 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
| 26 |
|
| 27 |
|
| 28 |
+
@cache
|
| 29 |
def _ctx() -> EditorAPIContext:
|
| 30 |
assert API_USER is not None
|
| 31 |
assert API_PASSWORD is not None
|
|
|
|
| 53 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
| 54 |
with io.BytesIO() as f:
|
| 55 |
params.image.save(f, format="JPEG")
|
| 56 |
+
response = await ctx.request("POST", "state/upload", files={"file": f})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
st_input = response.json()["state"]
|
| 58 |
|
| 59 |
if params.bbox:
|
|
|
|
| 70 |
st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
|
| 71 |
st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
|
| 72 |
|
| 73 |
+
response = await ctx.request(
|
| 74 |
+
"GET",
|
| 75 |
+
f"state/image/{st_erased}",
|
| 76 |
+
params={"format": "JPEG", "resolution": "DISPLAY"},
|
| 77 |
+
)
|
|
|
|
|
|
|
| 78 |
f = io.BytesIO()
|
| 79 |
f.write(response.content)
|
| 80 |
f.seek(0)
|
src/fg.py
CHANGED
|
@@ -1,18 +1,46 @@
|
|
| 1 |
import asyncio
|
| 2 |
import dataclasses as dc
|
| 3 |
import json
|
|
|
|
| 4 |
from collections import defaultdict
|
| 5 |
-
from collections.abc import Awaitable, Callable
|
| 6 |
-
from typing import Any, Literal
|
| 7 |
|
| 8 |
import httpx
|
| 9 |
import httpx_sse
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
Priority = Literal["low", "standard", "high"]
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
@dc.dataclass(kw_only=True)
|
|
@@ -23,18 +51,33 @@ class EditorAPIContext:
|
|
| 23 |
priority: Priority = "standard"
|
| 24 |
token: str | None = None
|
| 25 |
verify: bool | str = True
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
async def __aenter__(self) -> httpx.AsyncClient:
|
| 31 |
if self._client:
|
|
|
|
|
|
|
| 32 |
return self._client
|
|
|
|
| 33 |
self._client = httpx.AsyncClient(verify=self.verify)
|
|
|
|
| 34 |
return self._client
|
| 35 |
|
| 36 |
async def __aexit__(self, *args: Any) -> None:
|
| 37 |
-
if self._client:
|
|
|
|
|
|
|
|
|
|
| 38 |
await self._client.__aexit__(*args)
|
| 39 |
self._client = None
|
| 40 |
|
|
@@ -49,62 +92,153 @@ class EditorAPIContext:
|
|
| 49 |
f"{self.uri}/auth/login",
|
| 50 |
json={"username": self.user, "password": self.password},
|
| 51 |
)
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
async def sse_loop(self) -> None:
|
| 56 |
async with self as client:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
sub_token = response.json()["token"]
|
| 60 |
url = f"{self.uri}/sub/{sub_token}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
async with (
|
| 62 |
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
|
| 63 |
-
httpx_sse.aconnect_sse(c, "GET", url) as es,
|
| 64 |
):
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
-
async def
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
)
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
self,
|
| 91 |
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
| 92 |
params: Tin,
|
| 93 |
) -> Tout:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
r = tg.create_task(outer_co(params))
|
| 106 |
-
|
| 107 |
-
return r.result()
|
| 108 |
|
| 109 |
def run_one_sync[Tin, Tout](
|
| 110 |
self,
|
|
@@ -116,18 +250,16 @@ class EditorAPIContext:
|
|
| 116 |
except RuntimeError:
|
| 117 |
loop = asyncio.new_event_loop()
|
| 118 |
asyncio.set_event_loop(loop)
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
params = {"priority": self.priority} | (params or {})
|
| 124 |
-
|
| 125 |
-
response = await client.post(
|
| 126 |
-
f"{self.uri}/skills/{uri}",
|
| 127 |
-
json=params,
|
| 128 |
-
headers=self.auth_headers,
|
| 129 |
-
)
|
| 130 |
-
response.raise_for_status()
|
| 131 |
state_id = response.json()["state"]
|
| 132 |
-
await self.sse_await(state_id)
|
| 133 |
return state_id
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import dataclasses as dc
|
| 3 |
import json
|
| 4 |
+
import logging
|
| 5 |
from collections import defaultdict
|
| 6 |
+
from collections.abc import Awaitable, Callable, Mapping
|
| 7 |
+
from typing import Any, Literal, cast
|
| 8 |
|
| 9 |
import httpx
|
| 10 |
import httpx_sse
|
| 11 |
+
from httpx._types import QueryParamTypes, RequestFiles
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
Priority = Literal["low", "standard", "high"]
|
| 16 |
|
| 17 |
|
| 18 |
+
class SSELoopStopped(RuntimeError):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Futures[T]:
|
| 23 |
+
@classmethod
|
| 24 |
+
def create_future(cls) -> asyncio.Future[T]:
|
| 25 |
+
return asyncio.get_running_loop().create_future()
|
| 26 |
+
|
| 27 |
+
def __init__(self, capacity: int = 256) -> None:
|
| 28 |
+
self.futures = defaultdict[str, asyncio.Future[T]](self.create_future)
|
| 29 |
+
self.capacity = capacity
|
| 30 |
+
|
| 31 |
+
def cull(self) -> None:
|
| 32 |
+
while len(self.futures) >= self.capacity:
|
| 33 |
+
del self.futures[next(iter(self.futures))]
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, key: str) -> asyncio.Future[T]:
|
| 36 |
+
self.cull()
|
| 37 |
+
return self.futures[key]
|
| 38 |
+
|
| 39 |
+
def __delitem__(self, key: str) -> None:
|
| 40 |
+
try:
|
| 41 |
+
del self.futures[key]
|
| 42 |
+
except KeyError:
|
| 43 |
+
pass
|
| 44 |
|
| 45 |
|
| 46 |
@dc.dataclass(kw_only=True)
|
|
|
|
| 51 |
priority: Priority = "standard"
|
| 52 |
token: str | None = None
|
| 53 |
verify: bool | str = True
|
| 54 |
+
default_timeout: float = 60.0
|
| 55 |
+
logger: logging.Logger = logger
|
| 56 |
+
max_sse_failures: int = 5
|
| 57 |
|
| 58 |
+
_client: httpx.AsyncClient | None = None
|
| 59 |
+
_client_ctx_depth: int = 0
|
| 60 |
+
_sse_futures: Futures[dict[str, Any]] = dc.field(default_factory=Futures)
|
| 61 |
+
_sse_task: asyncio.Task[None] | None = None
|
| 62 |
+
_sse_failures: int = 0
|
| 63 |
+
_sse_last_event_id: str = ""
|
| 64 |
+
_sse_retry_ms: int = 0
|
| 65 |
|
| 66 |
async def __aenter__(self) -> httpx.AsyncClient:
|
| 67 |
if self._client:
|
| 68 |
+
assert self._client_ctx_depth > 0
|
| 69 |
+
self._client_ctx_depth += 1
|
| 70 |
return self._client
|
| 71 |
+
assert self._client_ctx_depth == 0
|
| 72 |
self._client = httpx.AsyncClient(verify=self.verify)
|
| 73 |
+
self._client_ctx_depth = 1
|
| 74 |
return self._client
|
| 75 |
|
| 76 |
async def __aexit__(self, *args: Any) -> None:
|
| 77 |
+
if (not self._client) or self._client_ctx_depth <= 0:
|
| 78 |
+
raise RuntimeError("unbalanced __aexit__")
|
| 79 |
+
self._client_ctx_depth -= 1
|
| 80 |
+
if self._client_ctx_depth == 0:
|
| 81 |
await self._client.__aexit__(*args)
|
| 82 |
self._client = None
|
| 83 |
|
|
|
|
| 92 |
f"{self.uri}/auth/login",
|
| 93 |
json={"username": self.user, "password": self.password},
|
| 94 |
)
|
| 95 |
+
response.raise_for_status()
|
| 96 |
+
self.logger.debug(f"logged in as {self.user}")
|
| 97 |
+
self.token = response.json()["token"]
|
| 98 |
+
|
| 99 |
+
async def request(
|
| 100 |
+
self,
|
| 101 |
+
method: Literal["GET", "POST"],
|
| 102 |
+
url: str,
|
| 103 |
+
files: RequestFiles | None = None,
|
| 104 |
+
params: QueryParamTypes | None = None,
|
| 105 |
+
json: dict[str, Any] | None = None,
|
| 106 |
+
headers: Mapping[str, str] | None = None,
|
| 107 |
+
raise_for_status: bool = True,
|
| 108 |
+
) -> httpx.Response:
|
| 109 |
+
async def _q() -> httpx.Response:
|
| 110 |
+
return await client.request(
|
| 111 |
+
method,
|
| 112 |
+
f"{self.uri}/{url}",
|
| 113 |
+
headers=dict(headers or {}) | self.auth_headers,
|
| 114 |
+
files=files,
|
| 115 |
+
params=params,
|
| 116 |
+
json=json,
|
| 117 |
+
)
|
| 118 |
|
|
|
|
| 119 |
async with self as client:
|
| 120 |
+
r = await _q()
|
| 121 |
+
if r.status_code == 401:
|
| 122 |
+
self.logger.debug("renewing token")
|
| 123 |
+
await self.login()
|
| 124 |
+
r = await _q()
|
| 125 |
+
|
| 126 |
+
if raise_for_status:
|
| 127 |
+
r.raise_for_status()
|
| 128 |
+
return r
|
| 129 |
+
|
| 130 |
+
@classmethod
|
| 131 |
+
def decode_json(cls, data: str) -> dict[str, Any] | None:
|
| 132 |
+
try:
|
| 133 |
+
r = json.loads(data)
|
| 134 |
+
except json.JSONDecodeError:
|
| 135 |
+
return None
|
| 136 |
+
if type(r) is not dict:
|
| 137 |
+
return None
|
| 138 |
+
return cast(dict[str, Any], r)
|
| 139 |
+
|
| 140 |
+
async def _sse_loop(self) -> None:
|
| 141 |
+
response = await self.request("POST", "sub-auth")
|
| 142 |
sub_token = response.json()["token"]
|
| 143 |
url = f"{self.uri}/sub/{sub_token}"
|
| 144 |
+
headers = {"Accept": "text/event-stream"}
|
| 145 |
+
if self._sse_last_event_id:
|
| 146 |
+
retry_ms = self._sse_retry_ms + 1000 * 2**self._sse_failures
|
| 147 |
+
self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms")
|
| 148 |
+
await asyncio.sleep(retry_ms / 1000)
|
| 149 |
+
headers["Last-Event-ID"] = self._sse_last_event_id
|
| 150 |
async with (
|
| 151 |
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
|
| 152 |
+
httpx_sse.aconnect_sse(c, "GET", url, headers=headers) as es,
|
| 153 |
):
|
| 154 |
+
es.response.raise_for_status()
|
| 155 |
+
self._sse_futures["_sse_loop"].set_result({"status": "ok"})
|
| 156 |
+
try:
|
| 157 |
+
async for sse in es.aiter_sse():
|
| 158 |
+
self._sse_last_event_id = sse.id
|
| 159 |
+
self._sse_retry_ms = sse.retry or 0
|
| 160 |
+
jdata = self.decode_json(sse.data)
|
| 161 |
+
if (jdata is None) or ("state" not in jdata):
|
| 162 |
+
# Note: when the server restarts we typically get an
|
| 163 |
+
# empty string here, then the loop exits.
|
| 164 |
+
self.logger.warning(f"unexpected SSE data: {sse.data}")
|
| 165 |
+
continue
|
| 166 |
+
self._sse_futures[jdata["state"]].set_result(jdata)
|
| 167 |
+
except asyncio.CancelledError:
|
| 168 |
+
pass
|
| 169 |
|
| 170 |
+
async def sse_start(self) -> None:
|
| 171 |
+
assert self._sse_task is None
|
| 172 |
+
self._sse_last_event_id = ""
|
| 173 |
+
self._sse_retry_ms = 0
|
| 174 |
+
self._sse_task = asyncio.create_task(self._sse_loop())
|
| 175 |
+
await self.sse_await("_sse_loop")
|
| 176 |
+
self._sse_failures = 0
|
| 177 |
+
|
| 178 |
+
async def sse_recover(self) -> bool:
|
| 179 |
+
while True:
|
| 180 |
+
if self._sse_failures > self.max_sse_failures:
|
| 181 |
+
return False
|
| 182 |
+
self._sse_task = asyncio.create_task(self._sse_loop())
|
| 183 |
+
try:
|
| 184 |
+
await self.sse_await("_sse_loop")
|
| 185 |
+
return True
|
| 186 |
+
except SSELoopStopped:
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
async def sse_stop(self) -> None:
|
| 190 |
+
assert self._sse_task
|
| 191 |
+
self._sse_task.cancel()
|
| 192 |
+
await self._sse_task
|
| 193 |
+
self._sse_task = None
|
| 194 |
+
|
| 195 |
+
async def sse_await(self, state_id: str, timeout: float | None = None) -> None:
|
| 196 |
+
assert self._sse_task
|
| 197 |
+
future = self._sse_futures[state_id]
|
| 198 |
+
|
| 199 |
+
while True:
|
| 200 |
+
done, _ = await asyncio.wait(
|
| 201 |
+
{future, self._sse_task},
|
| 202 |
+
timeout=timeout or self.default_timeout,
|
| 203 |
+
return_when=asyncio.FIRST_COMPLETED,
|
| 204 |
)
|
| 205 |
+
if not done:
|
| 206 |
+
raise TimeoutError(f"state {state_id} timed out after {timeout}")
|
| 207 |
+
if self._sse_task in done:
|
| 208 |
+
self._sse_failures += 1
|
| 209 |
+
if state_id != "_sse_loop" and (await self.sse_recover()):
|
| 210 |
+
self._sse_failures = 0
|
| 211 |
+
continue
|
| 212 |
+
exception = self._sse_task.exception()
|
| 213 |
+
raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
|
| 214 |
+
break
|
| 215 |
|
| 216 |
+
assert done == {future}
|
| 217 |
+
|
| 218 |
+
jdata = future.result()
|
| 219 |
+
del self._sse_futures[state_id]
|
| 220 |
+
assert jdata["status"] == "ok", f"state {state_id} is {jdata['status']}"
|
| 221 |
+
|
| 222 |
+
async def get_meta(self, state_id: str) -> dict[str, Any]:
|
| 223 |
+
response = await self.request("GET", f"state/meta/{state_id}")
|
| 224 |
+
return response.json()
|
| 225 |
+
|
| 226 |
+
async def _run_one[Tin, Tout](
|
| 227 |
self,
|
| 228 |
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
| 229 |
params: Tin,
|
| 230 |
) -> Tout:
|
| 231 |
+
# This wraps the coroutine in the SSE loop.
|
| 232 |
+
# This is mostly useful if you use synchronous Python,
|
| 233 |
+
# otherwise you can call the functions directly.
|
| 234 |
+
if not self.token:
|
| 235 |
+
await self.login()
|
| 236 |
+
await self.sse_start()
|
| 237 |
+
try:
|
| 238 |
+
r = await co(self, params)
|
| 239 |
+
return r
|
| 240 |
+
finally:
|
| 241 |
+
await self.sse_stop()
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def run_one_sync[Tin, Tout](
|
| 244 |
self,
|
|
|
|
| 250 |
except RuntimeError:
|
| 251 |
loop = asyncio.new_event_loop()
|
| 252 |
asyncio.set_event_loop(loop)
|
| 253 |
+
return loop.run_until_complete(self._run_one(co, params))
|
| 254 |
|
| 255 |
+
async def call_skill(
|
| 256 |
+
self,
|
| 257 |
+
uri: str,
|
| 258 |
+
params: dict[str, Any] | None,
|
| 259 |
+
timeout: float | None = None,
|
| 260 |
+
) -> str:
|
| 261 |
params = {"priority": self.priority} | (params or {})
|
| 262 |
+
response = await self.request("POST", f"skills/{uri}", json=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
state_id = response.json()["state"]
|
| 264 |
+
await self.sse_await(state_id, timeout=timeout)
|
| 265 |
return state_id
|
typings/pillow_heif/__init__.pyi
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def register_heif_opener() -> None: ...
|
| 2 |
+
def register_avif_opener() -> None: ...
|