VLM-Lens / src /utils.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
"""Utility functions for interacting with the SQLite database."""
import io
import logging
import sqlite3
from typing import Any, List, Optional
import torch
def select_tensors(
db_path: str,
table_name: str,
keys: List[str] = ['layer', 'pooling_method', 'tensor_dim', 'tensor'],
sql_where: Optional[str] = None,
) -> List[Any]:
"""Select and return all tensors from the specified SQLite database and table.
Args:
db_path (str): Path to the SQLite database file.
table_name (str): Name of the table to query.
keys (List[str]): List of keys to select from the database.
sql_where (str): Optional SQL WHERE clause to filter results.
Returns:
List[Any]: A list of tensors retrieved from the database.
"""
if 'tensor' not in keys:
logging.warning("'tensor' key should be included to retrieve tensors; automatically adding it.")
keys.append('tensor')
final_results = []
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
query = f'SELECT {", ".join(keys)} FROM {table_name}'
if sql_where:
assert sql_where.strip().lower().startswith('where'), "sql_where should start with 'WHERE'"
query += f' {sql_where}'
cursor.execute(query)
results = cursor.fetchall()
for row in results:
result_item = {key: value for key, value in zip(keys, row)}
result_item['tensor'] = torch.load(io.BytesIO(result_item['tensor']), map_location='cpu')
final_results.append(result_item)
return final_results