Spaces:
Running
Running
| import json | |
| from collections.abc import Sequence | |
| from uuid import UUID | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.messages import BaseMessage | |
| from loguru import logger | |
| from sqlalchemy import delete | |
| from sqlmodel import Session, col, select | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| from langflow.schema.message import Message | |
| from langflow.services.database.models.message.model import MessageRead, MessageTable | |
| from langflow.services.deps import async_session_scope, session_scope | |
| from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER | |
| def _get_variable_query( | |
| sender: str | None = None, | |
| sender_name: str | None = None, | |
| session_id: str | None = None, | |
| order_by: str | None = "timestamp", | |
| order: str | None = "DESC", | |
| flow_id: UUID | None = None, | |
| limit: int | None = None, | |
| ): | |
| stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712 | |
| if sender: | |
| stmt = stmt.where(MessageTable.sender == sender) | |
| if sender_name: | |
| stmt = stmt.where(MessageTable.sender_name == sender_name) | |
| if session_id: | |
| stmt = stmt.where(MessageTable.session_id == session_id) | |
| if flow_id: | |
| stmt = stmt.where(MessageTable.flow_id == flow_id) | |
| if order_by: | |
| col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc() | |
| stmt = stmt.order_by(col) | |
| if limit: | |
| stmt = stmt.limit(limit) | |
| return stmt | |
| def get_messages( | |
| sender: str | None = None, | |
| sender_name: str | None = None, | |
| session_id: str | None = None, | |
| order_by: str | None = "timestamp", | |
| order: str | None = "DESC", | |
| flow_id: UUID | None = None, | |
| limit: int | None = None, | |
| ) -> list[Message]: | |
| """Retrieves messages from the monitor service based on the provided filters. | |
| Args: | |
| sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") | |
| sender_name (Optional[str]): The name of the sender. | |
| session_id (Optional[str]): The session ID associated with the messages. | |
| order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". | |
| order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". | |
| flow_id (Optional[UUID]): The flow ID associated with the messages. | |
| limit (Optional[int]): The maximum number of messages to retrieve. | |
| Returns: | |
| List[Data]: A list of Data objects representing the retrieved messages. | |
| """ | |
| with session_scope() as session: | |
| stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) | |
| messages = session.exec(stmt) | |
| return [Message(**d.model_dump()) for d in messages] | |
| async def aget_messages( | |
| sender: str | None = None, | |
| sender_name: str | None = None, | |
| session_id: str | None = None, | |
| order_by: str | None = "timestamp", | |
| order: str | None = "DESC", | |
| flow_id: UUID | None = None, | |
| limit: int | None = None, | |
| ) -> list[Message]: | |
| """Retrieves messages from the monitor service based on the provided filters. | |
| Args: | |
| sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") | |
| sender_name (Optional[str]): The name of the sender. | |
| session_id (Optional[str]): The session ID associated with the messages. | |
| order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". | |
| order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". | |
| flow_id (Optional[UUID]): The flow ID associated with the messages. | |
| limit (Optional[int]): The maximum number of messages to retrieve. | |
| Returns: | |
| List[Data]: A list of Data objects representing the retrieved messages. | |
| """ | |
| async with async_session_scope() as session: | |
| stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) | |
| messages = await session.exec(stmt) | |
| return [await Message.create(**d.model_dump()) for d in messages] | |
| def add_messages(messages: Message | list[Message], flow_id: str | None = None): | |
| """Add a message to the monitor service.""" | |
| if not isinstance(messages, list): | |
| messages = [messages] | |
| if not all(isinstance(message, Message) for message in messages): | |
| types = ", ".join([str(type(message)) for message in messages]) | |
| msg = f"The messages must be instances of Message. Found: {types}" | |
| raise ValueError(msg) | |
| try: | |
| messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] | |
| with session_scope() as session: | |
| messages_models = add_messagetables(messages_models, session) | |
| return [Message(**message.model_dump()) for message in messages_models] | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None): | |
| """Add a message to the monitor service.""" | |
| if not isinstance(messages, list): | |
| messages = [messages] | |
| if not all(isinstance(message, Message) for message in messages): | |
| types = ", ".join([str(type(message)) for message in messages]) | |
| msg = f"The messages must be instances of Message. Found: {types}" | |
| raise ValueError(msg) | |
| try: | |
| messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] | |
| async with async_session_scope() as session: | |
| messages_models = await aadd_messagetables(messages_models, session) | |
| return [await Message.create(**message.model_dump()) for message in messages_models] | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| def update_messages(messages: Message | list[Message]) -> list[Message]: | |
| if not isinstance(messages, list): | |
| messages = [messages] | |
| with session_scope() as session: | |
| updated_messages: list[MessageTable] = [] | |
| for message in messages: | |
| msg = session.get(MessageTable, message.id) | |
| if msg: | |
| msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) | |
| session.add(msg) | |
| session.commit() | |
| session.refresh(msg) | |
| updated_messages.append(msg) | |
| else: | |
| logger.warning(f"Message with id {message.id} not found") | |
| return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] | |
| async def aupdate_messages(messages: Message | list[Message]) -> list[Message]: | |
| if not isinstance(messages, list): | |
| messages = [messages] | |
| async with async_session_scope() as session: | |
| updated_messages: list[MessageTable] = [] | |
| for message in messages: | |
| msg = await session.get(MessageTable, message.id) | |
| if msg: | |
| msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) | |
| session.add(msg) | |
| await session.commit() | |
| await session.refresh(msg) | |
| updated_messages.append(msg) | |
| else: | |
| logger.warning(f"Message with id {message.id} not found") | |
| return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] | |
| def add_messagetables(messages: list[MessageTable], session: Session): | |
| for message in messages: | |
| try: | |
| session.add(message) | |
| session.commit() | |
| session.refresh(message) | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| new_messages = [] | |
| for msg in messages: | |
| msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type] | |
| msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type] | |
| msg.category = msg.category or "" | |
| new_messages.append(msg) | |
| return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] | |
| async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession): | |
| try: | |
| for message in messages: | |
| session.add(message) | |
| await session.commit() | |
| for message in messages: | |
| await session.refresh(message) | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| new_messages = [] | |
| for msg in messages: | |
| msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type] | |
| msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type] | |
| msg.category = msg.category or "" | |
| new_messages.append(msg) | |
| return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] | |
| def delete_messages(session_id: str) -> None: | |
| """Delete messages from the monitor service based on the provided session ID. | |
| Args: | |
| session_id (str): The session ID associated with the messages to delete. | |
| """ | |
| with session_scope() as session: | |
| session.exec( | |
| delete(MessageTable) | |
| .where(col(MessageTable.session_id) == session_id) | |
| .execution_options(synchronize_session="fetch") | |
| ) | |
| async def adelete_messages(session_id: str) -> None: | |
| """Delete messages from the monitor service based on the provided session ID. | |
| Args: | |
| session_id (str): The session ID associated with the messages to delete. | |
| """ | |
| async with async_session_scope() as session: | |
| stmt = ( | |
| delete(MessageTable) | |
| .where(col(MessageTable.session_id) == session_id) | |
| .execution_options(synchronize_session="fetch") | |
| ) | |
| await session.exec(stmt) | |
| async def delete_message(id_: str) -> None: | |
| """Delete a message from the monitor service based on the provided ID. | |
| Args: | |
| id_ (str): The ID of the message to delete. | |
| """ | |
| async with async_session_scope() as session: | |
| message = await session.get(MessageTable, id_) | |
| if message: | |
| await session.delete(message) | |
| await session.commit() | |
| def store_message( | |
| message: Message, | |
| flow_id: str | None = None, | |
| ) -> list[Message]: | |
| """Stores a message in the memory. | |
| Args: | |
| message (Message): The message to store. | |
| flow_id (Optional[str]): The flow ID associated with the message. | |
| When running from the CustomComponent you can access this using `self.graph.flow_id`. | |
| Returns: | |
| List[Message]: A list of data containing the stored message. | |
| Raises: | |
| ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. | |
| """ | |
| if not message: | |
| logger.warning("No message provided.") | |
| return [] | |
| required_fields = ["session_id", "sender", "sender_name"] | |
| missing_fields = [field for field in required_fields if not getattr(message, field)] | |
| if missing_fields: | |
| missing_descriptions = { | |
| "session_id": "session_id (unique conversation identifier)", | |
| "sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')", | |
| "sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')", | |
| } | |
| missing = ", ".join(missing_descriptions[field] for field in missing_fields) | |
| msg = ( | |
| f"It looks like we're missing some important information: {missing}. " | |
| "Please ensure that your message includes all the required fields." | |
| ) | |
| raise ValueError(msg) | |
| if hasattr(message, "id") and message.id: | |
| return update_messages([message]) | |
| return add_messages([message], flow_id=flow_id) | |
| async def astore_message( | |
| message: Message, | |
| flow_id: str | None = None, | |
| ) -> list[Message]: | |
| """Stores a message in the memory. | |
| Args: | |
| message (Message): The message to store. | |
| flow_id (Optional[str]): The flow ID associated with the message. | |
| When running from the CustomComponent you can access this using `self.graph.flow_id`. | |
| Returns: | |
| List[Message]: A list of data containing the stored message. | |
| Raises: | |
| ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. | |
| """ | |
| if not message: | |
| logger.warning("No message provided.") | |
| return [] | |
| if not message.session_id or not message.sender or not message.sender_name: | |
| msg = "All of session_id, sender, and sender_name must be provided." | |
| raise ValueError(msg) | |
| if hasattr(message, "id") and message.id: | |
| return await aupdate_messages([message]) | |
| return await aadd_messages([message], flow_id=flow_id) | |
| class LCBuiltinChatMemory(BaseChatMessageHistory): | |
| def __init__( | |
| self, | |
| flow_id: str, | |
| session_id: str, | |
| ) -> None: | |
| self.flow_id = flow_id | |
| self.session_id = session_id | |
| def messages(self) -> list[BaseMessage]: | |
| messages = get_messages( | |
| session_id=self.session_id, | |
| ) | |
| return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages | |
| async def aget_messages(self) -> list[BaseMessage]: | |
| messages = await aget_messages( | |
| session_id=self.session_id, | |
| ) | |
| return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages | |
| def add_messages(self, messages: Sequence[BaseMessage]) -> None: | |
| for lc_message in messages: | |
| message = Message.from_lc_message(lc_message) | |
| message.session_id = self.session_id | |
| store_message(message, flow_id=self.flow_id) | |
| async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: | |
| for lc_message in messages: | |
| message = Message.from_lc_message(lc_message) | |
| message.session_id = self.session_id | |
| await astore_message(message, flow_id=self.flow_id) | |
| def clear(self) -> None: | |
| delete_messages(self.session_id) | |
| async def aclear(self) -> None: | |
| await adelete_messages(self.session_id) | |