Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import json | |
| import time | |
| from transformers import AutoTokenizer | |
| import os | |
| import importlib | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| from dotenv import load_dotenv | |
| from infer import ( | |
| load_trained_model, | |
| find_answer_start, | |
| get_noising_schedule, | |
| noisify_answer, | |
| filter_logits, | |
| confidence_guided_noising, | |
| noisify_answer_without_remasking | |
| ) | |
| from models import CustomTransformerModel | |
| from model_config import CustomTransformerConfig | |
| # Load .env only when running locally | |
| if os.getenv("HF_TOKEN") is None: | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token is None: | |
| raise ValueError("HF_TOKEN is not set") | |
| rng = np.random.default_rng() | |
| def generate_diffusion_text(input_ids, top_p, top_k, eos_bias=0.0): | |
| with torch.no_grad(): | |
| input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| logits = model(input_ids=input_tensor)["logits"] | |
| # Apply eos_bias | |
| if eos_bias != 0.0: | |
| logits[0, :, eos_token_id] += eos_bias | |
| logits = filter_logits(logits, top_k=top_p, top_p=top_k) | |
| logits = logits.clamp(min=-1e8, max=1e4) | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| probs = torch.clamp(probs, min=1e-8, max=1.0) | |
| assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!" | |
| assert (probs >= 0).all(), "Negative probs!" | |
| sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() | |
| # Extract confidence of selected tokens | |
| conf = probs[range(len(sampled)), sampled].cpu().numpy() | |
| return sampled, conf | |
| def format_chat_prompt(question): | |
| return ( | |
| "<|begin_of_text|>\n" | |
| "<|start_header_id|>system<|end_header_id|>\n" | |
| "You are a helpful assistant.\n" | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{question}\n" | |
| "<|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| def render_html(label, text): | |
| return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>" | |
| def highlight_tokens(token_ids, answer_start, changed_indices, color): | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| highlighted = [] | |
| for j, tok in enumerate(tokens): | |
| if tokenizer.convert_tokens_to_ids(tok) == eos_token_id: | |
| continue | |
| tok_str = tokenizer.convert_tokens_to_string([tok]) | |
| if (answer_start + j) in changed_indices: | |
| highlighted.append(f'<span style="color:{color}">{tok_str}</span>') | |
| else: | |
| highlighted.append(tok_str) | |
| return "".join(highlighted) | |
| def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness, | |
| noise_start, use_confidence_noising, | |
| use_permanent_unmasking, noise_clipping, top_p, | |
| top_k, added_tokens): | |
| eos_bias = -eos_bias | |
| if question.strip() == "": | |
| question = "What do you know about the city of Amsterdam?" | |
| prompt = format_chat_prompt(question) | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| answer_start = find_answer_start(input_ids, assistant_marker_ids) | |
| if answer_start is None: | |
| yield render_html("Error", "Could not find Assistant marker in input.") | |
| return | |
| input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256] | |
| ori_input_tokens = input_ids | |
| # Initial noising | |
| current_tokens, just_noised_indices = noisify_answer( | |
| input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0 | |
| ) | |
| yield render_html("Iteration 0 (initial noise)", | |
| highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red")) | |
| time.sleep(pause_length) | |
| last_tokens = [] | |
| prev_decoded = [] | |
| unmasked_mask = [False] * len(current_tokens) | |
| current_tokens = current_tokens[:answer_start] | |
| for i in range(max_it): | |
| current_tokens = current_tokens + [mask_token_id] * added_tokens | |
| current_tokens = current_tokens[:256] # Ensure we don't exceed the max length | |
| generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias) | |
| current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] | |
| # GREEN highlighting: compare to previous tokens | |
| new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
| diff_indices = { | |
| answer_start + j for j, tok in enumerate(new_decoded) | |
| if j >= len(prev_decoded) or tok != prev_decoded[j] | |
| } | |
| prev_decoded = new_decoded | |
| yield render_html(f"Iteration {i+1}/{max_it} (after generation)", | |
| highlight_tokens(current_tokens[answer_start:], answer_start, diff_indices, color="green")) | |
| time.sleep(pause_length) | |
| # Early stopping | |
| last_tokens.append(current_tokens) | |
| if len(last_tokens) > 3: | |
| last_tokens.pop(0) | |
| if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: | |
| yield render_html("Stopped early", f"After {i+1} iterations.") | |
| break | |
| # NOISING | |
| if i < max_it-1: | |
| threshold = get_noising_schedule(i, max_it, sharpness=sharpness) | |
| if use_confidence_noising: | |
| noised_answer, just_noised_indices = confidence_guided_noising( | |
| current_tokens, answer_start, tokenizer, confidences, noise_clipping, | |
| threshold=threshold, noise_start=noise_start | |
| ) | |
| elif use_permanent_unmasking: | |
| noised_answer, just_noised_indices = noisify_answer_without_remasking( | |
| current_tokens, answer_start, tokenizer, threshold=threshold, | |
| noise_start=noise_start, unmasked_mask=unmasked_mask | |
| ) | |
| else: | |
| noised_answer, just_noised_indices = noisify_answer( | |
| current_tokens, answer_start, tokenizer, | |
| threshold=threshold, noise_start=noise_start | |
| ) | |
| for idx in range(answer_start, len(current_tokens)): | |
| if noised_answer[idx] != mask_token_id: | |
| unmasked_mask[idx] = True | |
| yield render_html(f"Iteration {i+1}/{max_it} (before noising)", | |
| highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red")) | |
| current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] | |
| # Final output | |
| answer_ids = current_tokens[answer_start:] | |
| try: | |
| final_ids = answer_ids[:answer_ids.index(eos_token_id)] | |
| except ValueError: | |
| final_ids = answer_ids | |
| final_output = tokenizer.decode(final_ids, skip_special_tokens=True) | |
| yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore | |
| def is_running_on_spaces(): | |
| return os.getenv("SPACE_ID") is not None | |
| print("Loading model...") | |
| if is_running_on_spaces(): | |
| # Load from Hugging Face Hub | |
| ckpt_path = hf_hub_download( | |
| repo_id="ruurd/tini_model", | |
| filename="diffusion-model-8B.pth", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| else: | |
| # Load from local path | |
| ckpt_path = "diffusion-model-3B.pth" # change to your actual local path | |
| model, tokenizer = load_trained_model(checkpoint_path=ckpt_path) | |
| print("✅ Model loaded.") | |
| vocab_size = len(tokenizer) | |
| eos_token_id = tokenizer.eos_token_id | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] | |
| assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False) | |
| demo = gr.Interface( | |
| fn=diffusion_chat, | |
| inputs=[ | |
| gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"), | |
| gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"), | |
| gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"), | |
| gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"), | |
| gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"), | |
| gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"), | |
| gr.Checkbox(value=False, label="Use confidence-guided noising"), | |
| gr.Checkbox(value=False, label="Use permanent unmasking"), | |
| gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"), | |
| gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"), | |
| gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers"), | |
| gr.Slider(1, 256, value=256, step=1, label="Semi-autoregressive generation: number of added tokens per iteration"), | |
| ], | |
| outputs=[gr.HTML(label="Diffusion Output")], | |
| title="Diffusion Language Model Chat", | |
| theme="default", | |
| description="This interface runs a diffusion-based language model to generate answers progressively." | |
| ) | |
| demo.launch(share=True, allowed_paths=["."], ssr_mode=False) | |