File size: 5,005 Bytes
ac72c21
d7f8dad
43130a6
 
 
 
 
 
 
 
 
 
 
 
93a3f9a
d7f8dad
 
7c2299f
d7f8dad
 
93cbacc
d7f8dad
 
e5cdcee
 
 
 
 
43130a6
bb934b2
617bd81
bb934b2
617bd81
 
bb934b2
ac72c21
 
 
310d018
6c99f7c
bb934b2
 
ac72c21
bb934b2
 
 
 
6c99f7c
bb934b2
 
310d018
d7f8dad
bb934b2
 
ac72c21
617bd81
bb934b2
 
 
6c99f7c
bb934b2
310d018
6c99f7c
bb934b2
 
 
310d018
bb934b2
ac72c21
6c99f7c
bb934b2
 
 
69a7fd1
 
 
 
93cbacc
bb934b2
 
 
 
310d018
bb934b2
310d018
bb934b2
 
93cbacc
bb934b2
ac72c21
bb934b2
 
 
 
 
 
310d018
bb934b2
9a72c69
bb934b2
310d018
bb934b2
 
310d018
9a72c69
 
bb934b2
9a72c69
bb934b2
 
 
 
 
 
 
 
9a72c69
93a3f9a
9a72c69
bb934b2
93a3f9a
bb934b2
 
 
9a72c69
93a3f9a
310d018
bb934b2
310d018
93a3f9a
bb934b2
 
9a72c69
 
bb934b2
310d018
9a72c69
bb934b2
69a7fd1
bb934b2
 
db4c5ab
bb934b2
310d018
bb934b2
 
 
 
 
9a72c69
175fea5
bb934b2
 
 
 
310d018
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
# Handle Spaces GPU
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper

@spaces.GPU
def fake_gpu():
    pass
    
import numpy as np
import pandas as pd
import torch
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from huggingface_hub import login

# Authenticate
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)

# Modelos disponibles
AVAILABLE_MODELS = {
    "BLOOMZ-560M": "bigscience/bloomz-560m"
}

# Inicializar modelo y tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"

def cargar_modelo(nombre_modelo):
    """Carga el modelo y el tokenizer seleccionado."""
    global current_model, current_tokenizer, current_model_name
    if current_model_name != nombre_modelo:
        current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[nombre_modelo]).to(device)
        current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[nombre_modelo])
        current_model_name = nombre_modelo

# Cargar el modelo por defecto
cargar_modelo("BLOOMZ-560M")

@spaces.GPU()
def obtener_predicciones(texto, nombre_modelo, top_k=10):
    """Genera las predicciones de las siguientes palabras con sus probabilidades."""
    global current_model, current_tokenizer
    
    # Cargar modelo si ha cambiado
    if current_model_name != nombre_modelo:
        cargar_modelo(nombre_modelo)
    
    entradas = current_tokenizer(texto, return_tensors="pt").to(device)

    with torch.no_grad():
        salidas = current_model(**entradas)
        logits = salidas.logits[0, -1, :]
        probabilidades = torch.nn.functional.softmax(logits, dim=-1)
    
    top_k_prob, top_k_indices = torch.topk(probabilidades, k=top_k)
    top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
    
    return top_k_tokens, top_k_prob.cpu().tolist()

def generar_barplot(tokens, probabilidades):
    """Convierte los datos en un DataFrame para Gradio BarPlot."""
    df = pd.DataFrame({"Palabra": tokens, "Probabilidad": probabilidades})
    print(df)
    return df  # ✅ Now returning a Pandas DataFrame instead of a list

def predecir_siguiente_palabra(nombre_modelo, texto, top_k, token_custom=""):
    """Obtiene predicciones y actualiza la UI."""
    if token_custom:
        texto += token_custom

    tokens, probabilidades = obtener_predicciones(texto, nombre_modelo, int(top_k))

    # Generar gráfico con Gradio BarPlot
    barplot_data = generar_barplot(tokens, probabilidades)

    return gr.update(choices=[f"'{t}'" for t in tokens]), barplot_data

def agregar_token_seleccionado(texto, token_seleccionado):
    """Agrega el token seleccionado al texto de entrada."""
    if token_seleccionado:
        token_limpio = token_seleccionado.strip("'")
        texto += f" {token_limpio}"
    return texto

# Crear la interfaz en español
with gr.Blocks() as demo:
    gr.Markdown("# 🔥 Predicción de Texto con Modelos Transformadores")
    gr.Markdown(
        "Esta aplicación permite generar palabras utilizando un modelo de lenguaje. "
        "Selecciona un modelo, introduce un texto y explora las palabras más probables a continuación."
    )
    
    with gr.Row():
        dropdown_modelo = gr.Dropdown(
            choices=list(AVAILABLE_MODELS.keys()),
            value="BLOOMZ-560M",
            label="📌 Modelo de lenguaje"
        )

        dropdown_top_k = gr.Dropdown(
            choices=["5", "10", "15", "20"],
            value="10",
            label="🔢 Número de palabras a mostrar"
        )
    
    with gr.Row():
        texto_entrada = gr.Textbox(
            lines=5,
            label="📝 Texto de entrada",
            placeholder="Escribe aquí...",
            value="Mi abuela me dejó una gran"
        )
    
    with gr.Row():
        boton_predecir = gr.Button("🔮 Predecir")

    with gr.Row():
        dropdown_tokens = gr.Dropdown(
            label="🔠 Palabras predichas",
            choices=[]
        )
        boton_agregar = gr.Button("➕ Agregar palabra")

    with gr.Row():
        barplot_resultados = gr.BarPlot(
            value=pd.DataFrame(columns=["Palabra", "Probabilidad"]),  # ✅ Empty DataFrame to initialize
            x="Palabra",
            y="Probabilidad",
            title="📊 Predicciones del modelo"
        )

    # Acciones de botones
    boton_predecir.click(
        predecir_siguiente_palabra,
        inputs=[dropdown_modelo, texto_entrada, dropdown_top_k],
        outputs=[dropdown_tokens, barplot_resultados]
    )

    boton_agregar.click(
        agregar_token_seleccionado,
        inputs=[texto_entrada, dropdown_tokens],
        outputs=texto_entrada
    )

demo.queue().launch()