ALVHB95 commited on
Commit
77dfcc0
·
1 Parent(s): b745bb1
Files changed (1) hide show
  1. app.py +58 -91
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  =========================================================
3
- app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2/0.3)
4
  =========================================================
5
  """
6
 
@@ -8,13 +8,15 @@ import os
8
  import json
9
  import shutil
10
 
11
- # --- Ajustes de entorno / telemetría (antes de importar Chroma) ---
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
14
  os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
15
  os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
16
- # Silenciar telemetría de Chroma para evitar warnings/tracebacks ruidosos
17
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "FALSE")
 
 
 
18
 
19
  import gradio as gr
20
  import tensorflow as tf
@@ -29,38 +31,32 @@ except Exception:
29
  user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
30
  header_template = {"User-Agent": user_agent}
31
 
32
- # --- LangChain v0.2/0.3 family ---
33
  from langchain_text_splitters import RecursiveCharacterTextSplitter
34
  from langchain_core.prompts import ChatPromptTemplate
35
- from langchain_core.output_parsers import PydanticOutputParser
36
  from langchain.chains import ConversationalRetrievalChain
37
  from langchain.memory import ConversationBufferMemory
38
  from langchain_community.document_loaders import WebBaseLoader
39
  from langchain_community.vectorstores import Chroma
40
 
41
- # Embeddings (prefer langchain-huggingface si está instal., si no community)
42
  try:
43
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
44
  except ImportError:
45
  from langchain_community.embeddings import HuggingFaceEmbeddings
46
 
47
- # Context compression / retrievers
48
- from langchain.retrievers import ContextualCompressionRetriever
49
- from langchain.retrievers.document_compressors import DocumentCompressorPipeline
 
50
 
51
- # --- Retrievers avanzados / reranker ---
52
  from langchain_community.retrievers import BM25Retriever
53
- from langchain.retrievers import EnsembleRetriever
54
- from langchain.retrievers.multi_query import MultiQueryRetriever
55
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
56
- from langchain.retrievers.document_compressors import CrossEncoderReranker
57
 
58
- from pydantic import BaseModel, Field # Pydantic v2
59
-
60
- # HF Hub para descargar el SavedModel de imagen
61
  from huggingface_hub import snapshot_download
62
 
63
- # === LLM endpoint moderno (langchain-huggingface) ===
64
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
65
 
66
  # Theming + URL list
@@ -72,27 +68,21 @@ theme = theme.Theme()
72
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
73
  # =========================================================
74
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
75
- MODEL_SERVING_SIGNATURE = "serving_default" # ajusta si el modelo expone otra firma
76
 
77
- # Descarga el snapshot y envuélvelo con TFSMLayer (compatible Keras 3)
78
  model_dir = snapshot_download(MODEL_REPO)
79
  image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
80
 
81
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
82
 
83
  def predict_image(input_image: Image.Image):
84
- """Preprocesa a EfficientNetB0 (224x224) y ejecuta inferencia."""
85
  img = input_image.convert("RGB").resize((224, 224))
86
  x = tf.keras.preprocessing.image.img_to_array(img)
87
  x = tf.keras.applications.efficientnet.preprocess_input(x)
88
  x = tf.expand_dims(x, 0)
89
 
90
  outputs = image_model(x)
91
- if isinstance(outputs, dict) and outputs:
92
- preds = outputs[next(iter(outputs))]
93
- else:
94
- preds = outputs
95
-
96
  arr = preds.numpy() if hasattr(preds, "numpy") else preds
97
  probs = arr[0].tolist()
98
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
@@ -126,7 +116,6 @@ def safe_load_all_urls(urls):
126
 
127
  all_loaded_docs = safe_load_all_urls(URLS)
128
 
129
- # Chunks algo más largos (mejor para reranker)
130
  base_splitter = RecursiveCharacterTextSplitter(
131
  chunk_size=900,
132
  chunk_overlap=100,
@@ -134,7 +123,7 @@ base_splitter = RecursiveCharacterTextSplitter(
134
  )
135
  docs = base_splitter.split_documents(all_loaded_docs)
136
 
137
- # Embeddings MEJORADOS (recuperación)
138
  embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-base-v2")
139
 
140
  # Vector store
@@ -146,22 +135,28 @@ vectordb = Chroma.from_documents(
146
  persist_directory=persist_directory,
147
  )
148
 
149
- # Base retriever (vectorial)
150
  vec_retriever = vectordb.as_retriever(search_kwargs={"k": 8}, search_type="mmr")
151
 
152
- # BM25 + Ensemble (híbrido)
153
- bm25 = BM25Retriever.from_documents(docs)
154
- bm25.k = 8
155
- hybrid_retriever = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6])
156
-
157
- # --- Multi-Query (paráfrasis de la consulta) ---
158
- # Se apoya en el propio LLM para generar variantes y subir recall
159
- # (lo definimos después de crear el LLM, ver sección 4)
160
-
161
- # --- Compresión / split fino para compresor downstream ---
 
 
 
 
 
 
162
  try:
163
  from langchain_text_splitters import TokenTextSplitter
164
- splitter_for_compression = TokenTextSplitter(chunk_size=220, chunk_overlap=30) # requiere tiktoken
165
  except Exception:
166
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
167
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
@@ -169,34 +164,24 @@ except Exception:
169
  compressor_pipeline = DocumentCompressorPipeline(transformers=[splitter_for_compression])
170
 
171
  # ======================================
172
- # 3) PROMPT & Pydantic schema parsing
173
  # ======================================
174
- class FinalAnswer(BaseModel):
175
- question: str = Field(description="User question")
176
- answer: str = Field(description="Direct answer")
177
-
178
- parser = PydanticOutputParser(pydantic_object=FinalAnswer)
179
-
180
  SYSTEM_TEMPLATE = (
181
  "Eres Greta, una asistente bilingüe (ES/EN) experta en reciclaje y sostenibilidad. "
182
- "Responde de forma directa, útil y en el idioma del usuario. "
183
- "Si la respuesta no aparece en los fragmentos, dilo explícitamente y ofrece pasos prácticos. "
184
- "No inventes datos.\n\n"
185
- "Fragmentos:\n{context}\n\n"
186
- "Pregunta: {question}\n"
187
- "{format_instructions}"
188
- )
189
-
190
- qa_prompt = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE).partial(
191
- format_instructions=parser.get_format_instructions()
192
  )
 
193
 
194
  # ===========================================
195
  # 4) LLM — Hugging Face Inference (Llama 3.1 8B)
196
  # ===========================================
197
  endpoint = HuggingFaceEndpoint(
198
  repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
199
- task="text-generation", # estable para chat via HF Inference
200
  max_new_tokens=900,
201
  temperature=0.2,
202
  top_k=40,
@@ -206,28 +191,23 @@ endpoint = HuggingFaceEndpoint(
206
  timeout=120,
207
  model_kwargs={},
208
  )
209
-
210
- # OJO: usar llm= (no client=)
211
  llm = ChatHuggingFace(llm=endpoint)
212
 
213
  # ===========================================
214
- # 5) Chain (memory + RAG mejorado + robust JSON)
215
  # ===========================================
216
-
217
- # Memoria (aviso deprec., pero funcional en LC 0.3)
218
  memory = ConversationBufferMemory(
219
  memory_key="chat_history",
220
  return_messages=True,
221
  )
222
 
223
- # Multi-Query sobre el retriever híbrido
224
- mqr = MultiQueryRetriever.from_llm(retriever=hybrid_retriever, llm=llm, include_original=True)
225
 
226
- # Reranker más ligero (reduce coste latencia)
227
  cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
228
  reranker = CrossEncoderReranker(model=cross_encoder, top_n=4)
229
 
230
- # Compresor contextual (híbrido + multi-query → rerank → compresión fina)
231
  compression_retriever = ContextualCompressionRetriever(
232
  base_retriever=mqr,
233
  base_compressor=reranker,
@@ -241,29 +221,19 @@ qa_chain = ConversationalRetrievalChain.from_llm(
241
  combine_docs_chain_kwargs={"prompt": qa_prompt},
242
  get_chat_history=lambda h: h,
243
  rephrase_question=False,
244
- output_key="output",
 
245
  )
246
 
247
- def _safe_json_extract(raw: str, question: str) -> dict:
248
- """Intenta JSON estricto; si falla, extrae el primer {...}; si no, texto plano."""
249
- raw = (raw or "").strip()
250
- try:
251
- return json.loads(raw)
252
- except json.JSONDecodeError:
253
- start = raw.find("{")
254
- end = raw.rfind("}")
255
- if start != -1 and end != -1 and end > start:
256
- try:
257
- return json.loads(raw[start : end + 1])
258
- except json.JSONDecodeError:
259
- pass
260
- return {"question": question, "answer": raw or "No answer produced."}
261
-
262
  def chat_interface(question, history):
263
  try:
264
  result = qa_chain.invoke({"question": question})
265
- payload = _safe_json_extract(result.get("output", ""), question)
266
- return payload.get("answer", "")
 
 
 
 
267
  except Exception as e:
268
  return (
269
  "Lo siento, tuve un problema procesando tu pregunta. "
@@ -300,13 +270,10 @@ banner_tab = gr.Markdown(banner_tab_content)
300
  # ============================
301
  # 7) Gradio app (tabs + run)
302
  # ============================
303
-
304
- # CSS simple para “ampliar visualmente” el área del chat sin usar height=
305
  custom_css = """
306
- /* Aumenta altura mínima del contenedor de mensajes del chatbot */
307
- .gr-chatbot { min-height: 520px !important; }
308
- .gr-chatbot > div { min-height: 520px !important; }
309
- /* Un poco más de ancho general */
310
  .gradio-container { max-width: 1200px !important; }
311
  """
312
 
@@ -320,7 +287,7 @@ app = gr.TabbedInterface(
320
  [banner_tab, image_gradio_app, chatbot_gradio_app],
321
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
322
  theme=theme,
323
- css=custom_css, # aplica CSS globalmente a las pestañas
324
  )
325
 
326
  app.queue()
 
1
  """
2
  =========================================================
3
+ app.py — Green Greta (Gradio + TF/Keras 3 + LangChain 0.3)
4
  =========================================================
5
  """
6
 
 
8
  import json
9
  import shutil
10
 
11
+ # --- Env / telemetry (set before imports that use them) ---
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
  os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
14
  os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
15
  os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
 
16
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "FALSE")
17
+ os.environ.setdefault("USER_AGENT", "green-greta/1.0 (+contact-or-repo)")
18
+ # If you want deterministic CPU math from TF (optional):
19
+ # os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
20
 
21
  import gradio as gr
22
  import tensorflow as tf
 
31
  user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
32
  header_template = {"User-Agent": user_agent}
33
 
34
+ # --- LangChain core ---
35
  from langchain_text_splitters import RecursiveCharacterTextSplitter
36
  from langchain_core.prompts import ChatPromptTemplate
 
37
  from langchain.chains import ConversationalRetrievalChain
38
  from langchain.memory import ConversationBufferMemory
39
  from langchain_community.document_loaders import WebBaseLoader
40
  from langchain_community.vectorstores import Chroma
41
 
42
+ # Embeddings
43
  try:
44
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
45
  except ImportError:
46
  from langchain_community.embeddings import HuggingFaceEmbeddings
47
 
48
+ # Retrieval utilities
49
+ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
50
+ from langchain.retrievers.document_compressors import DocumentCompressorPipeline, CrossEncoderReranker
51
+ from langchain.retrievers.multi_query import MultiQueryRetriever
52
 
 
53
  from langchain_community.retrievers import BM25Retriever
 
 
54
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
 
55
 
56
+ # HF Hub for SavedModel
 
 
57
  from huggingface_hub import snapshot_download
58
 
59
+ # LLM via HF Inference
60
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
61
 
62
  # Theming + URL list
 
68
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
69
  # =========================================================
70
  MODEL_REPO = "rocioadlc/efficientnetB0_trash"
71
+ MODEL_SERVING_SIGNATURE = "serving_default"
72
 
 
73
  model_dir = snapshot_download(MODEL_REPO)
74
  image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
75
 
76
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
77
 
78
  def predict_image(input_image: Image.Image):
 
79
  img = input_image.convert("RGB").resize((224, 224))
80
  x = tf.keras.preprocessing.image.img_to_array(img)
81
  x = tf.keras.applications.efficientnet.preprocess_input(x)
82
  x = tf.expand_dims(x, 0)
83
 
84
  outputs = image_model(x)
85
+ preds = outputs[next(iter(outputs))] if isinstance(outputs, dict) and outputs else outputs
 
 
 
 
86
  arr = preds.numpy() if hasattr(preds, "numpy") else preds
87
  probs = arr[0].tolist()
88
  return {label: float(probs[i]) for i, label in enumerate(class_labels)}
 
116
 
117
  all_loaded_docs = safe_load_all_urls(URLS)
118
 
 
119
  base_splitter = RecursiveCharacterTextSplitter(
120
  chunk_size=900,
121
  chunk_overlap=100,
 
123
  )
124
  docs = base_splitter.split_documents(all_loaded_docs)
125
 
126
+ # Embeddings (better recall)
127
  embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-base-v2")
128
 
129
  # Vector store
 
135
  persist_directory=persist_directory,
136
  )
137
 
138
+ # Vector retriever
139
  vec_retriever = vectordb.as_retriever(search_kwargs={"k": 8}, search_type="mmr")
140
 
141
+ # BM25 + Ensemble with safe fallback if rank-bm25 isn't installed
142
+ use_bm25 = True
143
+ try:
144
+ bm25 = BM25Retriever.from_documents(docs) # requires rank-bm25
145
+ bm25.k = 8
146
+ except Exception as e:
147
+ print(f"[RAG] BM25 unavailable ({e}). Falling back to vector-only retriever.")
148
+ use_bm25 = False
149
+ bm25 = None
150
+
151
+ if use_bm25:
152
+ base_retriever = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6])
153
+ else:
154
+ base_retriever = vec_retriever
155
+
156
+ # Fine-grained compressor (splitter)
157
  try:
158
  from langchain_text_splitters import TokenTextSplitter
159
+ splitter_for_compression = TokenTextSplitter(chunk_size=220, chunk_overlap=30) # needs tiktoken
160
  except Exception:
161
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
162
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
 
164
  compressor_pipeline = DocumentCompressorPipeline(transformers=[splitter_for_compression])
165
 
166
  # ======================================
167
+ # 3) PROMPT (NO JSON INSTRUCTIONS)
168
  # ======================================
 
 
 
 
 
 
169
  SYSTEM_TEMPLATE = (
170
  "Eres Greta, una asistente bilingüe (ES/EN) experta en reciclaje y sostenibilidad. "
171
+ "Responde en el idioma del usuario, de forma directa, práctica y basada en los fragmentos. "
172
+ "Si la información no está en los fragmentos, dilo claramente y sugiere pasos útiles. "
173
+ "No inventes datos ni menciones la palabra 'fragmentos'.\n\n"
174
+ "{context}\n\n"
175
+ "Pregunta: {question}"
 
 
 
 
 
176
  )
177
+ qa_prompt = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE)
178
 
179
  # ===========================================
180
  # 4) LLM — Hugging Face Inference (Llama 3.1 8B)
181
  # ===========================================
182
  endpoint = HuggingFaceEndpoint(
183
  repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
184
+ task="text-generation",
185
  max_new_tokens=900,
186
  temperature=0.2,
187
  top_k=40,
 
191
  timeout=120,
192
  model_kwargs={},
193
  )
 
 
194
  llm = ChatHuggingFace(llm=endpoint)
195
 
196
  # ===========================================
197
+ # 5) Chain (memory + Multi-Query + reranker + compression)
198
  # ===========================================
 
 
199
  memory = ConversationBufferMemory(
200
  memory_key="chat_history",
201
  return_messages=True,
202
  )
203
 
204
+ # Multi-Query to boost recall
205
+ mqr = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=llm, include_original=True)
206
 
207
+ # Cross-encoder reranker (lighter)
208
  cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
209
  reranker = CrossEncoderReranker(model=cross_encoder, top_n=4)
210
 
 
211
  compression_retriever = ContextualCompressionRetriever(
212
  base_retriever=mqr,
213
  base_compressor=reranker,
 
221
  combine_docs_chain_kwargs={"prompt": qa_prompt},
222
  get_chat_history=lambda h: h,
223
  rephrase_question=False,
224
+ return_source_documents=False, # <- we only need the final answer
225
+ # Use default output key "answer" so we don't need to parse JSON
226
  )
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def chat_interface(question, history):
229
  try:
230
  result = qa_chain.invoke({"question": question})
231
+ # ConversationalRetrievalChain returns {"answer": "...", ...}
232
+ answer = result.get("answer", "")
233
+ # Safety fallback: if empty, return a friendly default
234
+ if not answer:
235
+ return "Lo siento, no pude generar una respuesta útil con los fragmentos disponibles."
236
+ return answer
237
  except Exception as e:
238
  return (
239
  "Lo siento, tuve un problema procesando tu pregunta. "
 
270
  # ============================
271
  # 7) Gradio app (tabs + run)
272
  # ============================
 
 
273
  custom_css = """
274
+ /* Make the chat area taller without using the height arg */
275
+ .gr-chatbot { min-height: 700px !important; }
276
+ .gr-chatbot > div { min-height: 700px !important; }
 
277
  .gradio-container { max-width: 1200px !important; }
278
  """
279
 
 
287
  [banner_tab, image_gradio_app, chatbot_gradio_app],
288
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
289
  theme=theme,
290
+ css=custom_css,
291
  )
292
 
293
  app.queue()