ALVHB95 commited on
Commit
d95504a
·
1 Parent(s): d13e610
Files changed (1) hide show
  1. app.py +92 -60
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  =========================================================
3
- app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2)
4
  =========================================================
5
  """
6
 
@@ -8,6 +8,14 @@ import os
8
  import json
9
  import shutil
10
 
 
 
 
 
 
 
 
 
11
  import gradio as gr
12
  import tensorflow as tf
13
  from tensorflow import keras
@@ -18,12 +26,10 @@ try:
18
  from fake_useragent import UserAgent
19
  user_agent = UserAgent().random
20
  except Exception:
21
- user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "\
22
- "(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
23
  header_template = {"User-Agent": user_agent}
24
 
25
-
26
- # --- LangChain v0.2 family ---
27
  from langchain_text_splitters import RecursiveCharacterTextSplitter
28
  from langchain_core.prompts import ChatPromptTemplate
29
  from langchain_core.output_parsers import PydanticOutputParser
@@ -32,37 +38,35 @@ from langchain.memory import ConversationBufferMemory
32
  from langchain_community.document_loaders import WebBaseLoader
33
  from langchain_community.vectorstores import Chroma
34
 
35
- # Embeddings (prefer langchain-huggingface if installed; fallback a community)
36
  try:
37
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
38
  except ImportError:
39
  from langchain_community.embeddings import HuggingFaceEmbeddings
40
 
41
- # Context compression (keeps inputs ≤ model limit)
42
  from langchain.retrievers import ContextualCompressionRetriever
43
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
44
 
45
- from pydantic import BaseModel, Field # <-- switched to Pydantic v2
 
 
 
 
46
 
47
- # HF Hub for downloading the SavedModel once (image classifier)
 
 
48
  from huggingface_hub import snapshot_download
49
 
50
- # === LLM endpoint moderno (compatible con huggingface_hub>=0.23) ===
51
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
52
 
53
  # Theming + URL list
54
  import theme
55
  from url_list import URLS
56
-
57
  theme = theme.Theme()
58
 
59
- # (Opcional) reducir telemetría/ruido en logs de Space
60
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
61
- os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
62
- os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
63
- os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
64
-
65
-
66
  # =========================================================
67
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
68
  # =========================================================
@@ -76,7 +80,7 @@ image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGN
76
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
77
 
78
  def predict_image(input_image: Image.Image):
79
- """Preprocess a EfficientNetB0 (224x224) y ejecuta inferencia."""
80
  img = input_image.convert("RGB").resize((224, 224))
81
  x = tf.keras.preprocessing.image.img_to_array(img)
82
  x = tf.keras.applications.efficientnet.preprocess_input(x)
@@ -100,7 +104,6 @@ image_gradio_app = gr.Interface(
100
  theme=theme,
101
  )
102
 
103
-
104
  # ============================================
105
  # 2) KNOWLEDGE LOADING (RAG: loader + splitter)
106
  # ============================================
@@ -122,44 +125,47 @@ def safe_load_all_urls(urls):
122
 
123
  all_loaded_docs = safe_load_all_urls(URLS)
124
 
125
- # Chunks base pequeños para que el compresor downstream trabaje menos
126
  base_splitter = RecursiveCharacterTextSplitter(
127
- chunk_size=700,
128
- chunk_overlap=80,
129
  length_function=len,
130
  )
131
  docs = base_splitter.split_documents(all_loaded_docs)
132
 
133
- # Embeddings
134
- embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
135
 
136
  # Vector store
137
  persist_directory = "docs/chroma/"
138
  shutil.rmtree(persist_directory, ignore_errors=True)
139
-
140
  vectordb = Chroma.from_documents(
141
  documents=docs,
142
  embedding=embeddings,
143
  persist_directory=persist_directory,
144
  )
145
 
146
- # Base retriever
147
- retriever = vectordb.as_retriever(search_kwargs={"k": 2}, search_type="mmr")
 
 
 
 
 
 
 
 
 
148
 
149
- # --- Compresión de contexto para entradas ~512 tokens (t5/…); útil igual con Mixtral ---
150
  try:
151
  from langchain_text_splitters import TokenTextSplitter
152
- splitter_for_compression = TokenTextSplitter(chunk_size=200, chunk_overlap=30) # requiere tiktoken
153
  except Exception:
154
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
155
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
156
 
157
- compressor = DocumentCompressorPipeline(transformers=[splitter_for_compression])
158
- compression_retriever = ContextualCompressionRetriever(
159
- base_retriever=retriever,
160
- base_compressor=compressor,
161
- )
162
-
163
 
164
  # ======================================
165
  # 3) PROMPT & Pydantic schema parsing
@@ -171,44 +177,61 @@ class FinalAnswer(BaseModel):
171
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
172
 
173
  SYSTEM_TEMPLATE = (
174
- "You are Greta, a bilingual (EN/ES) recycling assistant. "
175
- "Answer fully using the snippets below. Do not mention 'context'.\n\n"
176
- "Context:\n{context}\n\n"
177
- "User: {question}\n"
 
 
178
  "{format_instructions}"
179
  )
180
 
181
- qa_prompt = ChatPromptTemplate.from_template(
182
- SYSTEM_TEMPLATE,
183
- partial_variables={"format_instructions": parser.get_format_instructions()},
184
  )
185
 
186
-
187
-
188
- # 4) LLM — Hugging Face Inference API (Llama 3 chat)
189
  endpoint = HuggingFaceEndpoint(
190
  repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
191
- task="conversational", # <-- importante
192
- max_new_tokens=2000,
193
- temperature=0.1,
194
- top_k=30,
195
- repetition_penalty=1.03,
196
  return_full_text=False,
197
  huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
198
  timeout=120,
199
  model_kwargs={},
200
  )
201
 
 
202
  llm = ChatHuggingFace(llm=endpoint)
203
 
204
  # ===========================================
205
- # 5) Chain (memory + robust JSON extraction)
206
  # ===========================================
 
 
207
  memory = ConversationBufferMemory(
208
  memory_key="chat_history",
209
  return_messages=True,
210
  )
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  qa_chain = ConversationalRetrievalChain.from_llm(
213
  llm=llm,
214
  retriever=compression_retriever,
@@ -247,13 +270,6 @@ def chat_interface(question, history):
247
  f"Detalle técnico: {e}"
248
  )
249
 
250
- chatbot_gradio_app = gr.ChatInterface(
251
- fn=chat_interface,
252
- title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
253
- height=600,
254
- )
255
-
256
-
257
  # ============================
258
  # 6) Banner / Welcome content
259
  # ============================
@@ -280,14 +296,30 @@ banner_tab_content = """
280
  """
281
  banner_tab = gr.Markdown(banner_tab_content)
282
 
283
-
284
  # ============================
285
  # 7) Gradio app (tabs + run)
286
  # ============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  app = gr.TabbedInterface(
288
  [banner_tab, image_gradio_app, chatbot_gradio_app],
289
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
290
  theme=theme,
 
291
  )
292
 
293
  app.queue()
 
1
  """
2
  =========================================================
3
+ app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2/0.3)
4
  =========================================================
5
  """
6
 
 
8
  import json
9
  import shutil
10
 
11
+ # --- Ajustes de entorno / telemetría (antes de importar Chroma) ---
12
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
14
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
15
+ os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
16
+ # Silenciar telemetría de Chroma para evitar warnings/tracebacks ruidosos
17
+ os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "FALSE")
18
+
19
  import gradio as gr
20
  import tensorflow as tf
21
  from tensorflow import keras
 
26
  from fake_useragent import UserAgent
27
  user_agent = UserAgent().random
28
  except Exception:
29
+ user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
 
30
  header_template = {"User-Agent": user_agent}
31
 
32
+ # --- LangChain v0.2/0.3 family ---
 
33
  from langchain_text_splitters import RecursiveCharacterTextSplitter
34
  from langchain_core.prompts import ChatPromptTemplate
35
  from langchain_core.output_parsers import PydanticOutputParser
 
38
  from langchain_community.document_loaders import WebBaseLoader
39
  from langchain_community.vectorstores import Chroma
40
 
41
+ # Embeddings (prefer langchain-huggingface si está instal., si no community)
42
  try:
43
  from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
44
  except ImportError:
45
  from langchain_community.embeddings import HuggingFaceEmbeddings
46
 
47
+ # Context compression / retrievers
48
  from langchain.retrievers import ContextualCompressionRetriever
49
  from langchain.retrievers.document_compressors import DocumentCompressorPipeline
50
 
51
+ # --- Retrievers avanzados / reranker ---
52
+ from langchain_community.retrievers import BM25Retriever, EnsembleRetriever
53
+ from langchain.retrievers.multi_query import MultiQueryRetriever
54
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
55
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
56
 
57
+ from pydantic import BaseModel, Field # Pydantic v2
58
+
59
+ # HF Hub para descargar el SavedModel de imagen
60
  from huggingface_hub import snapshot_download
61
 
62
+ # === LLM endpoint moderno (langchain-huggingface) ===
63
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
64
 
65
  # Theming + URL list
66
  import theme
67
  from url_list import URLS
 
68
  theme = theme.Theme()
69
 
 
 
 
 
 
 
 
70
  # =========================================================
71
  # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
72
  # =========================================================
 
80
  class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
81
 
82
  def predict_image(input_image: Image.Image):
83
+ """Preprocesa a EfficientNetB0 (224x224) y ejecuta inferencia."""
84
  img = input_image.convert("RGB").resize((224, 224))
85
  x = tf.keras.preprocessing.image.img_to_array(img)
86
  x = tf.keras.applications.efficientnet.preprocess_input(x)
 
104
  theme=theme,
105
  )
106
 
 
107
  # ============================================
108
  # 2) KNOWLEDGE LOADING (RAG: loader + splitter)
109
  # ============================================
 
125
 
126
  all_loaded_docs = safe_load_all_urls(URLS)
127
 
128
+ # Chunks algo más largos (mejor para reranker)
129
  base_splitter = RecursiveCharacterTextSplitter(
130
+ chunk_size=900,
131
+ chunk_overlap=100,
132
  length_function=len,
133
  )
134
  docs = base_splitter.split_documents(all_loaded_docs)
135
 
136
+ # Embeddings MEJORADOS (recuperación)
137
+ embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-base-v2")
138
 
139
  # Vector store
140
  persist_directory = "docs/chroma/"
141
  shutil.rmtree(persist_directory, ignore_errors=True)
 
142
  vectordb = Chroma.from_documents(
143
  documents=docs,
144
  embedding=embeddings,
145
  persist_directory=persist_directory,
146
  )
147
 
148
+ # Base retriever (vectorial)
149
+ vec_retriever = vectordb.as_retriever(search_kwargs={"k": 8}, search_type="mmr")
150
+
151
+ # BM25 + Ensemble (híbrido)
152
+ bm25 = BM25Retriever.from_documents(docs)
153
+ bm25.k = 8
154
+ hybrid_retriever = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6])
155
+
156
+ # --- Multi-Query (paráfrasis de la consulta) ---
157
+ # Se apoya en el propio LLM para generar variantes y subir recall
158
+ # (lo definimos después de crear el LLM, ver sección 4)
159
 
160
+ # --- Compresión / split fino para compresor downstream ---
161
  try:
162
  from langchain_text_splitters import TokenTextSplitter
163
+ splitter_for_compression = TokenTextSplitter(chunk_size=220, chunk_overlap=30) # requiere tiktoken
164
  except Exception:
165
  from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
166
  splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
167
 
168
+ compressor_pipeline = DocumentCompressorPipeline(transformers=[splitter_for_compression])
 
 
 
 
 
169
 
170
  # ======================================
171
  # 3) PROMPT & Pydantic schema parsing
 
177
  parser = PydanticOutputParser(pydantic_object=FinalAnswer)
178
 
179
  SYSTEM_TEMPLATE = (
180
+ "Eres Greta, una asistente bilingüe (ES/EN) experta en reciclaje y sostenibilidad. "
181
+ "Responde de forma directa, útil y en el idioma del usuario. "
182
+ "Si la respuesta no aparece en los fragmentos, dilo explícitamente y ofrece pasos prácticos. "
183
+ "No inventes datos.\n\n"
184
+ "Fragmentos:\n{context}\n\n"
185
+ "Pregunta: {question}\n"
186
  "{format_instructions}"
187
  )
188
 
189
+ qa_prompt = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE).partial(
190
+ format_instructions=parser.get_format_instructions()
 
191
  )
192
 
193
+ # ===========================================
194
+ # 4) LLM — Hugging Face Inference (Llama 3.1 8B)
195
+ # ===========================================
196
  endpoint = HuggingFaceEndpoint(
197
  repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
198
+ task="text-generation", # estable para chat via HF Inference
199
+ max_new_tokens=900,
200
+ temperature=0.2,
201
+ top_k=40,
202
+ repetition_penalty=1.05,
203
  return_full_text=False,
204
  huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
205
  timeout=120,
206
  model_kwargs={},
207
  )
208
 
209
+ # OJO: usar llm= (no client=)
210
  llm = ChatHuggingFace(llm=endpoint)
211
 
212
  # ===========================================
213
+ # 5) Chain (memory + RAG mejorado + robust JSON)
214
  # ===========================================
215
+
216
+ # Memoria (aviso deprec., pero funcional en LC 0.3)
217
  memory = ConversationBufferMemory(
218
  memory_key="chat_history",
219
  return_messages=True,
220
  )
221
 
222
+ # Multi-Query sobre el retriever híbrido
223
+ mqr = MultiQueryRetriever.from_llm(retriever=hybrid_retriever, llm=llm, include_original=True)
224
+
225
+ # Reranker más ligero (reduce coste latencia)
226
+ cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
227
+ reranker = CrossEncoderReranker(model=cross_encoder, top_n=4)
228
+
229
+ # Compresor contextual (híbrido + multi-query → rerank → compresión fina)
230
+ compression_retriever = ContextualCompressionRetriever(
231
+ base_retriever=mqr,
232
+ base_compressor=reranker,
233
+ )
234
+
235
  qa_chain = ConversationalRetrievalChain.from_llm(
236
  llm=llm,
237
  retriever=compression_retriever,
 
270
  f"Detalle técnico: {e}"
271
  )
272
 
 
 
 
 
 
 
 
273
  # ============================
274
  # 6) Banner / Welcome content
275
  # ============================
 
296
  """
297
  banner_tab = gr.Markdown(banner_tab_content)
298
 
 
299
  # ============================
300
  # 7) Gradio app (tabs + run)
301
  # ============================
302
+
303
+ # CSS simple para “ampliar visualmente” el área del chat sin usar height=
304
+ custom_css = """
305
+ /* Aumenta altura mínima del contenedor de mensajes del chatbot */
306
+ .gr-chatbot { min-height: 520px !important; }
307
+ .gr-chatbot > div { min-height: 520px !important; }
308
+ /* Un poco más de ancho general */
309
+ .gradio-container { max-width: 1200px !important; }
310
+ """
311
+
312
+ chatbot_gradio_app = gr.ChatInterface(
313
+ fn=chat_interface,
314
+ title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
315
+ theme=theme,
316
+ )
317
+
318
  app = gr.TabbedInterface(
319
  [banner_tab, image_gradio_app, chatbot_gradio_app],
320
  tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
321
  theme=theme,
322
+ css=custom_css, # aplica CSS globalmente a las pestañas
323
  )
324
 
325
  app.queue()