Ticio commited on
Commit
13769af
verified
1 Parent(s): 55b1ed1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +197 -170
inference.py CHANGED
@@ -1,171 +1,198 @@
1
- import vecs
2
- from dotenv import load_dotenv
3
- import os
4
- import threading
5
- import base64
6
- import os
7
- from google import genai
8
- from google.genai import types
9
- from sentence_transformers.SentenceTransformer import SentenceTransformer
10
-
11
- load_dotenv()
12
-
13
- user = os.getenv("user")
14
- password = os.getenv("password")
15
- host = os.getenv("host")
16
- port = os.getenv("port")
17
- db_name = "postgres"
18
- DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
19
- vx = vecs.create_client(DB_CONNECTION)
20
- model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs', device="cpu")
21
- client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
22
-
23
- def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = True, include_metadata=True, table = "2023"):
24
- query_embeds = vx.get_or_create_collection(name= table, dimension=384)
25
- ans = query_embeds.query(
26
- data=query,
27
- limit=limit,
28
- filters=filters,
29
- measure=measure,
30
- include_value=include_value,
31
- include_metadata=include_metadata,
32
- )
33
- return ans
34
-
35
- def sort_by_score(item):
36
- return item[1]
37
-
38
- def infa帽o(rad):
39
- a = int(rad[len(rad)-2::])
40
- if a > 89:
41
- return a + 1900
42
- else:
43
- return a + 2000
44
-
45
- def thread_query(query, target, year):
46
- return target.extend(query_db(query, table=str(year)))
47
-
48
-
49
- def vector_query(query, start = 1992, end = 2024):
50
- results = []
51
- vector_query = model.encode(query)
52
- threads = []
53
- for i in range(start, end + 1):
54
- t = threading.Thread(target=thread_query, args=(vector_query, results, i))
55
- threads.append(t)
56
- t.start()
57
- threads[-1].join()
58
- results.sort(key=sort_by_score)
59
- q = {}
60
- for i in results:
61
- if i[2]['sentencia'] not in q.keys():
62
- q[i[2]['sentencia']] = 1
63
- else:
64
- q[i[2]['sentencia']] += 1
65
- judgements = []
66
-
67
- for i in q.keys():
68
- if q[i] > 1:
69
- judgements.append(i)
70
- print(query, judgements)
71
- return judgements
72
-
73
- def context_builder_prompt_constructor(judgement):
74
- return judgement
75
-
76
- def context_builder(context_prompt, target):
77
- model = "gemini-2.5-flash-lite"
78
- contents = [
79
- types.Content(
80
- role="user",
81
- parts=[
82
- types.Part.from_text(text=context_prompt),
83
- ],
84
- ),
85
- ]
86
- tools = [
87
- types.Tool(googleSearch=types.GoogleSearch(
88
- )),]
89
- generate_content_config = types.GenerateContentConfig(
90
- thinking_config = types.ThinkingConfig(
91
- thinking_budget=0,
92
- ),
93
- tools=tools,
94
- system_instruction=[
95
- types.Part.from_text(text=f"""resume el contenido de la sentencia de forma detallada, mencionando todos los puntos considerados en la sentencia"""),
96
- ],
97
- )
98
-
99
- response = client.models.generate_content(
100
- model=model,
101
- contents=contents,
102
- config=generate_content_config,
103
- )
104
- return target.append(response.text)
105
-
106
- def context_draft(judgements, query):
107
- context = []
108
- threads = []
109
- for i in judgements:
110
- t = threading.Thread(target=context_builder, args=(context_builder_prompt_constructor(i), context))
111
- threads.append(t)
112
- t.start()
113
-
114
- while len(context) < len(threads):
115
- pass
116
-
117
- draft = ''
118
- for i in context:
119
- draft += i + '\n'
120
- return draft
121
-
122
- def generate(query, context, message_history):
123
- model = "gemini-2.5-flash-lite"
124
-
125
- # Convert Hugging Face style message history to Gemini API format
126
- gemini_contents = []
127
- for message in message_history:
128
- role = "user" if message["role"] == "user" else "model"
129
- gemini_contents.append(
130
- types.Content(
131
- role=role,
132
- parts=[types.Part.from_text(text=message["content"])],
133
- )
134
- )
135
-
136
- # Add the current user query to the contents
137
- gemini_contents.append(
138
- types.Content(
139
- role="user",
140
- parts=[
141
- types.Part.from_text(text=query),
142
- ],
143
- )
144
- )
145
-
146
-
147
- generate_content_config = types.GenerateContentConfig(
148
- thinking_config = types.ThinkingConfig(
149
- thinking_budget=0,
150
- ),
151
- system_instruction=[
152
- types.Part.from_text(text=f"""Eres Ticio un asistente de investigaci贸n de jurisprudencia colombiana. Tienes acceso a un contexto especialmente dise帽ado para esta conversaci贸n. Tu tarea es contestar a las preguntas del usuario referenciando siempre las sentencias de donde viene la informaci贸n como si fueras un investigador experto.
153
- {context}
154
-
155
- """)]
156
- )
157
-
158
- response = client.models.generate_content(
159
- model=model,
160
- contents=gemini_contents,
161
- config=generate_content_config,
162
- )
163
- return response.text
164
-
165
- def inference(query, history, context):
166
- if context == None or len(context) <= 0 or len(history) <= 0:
167
- vector_query_results = vector_query(query)
168
- context = context_draft(vector_query_results, query)
169
- return generate(query, context, history), context
170
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return generate(query, context, history), context
 
1
+ import vecs
2
+ from dotenv import load_dotenv
3
+ import os
4
+ import threading
5
+ import base64
6
+ import os
7
+ import threading
8
+ from google import genai
9
+ from google.genai import types
10
+ from sentence_transformers.SentenceTransformer import SentenceTransformer
11
+ import smtplib, ssl
12
+ from email.mime.text import MIMEText
13
+
14
+ load_dotenv()
15
+ user = os.getenv("user")
16
+ password = os.getenv("password")
17
+ host = os.getenv("host")
18
+ port = os.getenv("port")
19
+ mail = os.getenv("email")
20
+ APP_PASSWORD = os.getenv("app_password")
21
+ db_name = "postgres"
22
+ DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
23
+ vx = vecs.create_client(DB_CONNECTION)
24
+ model = SentenceTransformer('Snowflake/snowflake-arctic-embed-xs', device="cpu")
25
+ client = genai.Client(api_key=os.getenv('GEMINI_API_KEY'))
26
+
27
+
28
+
29
+ def feedback(query, context, history):
30
+ msg = MIMEText(f"""Hola, env铆o autom谩tico con SMTP.
31
+
32
+ query: {query}
33
+
34
+ context: {context}
35
+
36
+ history: {history}
37
+ """, "plain", "utf-8")
38
+ msg["Subject"] = "feedback"
39
+ msg["From"] = mail
40
+ msg["To"] = mail
41
+ try:
42
+ with smtplib.SMTP_SSL("smtp.gmail.com", 465, context=ssl.create_default_context()) as server:
43
+ server.login(mail, APP_PASSWORD)
44
+ server.send_message(msg)
45
+ except:
46
+ print(f'envio fallido {query}')
47
+
48
+ def query_db(query, limit = 5, filters = {}, measure = "cosine_distance", include_value = True, include_metadata=True, table = "2023"):
49
+ query_embeds = vx.get_or_create_collection(name= table, dimension=384)
50
+ ans = query_embeds.query(
51
+ data=query,
52
+ limit=limit,
53
+ filters=filters,
54
+ measure=measure,
55
+ include_value=include_value,
56
+ include_metadata=include_metadata,
57
+ )
58
+ return ans
59
+
60
+ def sort_by_score(item):
61
+ return item[1]
62
+
63
+ def infa帽o(rad):
64
+ a = int(rad[len(rad)-2::])
65
+ if a > 89:
66
+ return a + 1900
67
+ else:
68
+ return a + 2000
69
+
70
+ def thread_query(query, target, year):
71
+ return target.extend(query_db(query, table=str(year)))
72
+
73
+
74
+ def vector_query(query, start = 1992, end = 2024):
75
+ results = []
76
+ vector_query = model.encode(query)
77
+ threads = []
78
+ for i in range(start, end + 1):
79
+ t = threading.Thread(target=thread_query, args=(vector_query, results, i))
80
+ threads.append(t)
81
+ t.start()
82
+ threads[-1].join()
83
+ results.sort(key=sort_by_score)
84
+ q = {}
85
+ for i in results:
86
+ if i[2]['sentencia'] not in q.keys():
87
+ q[i[2]['sentencia']] = 1
88
+ else:
89
+ q[i[2]['sentencia']] += 1
90
+ judgements = []
91
+
92
+ for i in q.keys():
93
+ if q[i] > 1:
94
+ judgements.append(i)
95
+ print(query, judgements)
96
+ return judgements
97
+
98
+ def context_builder_prompt_constructor(judgement):
99
+ return judgement
100
+
101
+ def context_builder(context_prompt, target):
102
+ model = "gemini-2.5-flash-lite"
103
+ contents = [
104
+ types.Content(
105
+ role="user",
106
+ parts=[
107
+ types.Part.from_text(text=context_prompt),
108
+ ],
109
+ ),
110
+ ]
111
+ tools = [
112
+ types.Tool(googleSearch=types.GoogleSearch(
113
+ )),]
114
+ generate_content_config = types.GenerateContentConfig(
115
+ thinking_config = types.ThinkingConfig(
116
+ thinking_budget=0,
117
+ ),
118
+ tools=tools,
119
+ system_instruction=[
120
+ types.Part.from_text(text=f"""resume el contenido de la sentencia de forma detallada, mencionando todos los puntos considerados en la sentencia"""),
121
+ ],
122
+ )
123
+
124
+ response = client.models.generate_content(
125
+ model=model,
126
+ contents=contents,
127
+ config=generate_content_config,
128
+ )
129
+ return target.append(response.text)
130
+
131
+ def context_draft(judgements, query):
132
+ context = []
133
+ threads = []
134
+ for i in judgements:
135
+ t = threading.Thread(target=context_builder, args=(context_builder_prompt_constructor(i), context))
136
+ threads.append(t)
137
+ t.start()
138
+
139
+ while len(context) < len(threads):
140
+ pass
141
+
142
+ draft = ''
143
+ for i in context:
144
+ draft += i + '\n'
145
+ return draft
146
+
147
+ def generate(query, context, message_history):
148
+ model = "gemini-2.5-flash-lite"
149
+
150
+ # Convert Hugging Face style message history to Gemini API format
151
+ gemini_contents = []
152
+ for message in message_history:
153
+ role = "user" if message["role"] == "user" else "model"
154
+ gemini_contents.append(
155
+ types.Content(
156
+ role=role,
157
+ parts=[types.Part.from_text(text=message["content"])],
158
+ )
159
+ )
160
+
161
+ # Add the current user query to the contents
162
+ gemini_contents.append(
163
+ types.Content(
164
+ role="user",
165
+ parts=[
166
+ types.Part.from_text(text=query),
167
+ ],
168
+ )
169
+ )
170
+
171
+
172
+ generate_content_config = types.GenerateContentConfig(
173
+ thinking_config = types.ThinkingConfig(
174
+ thinking_budget=0,
175
+ ),
176
+ system_instruction=[
177
+ types.Part.from_text(text=f"""Eres Ticio un asistente de investigaci贸n de jurisprudencia colombiana. Tienes acceso a un contexto especialmente dise帽ado para esta conversaci贸n. Tu tarea es contestar a las preguntas del usuario referenciando siempre las sentencias de donde viene la informaci贸n como si fueras un investigador experto.
178
+ {context}
179
+
180
+ """)]
181
+ )
182
+
183
+ response = client.models.generate_content(
184
+ model=model,
185
+ contents=gemini_contents,
186
+ config=generate_content_config,
187
+ )
188
+ return response.text
189
+
190
+ def inference(query, history, context):
191
+ if context == None or len(context) <= 0 or len(history) <= 0:
192
+ vector_query_results = vector_query(query)
193
+ context = context_draft(vector_query_results, query)
194
+ t = threading.Thread(target=feedback, args=(query, context, history))
195
+ t.start()
196
+ return generate(query, context, history), context
197
+ else:
198
  return generate(query, context, history), context