| """Utility functions for interacting with the SQLite database.""" | |
| import io | |
| import logging | |
| import sqlite3 | |
| from typing import Any, List, Optional | |
| import torch | |
| def select_tensors( | |
| db_path: str, | |
| table_name: str, | |
| keys: List[str] = ['layer', 'pooling_method', 'tensor_dim', 'tensor'], | |
| sql_where: Optional[str] = None, | |
| ) -> List[Any]: | |
| """Select and return all tensors from the specified SQLite database and table. | |
| Args: | |
| db_path (str): Path to the SQLite database file. | |
| table_name (str): Name of the table to query. | |
| keys (List[str]): List of keys to select from the database. | |
| sql_where (str): Optional SQL WHERE clause to filter results. | |
| Returns: | |
| List[Any]: A list of tensors retrieved from the database. | |
| """ | |
| if 'tensor' not in keys: | |
| logging.warning("'tensor' key should be included to retrieve tensors; automatically adding it.") | |
| keys.append('tensor') | |
| final_results = [] | |
| with sqlite3.connect(db_path) as connection: | |
| cursor = connection.cursor() | |
| query = f'SELECT {", ".join(keys)} FROM {table_name}' | |
| if sql_where: | |
| assert sql_where.strip().lower().startswith('where'), "sql_where should start with 'WHERE'" | |
| query += f' {sql_where}' | |
| cursor.execute(query) | |
| results = cursor.fetchall() | |
| for row in results: | |
| result_item = {key: value for key, value in zip(keys, row)} | |
| result_item['tensor'] = torch.load(io.BytesIO(result_item['tensor']), map_location='cpu') | |
| final_results.append(result_item) | |
| return final_results | |