Spaces:
Running
Running
File size: 4,715 Bytes
3f55f46 a715158 728e522 a715158 0ed3c3f 728e522 98fb113 a715158 0ed3c3f 98fb113 0ed3c3f a715158 3f55f46 728e522 a715158 3f55f46 0ed3c3f 3f55f46 0ed3c3f 3f55f46 0ed3c3f 3f55f46 728e522 0ed3c3f 728e522 3f55f46 728e522 3f55f46 a715158 0ed3c3f 728e522 3f55f46 728e522 3f55f46 728e522 0ed3c3f 728e522 3f55f46 728e522 a715158 728e522 0ed3c3f 728e522 0ed3c3f 728e522 0ed3c3f 728e522 a715158 728e522 3f55f46 a715158 0ed3c3f 728e522 0ed3c3f 3f55f46 728e522 0ed3c3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# core/connection_manager.py (FINAL, CORRECTED, AND ROBUST VERSION)
"""
Manages WebSocket connections and broadcasts messages to all clients.
Uses a dedicated broadcast queue per client to prevent blocking.
"""
import asyncio
import logging
from typing import Dict, Union
from fastapi import WebSocket, WebSocketDisconnect
from performance_config import PERF_CONFIG
logger = logging.getLogger(__name__)
MAX_CONNECTIONS = PERF_CONFIG["max_connections"]
class ConnectionManager:
"""Manages active WebSocket connections and message broadcasting."""
def __init__(self):
# Maps a WebSocket object to a dictionary containing its state
self.active_connections: Dict[WebSocket, Dict] = {}
async def connect(self, websocket: WebSocket, station_name: str) -> bool:
"""
Handles the logic for a new connection AFTER it has been accepted.
Returns True if the connection is kept, False if it is rejected.
"""
# STEP 1: Check capacity. If full, send a "wait" message, close, and report failure.
if len(self.active_connections) >= MAX_CONNECTIONS:
logger.warning(f"Connection refused: Maximum capacity of {MAX_CONNECTIONS} reached.")
await websocket.send_json({"type": "wait", "payload": {"message": "The session is currently full. Please try again later."}})
await websocket.close()
return False
# STEP 2: If capacity is available, create resources for the new connection.
broadcast_queue = asyncio.Queue(maxsize=500)
broadcaster_task = asyncio.create_task(self._broadcaster_loop(websocket, broadcast_queue))
self.active_connections[websocket] = {
"station": station_name,
"queue": broadcast_queue,
"task": broadcaster_task
}
# STEP 3: Announce the new user count and report success.
self.broadcast_user_count()
return True
async def disconnect(self, websocket: WebSocket):
"""Disconnects a WebSocket and cancels its broadcaster task."""
if websocket in self.active_connections:
# Cancel the dedicated broadcaster task for this client
self.active_connections[websocket]["task"].cancel()
del self.active_connections[websocket]
# Announce the new user count to the remaining clients
self.broadcast_user_count()
def broadcast_to_station(self, station_name: str, message: Union[dict, bytes]):
"""Puts a message into the queue for each client listening to a specific station."""
for conn, data in self.active_connections.items():
if data["station"] == station_name:
try:
# Use a non-blocking put to avoid the services from ever stalling.
data["queue"].put_nowait(message)
except asyncio.QueueFull:
logger.warning(f"Broadcast queue full for client {conn.client}. Dropping message.")
def broadcast_to_all(self, message: Union[dict, bytes]):
"""Puts a message into the queue for every connected client."""
for data in self.active_connections.values():
try:
data["queue"].put_nowait(message)
except asyncio.QueueFull:
logger.warning("Broadcast queue full for a client during broadcast_to_all. Dropping message.")
def broadcast_user_count(self):
"""Broadcasts the current number of connected users."""
self.broadcast_to_all({
"type": "user_count",
"payload": {"count": self.get_connection_count()}
})
async def _broadcaster_loop(self, websocket: WebSocket, queue: asyncio.Queue):
"""The internal loop that sends queued messages to a single client."""
while True:
try:
message = await queue.get()
if isinstance(message, dict):
await websocket.send_json(message)
elif isinstance(message, bytes):
await websocket.send_bytes(message)
queue.task_done()
except (WebSocketDisconnect, asyncio.CancelledError):
logger.info(f"Broadcaster for {websocket.client} cancelled or client disconnected.")
break
except Exception as e:
logger.error(f"Broadcast loop error for {websocket.client}: {e}", exc_info=True)
break
# When the loop breaks, ensure the disconnect logic is run for this client.
await self.disconnect(websocket)
def get_connection_count(self) -> int:
return len(self.active_connections) |