TicioProd / inference.py
Ticio's picture
Upload 2 files
fb073ad verified
raw
history blame
5.12 kB
import vecs
from dotenv import load_dotenv
import os
import threading
import base64
import os
from google import genai
from google.genai import types
from sentence_transformers.SentenceTransformer import SentenceTransformer
load_dotenv()
user = os.getenv("user")
password = os.getenv("password")
host = os.getenv("host")
port = os.getenv("port")
db_name = "postgres"
DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
vx = vecs.create_client(DB_CONNECTION)
model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs', device="cpu")
client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = True, include_metadata=True, table = "2023"):
query_embeds = vx.get_or_create_collection(name= table, dimension=384)
ans = query_embeds.query(
data=query,
limit=limit,
filters=filters,
measure=measure,
include_value=include_value,
include_metadata=include_metadata,
)
return ans
def sort_by_score(item):
return item[1]
def infa帽o(rad):
a = int(rad[len(rad)-2::])
if a > 89:
return a + 1900
else:
return a + 2000
def thread_query(query, target, year):
return target.extend(query_db(query, table=str(year)))
def vector_query(query, start = 1992, end = 2024):
results = []
vector_query = model.encode(query)
threads = []
for i in range(start, end + 1):
t = threading.Thread(target=thread_query, args=(vector_query, results, i))
threads.append(t)
t.start()
threads[-1].join()
results.sort(key=sort_by_score)
q = {}
for i in results:
if i[2]['sentencia'] not in q.keys():
q[i[2]['sentencia']] = 1
else:
q[i[2]['sentencia']] += 1
judgements = []
for i in q.keys():
if q[i] > 1:
judgements.append(i)
print(query, judgements)
return judgements
def context_builder_prompt_constructor(judgement):
return judgement
def context_builder(context_prompt, target):
model = "gemini-2.5-flash-lite"
contents = [
types.Content(
role="user",
parts=[
types.Part.from_text(text=context_prompt),
],
),
]
tools = [
types.Tool(googleSearch=types.GoogleSearch(
)),]
generate_content_config = types.GenerateContentConfig(
thinking_config = types.ThinkingConfig(
thinking_budget=0,
),
tools=tools,
system_instruction=[
types.Part.from_text(text=f"""resume el contenido de la sentencia de forma detallada, mencionando todos los puntos considerados en la sentencia"""),
],
)
response = client.models.generate_content(
model=model,
contents=contents,
config=generate_content_config,
)
return target.append(response.text)
def context_draft(judgements, query):
context = []
threads = []
for i in judgements:
t = threading.Thread(target=context_builder, args=(context_builder_prompt_constructor(i), context))
threads.append(t)
t.start()
while len(context) < len(threads):
pass
draft = ''
for i in context:
draft += i + '\n'
return draft
def generate(query, context, message_history):
model = "gemini-2.5-flash-lite"
# Convert Hugging Face style message history to Gemini API format
gemini_contents = []
for message in message_history:
role = "user" if message["role"] == "user" else "model"
gemini_contents.append(
types.Content(
role=role,
parts=[types.Part.from_text(text=message["content"])],
)
)
# Add the current user query to the contents
gemini_contents.append(
types.Content(
role="user",
parts=[
types.Part.from_text(text=query),
],
)
)
generate_content_config = types.GenerateContentConfig(
thinking_config = types.ThinkingConfig(
thinking_budget=0,
),
system_instruction=[
types.Part.from_text(text=f"""Eres Ticio un asistente de investigaci贸n de jurisprudencia colombiana. Tienes acceso a un contexto especialmente dise帽ado para esta conversaci贸n. Tu tarea es contestar a las preguntas del usuario referenciando siempre las sentencias de donde viene la informaci贸n como si fueras un investigador experto.
{context}
""")]
)
response = client.models.generate_content(
model=model,
contents=gemini_contents,
config=generate_content_config,
)
return response.text
def inference(query, history, context):
if context == None or len(context) <= 0 or len(history) <= 0:
vector_query_results = vector_query(query)
context = context_draft(vector_query_results, query)
return generate(query, context, history), context
else:
return generate(query, context, history), context