Spaces:
Running
Running
| """ | |
| ========================================================= | |
| app.py — Green Greta (Gradio + TF/Keras 3 + Local HF + LangChain v0.2/0.3) | |
| ========================================================= | |
| """ | |
| import os | |
| import json | |
| import shutil | |
| # --- Ajustes de entorno / telemetría (antes de importar Chroma) --- | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") | |
| os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") | |
| os.environ.setdefault("ANONYMIZED_TELEMETRY", "false") | |
| # Silenciar telemetría de Chroma para evitar warnings/tracebacks ruidosos | |
| os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "FALSE") | |
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from PIL import Image | |
| import tenacity | |
| try: | |
| from fake_useragent import UserAgent | |
| user_agent = UserAgent().random | |
| except Exception: | |
| user_agent = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" | |
| header_template = {"User-Agent": user_agent} | |
| # --- LangChain v0.2/0.3 family --- | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import Chroma | |
| # Embeddings (prefer langchain-huggingface si está instal., si no community) | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings # pip install -U langchain-huggingface | |
| except ImportError: | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| # Context compression / retrievers | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import DocumentCompressorPipeline | |
| # --- Retrievers avanzados / reranker --- | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from pydantic import BaseModel, Field # Pydantic v2 | |
| # HF Hub para descargar el SavedModel de imagen | |
| from huggingface_hub import snapshot_download | |
| # === LLM endpoint moderno (langchain-huggingface) === | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| # Theming + URL list | |
| import theme | |
| from url_list import URLS | |
| theme = theme.Theme() | |
| # ========================================================= | |
| # 1) IMAGE CLASSIFICATION — Keras 3-safe SavedModel loading | |
| # ========================================================= | |
| MODEL_REPO = "rocioadlc/efficientnetB0_trash" | |
| MODEL_SERVING_SIGNATURE = "serving_default" # ajusta si el modelo expone otra firma | |
| # Descarga el snapshot y envuélvelo con TFSMLayer (compatible Keras 3) | |
| model_dir = snapshot_download(MODEL_REPO) | |
| image_model = keras.layers.TFSMLayer(model_dir, call_endpoint=MODEL_SERVING_SIGNATURE) | |
| class_labels = ["cardboard", "glass", "metal", "paper", "plastic", "trash"] | |
| def predict_image(input_image: Image.Image): | |
| """Preprocesa a EfficientNetB0 (224x224) y ejecuta inferencia.""" | |
| img = input_image.convert("RGB").resize((224, 224)) | |
| x = tf.keras.preprocessing.image.img_to_array(img) | |
| x = tf.keras.applications.efficientnet.preprocess_input(x) | |
| x = tf.expand_dims(x, 0) | |
| outputs = image_model(x) | |
| if isinstance(outputs, dict) and outputs: | |
| preds = outputs[next(iter(outputs))] | |
| else: | |
| preds = outputs | |
| arr = preds.numpy() if hasattr(preds, "numpy") else preds | |
| probs = arr[0].tolist() | |
| return {label: float(probs[i]) for i, label in enumerate(class_labels)} | |
| image_gradio_app = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(label="Image", sources=["upload", "webcam"], type="pil"), | |
| outputs=[gr.Label(label="Result")], | |
| title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>", | |
| theme=theme, | |
| ) | |
| # ============================================ | |
| # 2) KNOWLEDGE LOADING (RAG: loader + splitter) | |
| # ============================================ | |
| def load_url(url: str): | |
| loader = WebBaseLoader(web_paths=[url], header_template=header_template) | |
| return loader.load() | |
| def safe_load_all_urls(urls): | |
| all_docs = [] | |
| for link in urls: | |
| try: | |
| docs = load_url(link) | |
| all_docs.extend(docs) | |
| except Exception as e: | |
| print(f"Skipping URL due to error: {link}\nError: {e}\n") | |
| return all_docs | |
| all_loaded_docs = safe_load_all_urls(URLS) | |
| # Chunks algo más largos (mejor para reranker) | |
| base_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=900, | |
| chunk_overlap=100, | |
| length_function=len, | |
| ) | |
| docs = base_splitter.split_documents(all_loaded_docs) | |
| # Embeddings MEJORADOS (recuperación) | |
| embeddings = HuggingFaceEmbeddings(model_name="intfloat/e5-base-v2") | |
| # Vector store | |
| persist_directory = "docs/chroma/" | |
| shutil.rmtree(persist_directory, ignore_errors=True) | |
| vectordb = Chroma.from_documents( | |
| documents=docs, | |
| embedding=embeddings, | |
| persist_directory=persist_directory, | |
| ) | |
| # Base retriever (vectorial) | |
| vec_retriever = vectordb.as_retriever(search_kwargs={"k": 8}, search_type="mmr") | |
| # BM25 + Ensemble (híbrido) | |
| bm25 = BM25Retriever.from_documents(docs) | |
| bm25.k = 8 | |
| hybrid_retriever = EnsembleRetriever(retrievers=[bm25, vec_retriever], weights=[0.4, 0.6]) | |
| # --- Multi-Query (paráfrasis de la consulta) --- | |
| # Se apoya en el propio LLM para generar variantes y subir recall | |
| # (lo definimos después de crear el LLM, ver sección 4) | |
| # --- Compresión / split fino para compresor downstream --- | |
| try: | |
| from langchain_text_splitters import TokenTextSplitter | |
| splitter_for_compression = TokenTextSplitter(chunk_size=220, chunk_overlap=30) # requiere tiktoken | |
| except Exception: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter as FallbackSplitter | |
| splitter_for_compression = FallbackSplitter(chunk_size=300, chunk_overlap=50) | |
| compressor_pipeline = DocumentCompressorPipeline(transformers=[splitter_for_compression]) | |
| # ====================================== | |
| # 3) PROMPT & Pydantic schema parsing | |
| # ====================================== | |
| class FinalAnswer(BaseModel): | |
| question: str = Field(description="User question") | |
| answer: str = Field(description="Direct answer") | |
| parser = PydanticOutputParser(pydantic_object=FinalAnswer) | |
| SYSTEM_TEMPLATE = ( | |
| "Eres Greta, una asistente bilingüe (ES/EN) experta en reciclaje y sostenibilidad. " | |
| "Responde de forma directa, útil y en el idioma del usuario. " | |
| "Si la respuesta no aparece en los fragmentos, dilo explícitamente y ofrece pasos prácticos. " | |
| "No inventes datos.\n\n" | |
| "Fragmentos:\n{context}\n\n" | |
| "Pregunta: {question}\n" | |
| "{format_instructions}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_template(SYSTEM_TEMPLATE).partial( | |
| format_instructions=parser.get_format_instructions() | |
| ) | |
| # =========================================== | |
| # 4) LLM — Hugging Face Inference (Llama 3.1 8B) | |
| # =========================================== | |
| endpoint = HuggingFaceEndpoint( | |
| repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| task="text-generation", # estable para chat via HF Inference | |
| max_new_tokens=900, | |
| temperature=0.2, | |
| top_k=40, | |
| repetition_penalty=1.05, | |
| return_full_text=False, | |
| huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
| timeout=120, | |
| model_kwargs={}, | |
| ) | |
| # OJO: usar llm= (no client=) | |
| llm = ChatHuggingFace(llm=endpoint) | |
| # =========================================== | |
| # 5) Chain (memory + RAG mejorado + robust JSON) | |
| # =========================================== | |
| # Memoria (aviso deprec., pero funcional en LC 0.3) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| ) | |
| # Multi-Query sobre el retriever híbrido | |
| mqr = MultiQueryRetriever.from_llm(retriever=hybrid_retriever, llm=llm, include_original=True) | |
| # Reranker más ligero (reduce coste latencia) | |
| cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
| reranker = CrossEncoderReranker(model=cross_encoder, top_n=4) | |
| # Compresor contextual (híbrido + multi-query → rerank → compresión fina) | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_retriever=mqr, | |
| base_compressor=reranker, | |
| ) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=compression_retriever, | |
| memory=memory, | |
| verbose=True, | |
| combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
| get_chat_history=lambda h: h, | |
| rephrase_question=False, | |
| output_key="output", | |
| ) | |
| def _safe_json_extract(raw: str, question: str) -> dict: | |
| """Intenta JSON estricto; si falla, extrae el primer {...}; si no, texto plano.""" | |
| raw = (raw or "").strip() | |
| try: | |
| return json.loads(raw) | |
| except json.JSONDecodeError: | |
| start = raw.find("{") | |
| end = raw.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| return json.loads(raw[start : end + 1]) | |
| except json.JSONDecodeError: | |
| pass | |
| return {"question": question, "answer": raw or "No answer produced."} | |
| def chat_interface(question, history): | |
| try: | |
| result = qa_chain.invoke({"question": question}) | |
| payload = _safe_json_extract(result.get("output", ""), question) | |
| return payload.get("answer", "") | |
| except Exception as e: | |
| return ( | |
| "Lo siento, tuve un problema procesando tu pregunta. " | |
| "Intenta de nuevo en un momento o formula la consulta de otra manera.\n\n" | |
| f"Detalle técnico: {e}" | |
| ) | |
| # ============================ | |
| # 6) Banner / Welcome content | |
| # ============================ | |
| banner_tab_content = """ | |
| <div style="background-color: #d3e3c3; text-align: center; padding: 20px; display: flex; flex-direction: column; align-items: center;"> | |
| <img src="https://huggingface.co/spaces/ALVHB95/TFM_DataScience_APP/resolve/main/front_4.jpg" alt="Banner Image" style="width: 50%; max-width: 500px; margin: 0 auto;"> | |
| <h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">¡Bienvenido a nuestro clasificador de imágenes y chatbot para un reciclaje más inteligente!♻️</h1> | |
| <p style="font-size: 16px; color: #4e6339; text-align: justify;">¿Alguna vez te has preguntado si puedes reciclar un objeto en particular? ¿O te has sentido abrumado por la cantidad de residuos que generas y no sabes cómo manejarlos de manera más sostenible? ¡Estás en el lugar correcto!</p> | |
| <p style="font-size: 16px; color: #4e6339; text-align: justify;">Nuestra plataforma combina la potencia de la inteligencia artificial con la comodidad de un chatbot para brindarte respuestas rápidas y precisas sobre qué objetos son reciclables y cómo hacerlo de la manera más eficiente.</p> | |
| <p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">¿Cómo usarlo?</span></strong></p> | |
| <ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;"> | |
| <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Ve a la pestaña Greta Image Classification y simplemente carga una foto del objeto que quieras reciclar, y nuestro modelo identificará de qué se trata🕵️♂️ para que puedas desecharlo adecuadamente.</li> | |
| <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> ¿Tienes preguntas sobre reciclaje, materiales específicos o prácticas sostenibles? ¡Pregunta a nuestro chatbot en la pestaña Green Greta Chat!📝 Está aquí para responder todas tus preguntas y ayudarte a tomar decisiones más informadas sobre tu reciclaje.</li> | |
| </ul> | |
| <h1 style="font-size: 24px; color: #4e6339; margin-top: 20px;">Welcome to our image classifier and chatbot for smarter recycling!♻️</h1> | |
| <p style="font-size: 16px; color: #4e6339; text-align: justify;">Have you ever wondered if you can recycle a particular object? Or felt overwhelmed by the amount of waste you generate and don't know how to handle it more sustainably? You're in the right place!</p> | |
| <p style="font-size: 16px; color: #4e6339; text-align: justify;">Our platform combines the power of artificial intelligence with the convenience of a chatbot to provide you with quick and accurate answers about which objects are recyclable and how to do it most efficiently.</p> | |
| <p style="font-size: 16px; text-align:center;"><strong><span style="color: #4e6339;">How to use it?</span></strong> | |
| <ul style="list-style-type: disc; text-align: justify; margin-top: 20px; padding-left: 20px;"> | |
| <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Image Classification:</span></strong> Go to the Greta Image Classification tab and simply upload a photo of the object you want to recycle, and our model will identify what it is🕵️♂️ so you can dispose of it properly.</li> | |
| <li style="font-size: 16px; color: #4e6339;"><strong><span style="color: #4e6339;">Green Greta Chat:</span></strong> Have questions about recycling, specific materials, or sustainable practices? Ask our chatbot in the Green Greta Chat tab!📝 It's here to answer all your questions and help you make more informed decisions about your recycling.</li> | |
| </ul> | |
| </div> | |
| """ | |
| banner_tab = gr.Markdown(banner_tab_content) | |
| # ============================ | |
| # 7) Gradio app (tabs + run) | |
| # ============================ | |
| # CSS simple para “ampliar visualmente” el área del chat sin usar height= | |
| custom_css = """ | |
| /* Aumenta altura mínima del contenedor de mensajes del chatbot */ | |
| .gr-chatbot { min-height: 520px !important; } | |
| .gr-chatbot > div { min-height: 520px !important; } | |
| /* Un poco más de ancho general */ | |
| .gradio-container { max-width: 1200px !important; } | |
| """ | |
| chatbot_gradio_app = gr.ChatInterface( | |
| fn=chat_interface, | |
| title="<span style='color: rgb(243, 239, 224);'>Green Greta</span>", | |
| theme=theme, | |
| ) | |
| app = gr.TabbedInterface( | |
| [banner_tab, image_gradio_app, chatbot_gradio_app], | |
| tab_names=["Welcome to Green Greta", "Green Greta Image Classification", "Green Greta Chat"], | |
| theme=theme, | |
| css=custom_css, # aplica CSS globalmente a las pestañas | |
| ) | |
| app.queue() | |
| app.launch() | |