Spaces:
Paused
Paused
| from __future__ import annotations | |
| import datetime | |
| import os | |
| import threading | |
| from collections import OrderedDict | |
| from collections.abc import Iterator | |
| from copy import copy, deepcopy | |
| from typing import TYPE_CHECKING, Any | |
| if TYPE_CHECKING: | |
| from gradio.blocks import Blocks | |
| from gradio.components import State | |
| class StateHolder: | |
| def __init__(self): | |
| self.capacity = 10000 | |
| self.session_data: OrderedDict[str, SessionState] = OrderedDict() | |
| self.time_last_used: dict[str, datetime.datetime] = {} | |
| self.lock = threading.Lock() | |
| def set_blocks(self, blocks: Blocks): | |
| self.blocks = blocks | |
| blocks.state_holder = self | |
| self.capacity = blocks.state_session_capacity | |
| def reset(self, blocks: Blocks): | |
| """Reset the state holder with new blocks. Used during reload mode.""" | |
| self.session_data = OrderedDict() | |
| # Call set blocks again to set new ids | |
| self.set_blocks(blocks) | |
| def __getitem__(self, session_id: str) -> SessionState: | |
| if session_id not in self.session_data: | |
| self.session_data[session_id] = SessionState(self.blocks) | |
| self.update(session_id) | |
| self.time_last_used[session_id] = datetime.datetime.now() | |
| return self.session_data[session_id] | |
| def __contains__(self, session_id: str): | |
| return session_id in self.session_data | |
| def update(self, session_id: str): | |
| with self.lock: | |
| if session_id in self.session_data: | |
| self.session_data.move_to_end(session_id) | |
| if len(self.session_data) > self.capacity: | |
| self.session_data.popitem(last=False) | |
| def delete_all_expired_state( | |
| self, | |
| ): | |
| for session_id in self.session_data: | |
| self.delete_state(session_id, expired_only=True) | |
| def delete_state(self, session_id: str, expired_only: bool = False): | |
| if session_id not in self.session_data: | |
| return | |
| to_delete = [] | |
| session_state = self.session_data[session_id] | |
| for component, value, expired in session_state.state_components: | |
| if not expired_only or expired: | |
| component.delete_callback(value) | |
| to_delete.append(component._id) | |
| for component in to_delete: | |
| del session_state.state_data[component] | |
| class SessionState: | |
| def __init__(self, blocks: Blocks): | |
| self.blocks_config = copy(blocks.default_config) | |
| self.state_data: dict[int, Any] = {} | |
| self._state_ttl = {} | |
| self.is_closed = False | |
| # When a session is closed, the state is stored for an hour to give the user time to reopen the session. | |
| # During testing we set to a lower value to be able to test | |
| self.STATE_TTL_WHEN_CLOSED = ( | |
| 1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600 | |
| ) | |
| def __getitem__(self, key: int) -> Any: | |
| block = self.blocks_config.blocks[key] | |
| if block.stateful: | |
| if key not in self.state_data: | |
| self.state_data[key] = deepcopy(getattr(block, "value", None)) | |
| return self.state_data[key] | |
| else: | |
| return block | |
| def __setitem__(self, key: int, value: Any): | |
| from gradio.components import State | |
| block = self.blocks_config.blocks[key] | |
| if isinstance(block, State): | |
| self._state_ttl[key] = ( | |
| block.time_to_live, | |
| datetime.datetime.now(), | |
| ) | |
| self.state_data[key] = value | |
| else: | |
| self.blocks_config.blocks[key] = value | |
| def __contains__(self, key: int): | |
| block = self.blocks_config.blocks[key] | |
| if block.stateful: | |
| return key in self.state_data | |
| else: | |
| return key in self.blocks_config.blocks | |
| def state_components(self) -> Iterator[tuple[State, Any, bool]]: | |
| from gradio.components import State | |
| for id in self.state_data: | |
| block = self.blocks_config.blocks[id] | |
| if isinstance(block, State) and id in self._state_ttl: | |
| time_to_live, created_at = self._state_ttl[id] | |
| if self.is_closed: | |
| time_to_live = self.STATE_TTL_WHEN_CLOSED | |
| value = self.state_data[id] | |
| yield ( | |
| block, | |
| value, | |
| (datetime.datetime.now() - created_at).seconds > time_to_live, | |
| ) | |