Spaces:
Running
Running
| import enum | |
| from langchain_core.messages import BaseMessage | |
| from pydantic import BaseModel, field_validator, model_validator | |
| from typing_extensions import TypedDict | |
| from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES | |
| from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI | |
| class File(TypedDict): | |
| """File schema.""" | |
| path: str | |
| name: str | |
| type: str | |
| class ChatOutputResponse(BaseModel): | |
| """Chat output response schema.""" | |
| message: str | list[str | dict] | |
| sender: str | None = MESSAGE_SENDER_AI | |
| sender_name: str | None = MESSAGE_SENDER_NAME_AI | |
| session_id: str | None = None | |
| stream_url: str | None = None | |
| component_id: str | None = None | |
| files: list[File] = [] | |
| type: str | |
| def validate_files(cls, files): | |
| """Validate files.""" | |
| if not files: | |
| return files | |
| for file in files: | |
| if not isinstance(file, dict): | |
| msg = "Files must be a list of dictionaries." | |
| raise ValueError(msg) # noqa: TRY004 | |
| if not all(key in file for key in ["path", "name", "type"]): | |
| # If any of the keys are missing, we should extract the | |
| # values from the file path | |
| path = file.get("path") | |
| if not path: | |
| msg = "File path is required." | |
| raise ValueError(msg) | |
| name = file.get("name") | |
| if not name: | |
| name = path.split("/")[-1] | |
| file["name"] = name | |
| type_ = file.get("type") | |
| if not type_: | |
| # get the file type from the path | |
| extension = path.split(".")[-1] | |
| file_types = set(TEXT_FILE_TYPES + IMG_FILE_TYPES) | |
| if extension and extension in file_types: | |
| type_ = extension | |
| else: | |
| for file_type in file_types: | |
| if file_type in path: | |
| type_ = file_type | |
| break | |
| if not type_: | |
| msg = "File type is required." | |
| raise ValueError(msg) | |
| file["type"] = type_ | |
| return files | |
| def from_message( | |
| cls, | |
| message: BaseMessage, | |
| sender: str | None = MESSAGE_SENDER_AI, | |
| sender_name: str | None = MESSAGE_SENDER_NAME_AI, | |
| ): | |
| """Build chat output response from message.""" | |
| content = message.content | |
| return cls(message=content, sender=sender, sender_name=sender_name) | |
| def validate_message(self): | |
| """Validate message.""" | |
| # The idea here is ensure the \n in message | |
| # is compliant with markdown if sender is machine | |
| # so, for example: | |
| # \n\n -> \n\n | |
| # \n -> \n\n | |
| if self.sender != MESSAGE_SENDER_AI: | |
| return self | |
| # We need to make sure we don't duplicate \n | |
| # in the message | |
| message = self.message.replace("\n\n", "\n") | |
| self.message = message.replace("\n", "\n\n") | |
| return self | |
| class DataOutputResponse(BaseModel): | |
| """Data output response schema.""" | |
| data: list[dict | None] | |
| class ContainsEnumMeta(enum.EnumMeta): | |
| def __contains__(cls, item) -> bool: | |
| try: | |
| cls(item) | |
| except ValueError: | |
| return False | |
| else: | |
| return True | |