audit_assistant / utils.py
akryldigital's picture
Pilot (#2)
92633a7 verified
import json
import dataclasses
from uuid import UUID
from typing import Any
from datetime import datetime, date
import configparser
from torch import cuda
from qdrant_client.http import models as rest
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
def get_config(fp):
config = configparser.ConfigParser()
config.read_file(open(fp))
return config
def get_embeddings_model(config):
device = "cuda" if cuda.is_available() else "cpu"
# Define embedding model
model_name = config.get("retriever", "MODEL")
model_kwargs = {"device": device}
normalize_embeddings = bool(int(config.get("retriever", "NORMALIZE")))
encode_kwargs = {
"normalize_embeddings": normalize_embeddings,
"batch_size": 100,
}
embeddings = HuggingFaceEmbeddings(
show_progress=True,
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embeddings
# Create a search filter for Qdrant
def create_filter(
reports: list = [], sources: str = None, subtype: str = None, year: str = None
):
if len(reports) == 0:
print(f"defining filter for sources:{sources}, subtype:{subtype}")
filter = rest.Filter(
must=[
rest.FieldCondition(
key="metadata.source", match=rest.MatchValue(value=sources)
),
rest.FieldCondition(
key="metadata.filename", match=rest.MatchAny(any=subtype)
),
# rest.FieldCondition(
# key="metadata.year",
# match=rest.MatchAny(any=year)
]
)
else:
print(f"defining filter for allreports:{reports}")
filter = rest.Filter(
must=[
rest.FieldCondition(
key="metadata.filename", match=rest.MatchAny(any=reports)
)
]
)
return filter
def load_json(fp):
with open(fp, "r") as f:
docs = json.load(f)
return docs
def get_timestamp():
now = datetime.datetime.now()
timestamp = now.strftime("%Y%m%d%H%M%S")
return timestamp
# A custom class to help with recursive serialization.
# This approach avoids modifying the original object.
class _RecursiveSerializer(json.JSONEncoder):
"""A custom JSONEncoder that handles complex types by converting them to dicts or strings."""
def default(self, obj):
# Prefer the pydantic method if it exists for the most robust serialization.
if hasattr(obj, 'model_dump'):
return obj.model_dump()
# Handle dataclasses
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
# Handle other non-serializable but common types.
if isinstance(obj, (datetime, date, UUID)):
return str(obj)
# Fallback for general objects with a __dict__
if hasattr(obj, '__dict__'):
return obj.__dict__
# Default fallback to JSONEncoder's behavior
return super().default(obj)
def to_json_string(obj: Any, **kwargs) -> str:
"""
Serializes a Python object into a JSON-formatted string.
This function is a comprehensive utility that can handle:
- Standard Python types (lists, dicts, strings, numbers, bools, None).
- Pydantic models (using `model_dump()`).
- Dataclasses (using `dataclasses.asdict()`).
- Standard library types not natively JSON-serializable (e.g., datetime, UUID).
- Custom classes with a `__dict__`.
Args:
obj (Any): The Python object to serialize.
**kwargs: Additional keyword arguments to pass to `json.dumps`.
Returns:
str: A JSON-formatted string.
Example:
>>> from datetime import datetime
>>> from pydantic import BaseModel
>>> from dataclasses import dataclass
>>> class Address(BaseModel):
... street: str
... city: str
>>> @dataclass
... class Product:
... id: int
... name: str
>>> class Order(BaseModel):
... user_address: Address
... item: Product
>>> order_obj = Order(
... user_address=Address(street="123 Main St", city="Example City"),
... item=Product(id=1, name="Laptop")
... )
>>> print(to_json_string(order_obj, indent=2))
{
"user_address": {
"street": "123 Main St",
"city": "Example City"
},
"item": {
"id": 1,
"name": "Laptop"
}
}
"""
return json.dumps(obj, cls=_RecursiveSerializer, **kwargs)