"""Utility functions for interacting with the SQLite database.""" import io import logging import sqlite3 from typing import Any, List, Optional import torch def select_tensors( db_path: str, table_name: str, keys: List[str] = ['layer', 'pooling_method', 'tensor_dim', 'tensor'], sql_where: Optional[str] = None, ) -> List[Any]: """Select and return all tensors from the specified SQLite database and table. Args: db_path (str): Path to the SQLite database file. table_name (str): Name of the table to query. keys (List[str]): List of keys to select from the database. sql_where (str): Optional SQL WHERE clause to filter results. Returns: List[Any]: A list of tensors retrieved from the database. """ if 'tensor' not in keys: logging.warning("'tensor' key should be included to retrieve tensors; automatically adding it.") keys.append('tensor') final_results = [] with sqlite3.connect(db_path) as connection: cursor = connection.cursor() query = f'SELECT {", ".join(keys)} FROM {table_name}' if sql_where: assert sql_where.strip().lower().startswith('where'), "sql_where should start with 'WHERE'" query += f' {sql_where}' cursor.execute(query) results = cursor.fetchall() for row in results: result_item = {key: value for key, value in zip(keys, row)} result_item['tensor'] = torch.load(io.BytesIO(result_item['tensor']), map_location='cpu') final_results.append(result_item) return final_results