File size: 4,715 Bytes
3f55f46
a715158
 
728e522
a715158
 
0ed3c3f
728e522
 
98fb113
a715158
0ed3c3f
 
98fb113
0ed3c3f
a715158
 
 
 
3f55f46
728e522
a715158
3f55f46
 
 
 
 
 
0ed3c3f
 
3f55f46
0ed3c3f
3f55f46
0ed3c3f
3f55f46
728e522
 
0ed3c3f
728e522
 
 
 
 
3f55f46
 
728e522
3f55f46
a715158
0ed3c3f
728e522
 
3f55f46
728e522
 
3f55f46
 
728e522
0ed3c3f
728e522
 
 
 
3f55f46
 
 
 
 
728e522
 
 
 
 
 
 
 
a715158
728e522
0ed3c3f
728e522
0ed3c3f
 
728e522
0ed3c3f
728e522
 
a715158
 
728e522
 
 
 
 
 
 
3f55f46
a715158
0ed3c3f
728e522
 
0ed3c3f
3f55f46
 
728e522
0ed3c3f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# core/connection_manager.py (FINAL, CORRECTED, AND ROBUST VERSION)
"""
Manages WebSocket connections and broadcasts messages to all clients.
Uses a dedicated broadcast queue per client to prevent blocking.
"""
import asyncio
import logging
from typing import Dict, Union
from fastapi import WebSocket, WebSocketDisconnect
from performance_config import PERF_CONFIG

logger = logging.getLogger(__name__)

MAX_CONNECTIONS = PERF_CONFIG["max_connections"]

class ConnectionManager:
    """Manages active WebSocket connections and message broadcasting."""

    def __init__(self):
        # Maps a WebSocket object to a dictionary containing its state
        self.active_connections: Dict[WebSocket, Dict] = {}

    async def connect(self, websocket: WebSocket, station_name: str) -> bool:
        """
        Handles the logic for a new connection AFTER it has been accepted.
        Returns True if the connection is kept, False if it is rejected.
        """
        # STEP 1: Check capacity. If full, send a "wait" message, close, and report failure.
        if len(self.active_connections) >= MAX_CONNECTIONS:
            logger.warning(f"Connection refused: Maximum capacity of {MAX_CONNECTIONS} reached.")
            await websocket.send_json({"type": "wait", "payload": {"message": "The session is currently full. Please try again later."}})
            await websocket.close()
            return False

        # STEP 2: If capacity is available, create resources for the new connection.
        broadcast_queue = asyncio.Queue(maxsize=500)
        broadcaster_task = asyncio.create_task(self._broadcaster_loop(websocket, broadcast_queue))
        
        self.active_connections[websocket] = {
            "station": station_name,
            "queue": broadcast_queue,
            "task": broadcaster_task
        }
        
        # STEP 3: Announce the new user count and report success.
        self.broadcast_user_count()
        return True

    async def disconnect(self, websocket: WebSocket):
        """Disconnects a WebSocket and cancels its broadcaster task."""
        if websocket in self.active_connections:
            # Cancel the dedicated broadcaster task for this client
            self.active_connections[websocket]["task"].cancel()
            del self.active_connections[websocket]
        
        # Announce the new user count to the remaining clients
        self.broadcast_user_count()

    def broadcast_to_station(self, station_name: str, message: Union[dict, bytes]):
        """Puts a message into the queue for each client listening to a specific station."""
        for conn, data in self.active_connections.items():
            if data["station"] == station_name:
                try:
                    # Use a non-blocking put to avoid the services from ever stalling.
                    data["queue"].put_nowait(message)
                except asyncio.QueueFull:
                    logger.warning(f"Broadcast queue full for client {conn.client}. Dropping message.")
    
    def broadcast_to_all(self, message: Union[dict, bytes]):
        """Puts a message into the queue for every connected client."""
        for data in self.active_connections.values():
             try:
                data["queue"].put_nowait(message)
             except asyncio.QueueFull:
                logger.warning("Broadcast queue full for a client during broadcast_to_all. Dropping message.")

    def broadcast_user_count(self):
        """Broadcasts the current number of connected users."""
        self.broadcast_to_all({
            "type": "user_count",
            "payload": {"count": self.get_connection_count()}
        })

    async def _broadcaster_loop(self, websocket: WebSocket, queue: asyncio.Queue):
        """The internal loop that sends queued messages to a single client."""
        while True:
            try:
                message = await queue.get()
                if isinstance(message, dict):
                    await websocket.send_json(message)
                elif isinstance(message, bytes):
                    await websocket.send_bytes(message)
                queue.task_done()
            except (WebSocketDisconnect, asyncio.CancelledError):
                logger.info(f"Broadcaster for {websocket.client} cancelled or client disconnected.")
                break
            except Exception as e:
                logger.error(f"Broadcast loop error for {websocket.client}: {e}", exc_info=True)
                break
        
        # When the loop breaks, ensure the disconnect logic is run for this client.
        await self.disconnect(websocket)

    def get_connection_count(self) -> int:
        return len(self.active_connections)