Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, PointIdsList | |
| from fastembed import TextEmbedding, SparseTextEmbedding | |
| import logging | |
| import uuid | |
| from .output_files_generator import generate_yaml_file, generate_markdown_files | |
| from .config import config | |
| from .exceptions import ConfigurationError | |
| from .database import validate_point_payload, get_dense_vector_name, get_sparse_vector_name | |
| logger = logging.getLogger('fabric_to_espanso') | |
| # TODO: Make a summary of the prompts using a call to an LLM for every prompt and store that in the purpose field | |
| # of the database instead of the extracted purpose from the markdown files and use that summary to create the embeddings | |
| def get_embedding(text: str) -> list: | |
| """ | |
| Generate embedding vector for the given text using FastEmbed. | |
| Args: | |
| text (str): Text to generate embedding for | |
| Returns: | |
| list: Tuple of (dense_embeddings, sparse_embeddings) | |
| """ | |
| if not config.embedding.use_fastembed: | |
| msg = "Embedding model not initialized. Set use_fastembed to True in the configuration." | |
| logger.error(msg) | |
| raise ConfigurationError(msg) | |
| # Models are lazily initialized only when needed | |
| if not hasattr(get_embedding, '_dense_model'): | |
| get_embedding._dense_model = TextEmbedding(model_name=config.embedding.dense_model_name) | |
| if not hasattr(get_embedding, '_sparse_model'): | |
| get_embedding._sparse_model = SparseTextEmbedding(model_name=config.embedding.sparse_model_name) | |
| dense_embeddings = list(get_embedding._dense_model.embed(text))[0] | |
| sparse_embedding = list(get_embedding._sparse_model.embed(text, return_dense=False))[0] | |
| return dense_embeddings, { | |
| 'indices': sparse_embedding.indices.tolist(), | |
| 'values': sparse_embedding.values.tolist() | |
| } | |
| def update_qdrant_database(client: QdrantClient, collection_name: str, new_files: list, modified_files: list, deleted_files: list): | |
| """ | |
| Update the Qdrant database based on detected file changes. | |
| Args: | |
| client (QdrantClient): An initialized Qdrant client. | |
| new_files (list): List of new files to be added to the database. | |
| modified_files (list): List of modified files to be updated in the database. | |
| deleted_files (list): List of deleted files to be removed from the database. | |
| """ | |
| if not config.embedding.use_fastembed: | |
| msg = "Embedding model not initialized. Set use_fastembed to True in the configuration." | |
| logger.info(msg) | |
| return | |
| try: | |
| # Add new files | |
| for file in new_files: | |
| try: | |
| payload_new = validate_point_payload(file) | |
| # Get vector names from the collection configuration | |
| dense_vector_name = get_dense_vector_name(client, collection_name) | |
| sparse_vector_name = get_sparse_vector_name(client, collection_name) | |
| # Create point with the correct vector names | |
| point = PointStruct( | |
| id=str(uuid.uuid4()), # Generate a new UUID for each point | |
| vector={ | |
| dense_vector_name: get_embedding(payload_new['purpose'])[0], | |
| sparse_vector_name: get_embedding(payload_new['purpose'])[1] | |
| }, | |
| payload={ | |
| "filename": payload_new['filename'], | |
| "content": payload_new['content'], | |
| "purpose": payload_new['purpose'], | |
| "date": payload_new['last_modified'], | |
| "filesize": payload_new['filesize'], | |
| "trigger": payload_new['trigger'], | |
| } | |
| ) | |
| client.upsert(collection_name=collection_name, points=[point]) # Update the database with the new file | |
| logger.info(f"Added new file to database: {file['filename']}") | |
| except ConfigurationError as e: | |
| logger.error(f"Skipping new file: {str(e)}") | |
| # Update modified files | |
| for file in modified_files: | |
| try: | |
| # Query the database to find the point with the matching filename | |
| scroll_result = client.scroll( | |
| collection_name=collection_name, | |
| scroll_filter=Filter( | |
| must=[FieldCondition(key="filename", match=MatchValue(value=file['filename']))] | |
| ), | |
| limit=1 | |
| )[0] | |
| # TODO: Add handling of cases of multiple entries with the same filename | |
| if scroll_result: | |
| point_id = scroll_result[0].id | |
| payload_current = validate_point_payload(file, point_id) | |
| # Update the existing point with the new file data | |
| # Get vector names from the collection configuration | |
| dense_vector_name = get_dense_vector_name(client, collection_name) | |
| sparse_vector_name = get_sparse_vector_name(client, collection_name) | |
| # Create point with the correct vector names | |
| point = PointStruct( | |
| id=point_id, | |
| vector={ | |
| dense_vector_name: get_embedding(payload_current['purpose'])[0], | |
| sparse_vector_name: get_embedding(payload_current['purpose'])[1] | |
| }, | |
| payload={ | |
| "filename": payload_current['filename'], | |
| "content": file['content'], | |
| "purpose": file['purpose'], | |
| "date": file['last_modified'], | |
| "filesize": file['filesize'], | |
| "trigger": payload_current['trigger'], | |
| } | |
| ) | |
| client.upsert(collection_name=collection_name, points=[point]) | |
| logger.info(f"Updated modified file in database: {payload_current['filename']}") | |
| else: | |
| logger.warning(f"File not found in database for update: {file['filename']}") | |
| except ConfigurationError as e: | |
| logger.error(f"Skipping modified file: {str(e)}") | |
| # Delete removed files | |
| for filename in deleted_files: | |
| # Query the database to find the point with the matching filename | |
| scroll_result = client.scroll( | |
| collection_name=collection_name, | |
| scroll_filter=Filter( | |
| must=[FieldCondition(key="filename", match=MatchValue(value=filename))] | |
| ), | |
| limit=1 | |
| )[0] | |
| # TODO: Add handling of cases of multiple entries with the same filename | |
| if scroll_result: | |
| point_id = scroll_result[0].id | |
| client.delete( | |
| collection_name=collection_name, | |
| points_selector=PointIdsList(points=[point_id]) | |
| ) | |
| logger.info(f"Deleted file from database: {filename}") | |
| else: | |
| logger.warning(f"File not found in database for deletion: {filename}") | |
| logger.info("Database update completed successfully") | |
| # Generate new YAML file for use with espanso after database update | |
| print("Generating YAML file...") | |
| generate_yaml_file(client, config.embedding.collection_name, config.yaml_output_folder) | |
| # Generate markdown files for use with obsidian after database update | |
| print("Generating markdown files...") | |
| generate_markdown_files(client, config.embedding.collection_name, config.obsidian_output_folder) | |
| except Exception as e: | |
| logger.error(f"Error updating Qdrant database: {str(e)}", exc_info=True) | |
| raise |