ALVHB95's picture
new model
c029d89
raw
history blame
14.7 kB
"""
=========================================================
app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2/0.3)
=========================================================
"""
import os
import json
import shutil
# --- Ajustes de entorno / telemetría (antes de importar Chroma) ---
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
os.environ.setdefault("ANONYMIZED_TELEMETRY", "false")
# Silenciar telemetría de Chroma para evitar warnings/tracebacks ruidosos
os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "FALSE")
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import tenacity
try:
from fake_useragent import UserAgent
user_agent = UserAgent().random
except Exception:
user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
header_template = {"User-Agent": user_agent}
# --- LangChain v0.2/0.3 family ---
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
# Embeddings (prefer langchain-huggingface si está instal., si no community)
try:
from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
# Context compression / retrievers
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
# --- Retrievers avanzados / reranker ---
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers.document_compressors import CrossEncoderReranker
from pydantic import BaseModel, Field # Pydantic v2
# HF Hub para descargar el SavedModel de imagen
from huggingface_hub import snapshot_download
# === LLM endpoint moderno (langchain-huggingface) ===
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# Theming + URL list
import theme
from url_list import URLS
theme = theme.Theme()
# =========================================================
# 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading
# =========================================================
MODEL_REPO = "rocioadlc/efficientnetB0_trash"
MODEL_SERVING_SIGNATURE = "serving_default" # ajusta si el modelo expone otra firma
# Descarga el snapshot y envuélvelo con TFSMLayer (compatible Keras 3)
model_dir = snapshot_download(MODEL_REPO)
image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE)
class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"]
def predict_image(input_image: Image.Image):
"""Preprocesa a EfficientNetB0 (224x224) y ejecuta inferencia."""
img = input_image.convert("RGB").resize((224, 224))
x = tf.keras.preprocessing.image.img_to_array(img)
x = tf.keras.applications.efficientnet.preprocess_input(x)
x = tf.expand_dims(x, 0)
outputs = image_model(x)
if isinstance(outputs, dict) and outputs:
preds = outputs[next(iter(outputs))]
else:
preds = outputs
arr = preds.numpy() if hasattr(preds, "numpy") else preds
probs = arr[0].tolist()
return {label: float(probs[i]) for i, label in enumerate(class_labels)}
image_gradio_app = gr.Interface(
fn=predict_image,
inputs=gr.Image(label="Image", sources=["upload", "webcam"], type="pil"),
outputs=[gr.Label(label="Result")],
title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
theme=theme,
)
# ============================================
# 2) KNOWLEDGE LOADING (RAG: loader + splitter)
# ============================================
@tenacity.retry(wait=tenacity.wait_fixed(3), stop=tenacity.stop_after_attempt(3), reraise=True)
def load_url(url: str):
loader = WebBaseLoader(web_paths=[url], header_template=header_template)
return loader.load()
def safe_load_all_urls(urls):
all_docs = []
for link in urls:
try:
docs = load_url(link)
all_docs.extend(docs)
except Exception as e:
print(f"Skipping URL due to error: {link}\nError: {e}\n")
return all_docs
all_loaded_docs = safe_load_all_urls(URLS)
# Chunks algo más largos (mejor para reranker)
base_splitter = RecursiveCharacterTextSplitter(
chunk_size=900,
chunk_overlap=100,
length_function=len,
)
docs = base_splitter.split_documents(all_loaded_docs)
# Embeddings MEJORADOS (recuperación)
embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-base-v2")
# Vector store
persist_directory = "docs/chroma/"
shutil.rmtree(persist_directory, ignore_errors=True)
vectordb = Chroma.from_documents(
documents=docs,
embedding=embeddings,
persist_directory=persist_directory,
)
# Base retriever (vectorial)
vec_retriever = vectordb.as_retriever(search_kwargs={"k": 8}, search_type="mmr")
# BM25 + Ensemble (híbrido)
bm25 = BM25Retriever.from_documents(docs)
bm25.k = 8
hybrid_retriever = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6])
# --- Multi-Query (paráfrasis de la consulta) ---
# Se apoya en el propio LLM para generar variantes y subir recall
# (lo definimos después de crear el LLM, ver sección 4)
# --- Compresión / split fino para compresor downstream ---
try:
from langchain_text_splitters import TokenTextSplitter
splitter_for_compression = TokenTextSplitter(chunk_size=220, chunk_overlap=30) # requiere tiktoken
except Exception:
from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter
splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50)
compressor_pipeline = DocumentCompressorPipeline(transformers=[splitter_for_compression])
# ======================================
# 3) PROMPT & Pydantic schema parsing
# ======================================
class FinalAnswer(BaseModel):
question: str = Field(description="User question")
answer: str = Field(description="Direct answer")
parser = PydanticOutputParser(pydantic_object=FinalAnswer)
SYSTEM_TEMPLATE = (
"Eres Greta, una asistente bilingüe (ES/EN) experta en reciclaje y sostenibilidad. "
"Responde de forma directa, útil y en el idioma del usuario. "
"Si la respuesta no aparece en los fragmentos, dilo explícitamente y ofrece pasos prácticos. "
"No inventes datos.\n\n"
"Fragmentos:\n{context}\n\n"
"Pregunta: {question}\n"
"{format_instructions}"
)
qa_prompt = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE).partial(
format_instructions=parser.get_format_instructions()
)
# ===========================================
# 4) LLM — Hugging Face Inference (Llama 3.1 8B)
# ===========================================
endpoint = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
task="text-generation", # estable para chat via HF Inference
max_new_tokens=900,
temperature=0.2,
top_k=40,
repetition_penalty=1.05,
return_full_text=False,
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
timeout=120,
model_kwargs={},
)
# OJO: usar llm= (no client=)
llm = ChatHuggingFace(llm=endpoint)
# ===========================================
# 5) Chain (memory + RAG mejorado + robust JSON)
# ===========================================
# Memoria (aviso deprec., pero funcional en LC 0.3)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
)
# Multi-Query sobre el retriever híbrido
mqr = MultiQueryRetriever.from_llm(retriever=hybrid_retriever, llm=llm, include_original=True)
# Reranker más ligero (reduce coste latencia)
cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
reranker = CrossEncoderReranker(model=cross_encoder, top_n=4)
# Compresor contextual (híbrido + multi-query → rerank → compresión fina)
compression_retriever = ContextualCompressionRetriever(
base_retriever=mqr,
base_compressor=reranker,
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=compression_retriever,
memory=memory,
verbose=True,
combine_docs_chain_kwargs={"prompt": qa_prompt},
get_chat_history=lambda h: h,
rephrase_question=False,
output_key="output",
)
def _safe_json_extract(raw: str, question: str) -> dict:
"""Intenta JSON estricto; si falla, extrae el primer {...}; si no, texto plano."""
raw = (raw or "").strip()
try:
return json.loads(raw)
except json.JSONDecodeError:
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(raw[start : end + 1])
except json.JSONDecodeError:
pass
return {"question": question, "answer": raw or "No answer produced."}
def chat_interface(question, history):
try:
result = qa_chain.invoke({"question": question})
payload = _safe_json_extract(result.get("output", ""), question)
return payload.get("answer", "")
except Exception as e:
return (
"Lo siento, tuve un problema procesando tu pregunta. "
"Intenta de nuevo en un momento o formula la consulta de otra manera.\n\n"
f"Detalle técnico: {e}"
)
# ============================
# 6) Banner / Welcome content
# ============================
banner_tab_content = """
<div style="background-color: #d3e3c3; text-align: center; padding: 20px; display: flex; flex-direction: column; align-items: center;">
<img src="https://huggingface.co/spaces/ALVHB95/TFM_DataScience_APP/resolve/main/front_4.jpg" alt="Banner Image" style="width: 50%; max-width: 500px; margin: 0 auto;">
<h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">¡Bienvenido a nuestro clasificador de imágenes y chatbot para un reciclaje más inteligente!♻️</h1>
<p style="font-size: 16px; color: #4e6339; text-align: justify;">¿Alguna vez te has preguntado si puedes reciclar un objeto en particular? ¿O te has sentido abrumado por la cantidad de residuos que generas y no sabes cómo manejarlos de manera más sostenible? ¡Estás en el lugar correcto!</p>
<p style="font-size: 16px; color: #4e6339; text-align: justify;">Nuestra plataforma combina la potencia de la inteligencia artificial con la comodidad de un chatbot para brindarte respuestas rápidas y precisas sobre qué objetos son reciclables y cómo hacerlo de la manera más eficiente.</p>
<p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">¿Cómo usarlo?</span></strong></p>
<ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;">
<li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Ve a la pestaña Greta Image Classification y simplemente carga una foto del objeto que quieras reciclar, y nuestro modelo identificará de qué se trata🕵️‍♂️ para que puedas desecharlo adecuadamente.</li>
<li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> ¿Tienes preguntas sobre reciclaje, materiales específicos o prácticas sostenibles? ¡Pregunta a nuestro chatbot en la pestaña Green Greta Chat!📝 Está aquí para responder todas tus preguntas y ayudarte a tomar decisiones más informadas sobre tu reciclaje.</li>
</ul>
<h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">Welcome to our image classifier and chatbot for smarter recycling!♻️</h1>
<p style="font-size: 16px; color: #4e6339; text-align: justify;">Have you ever wondered if you can recycle a particular object? Or felt overwhelmed by the amount of waste you generate and don't know how to handle it more sustainably? You're in the right place!</p>
<p style="font-size: 16px; color: #4e6339; text-align: justify;">Our platform combines the power of artificial intelligence with the convenience of a chatbot to provide you with quick and accurate answers about which objects are recyclable and how to do it most efficiently.</p>
<p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">How to use it?</span></strong>
<ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;">
<li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Go to the Greta Image Classification tab and simply upload a photo of the object you want to recycle, and our model will identify what it is🕵️‍♂️ so you can dispose of it properly.</li>
<li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> Have questions about recycling, specific materials, or sustainable practices? Ask our chatbot in the Green Greta Chat tab!📝 It's here to answer all your questions and help you make more informed decisions about your recycling.</li>
</ul>
</div>
"""
banner_tab = gr.Markdown(banner_tab_content)
# ============================
# 7) Gradio app (tabs + run)
# ============================
# CSS simple para “ampliar visualmente” el área del chat sin usar height=
custom_css = """
/* Aumenta altura mínima del contenedor de mensajes del chatbot */
.gr-chatbot { min-height: 520px !important; }
.gr-chatbot > div { min-height: 520px !important; }
/* Un poco más de ancho general */
.gradio-container { max-width: 1200px !important; }
"""
chatbot_gradio_app = gr.ChatInterface(
fn=chat_interface,
title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>",
theme=theme,
)
app = gr.TabbedInterface(
[banner_tab, image_gradio_app, chatbot_gradio_app],
tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"],
theme=theme,
css=custom_css, # aplica CSS globalmente a las pestañas
)
app.queue()
app.launch()