Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import json | |
| import time | |
| from transformers import AutoTokenizer | |
| import os | |
| import importlib | |
| from huggingface_hub import hf_hub_download | |
| from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout | |
| import spaces | |
| hf_token = os.getenv("HF_TOKEN") | |
| # --- Load tokenizer --- | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token) | |
| vocab_size = len(tokenizer) | |
| pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| eot_token_id = tokenizer.eos_token_id | |
| assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False) | |
| # --- Load token probabilities --- | |
| with open("token_probabilities.json") as f: | |
| token_probs_dict = json.load(f) | |
| token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32) | |
| # def load_model(): | |
| # ckpt_path = hf_hub_download( | |
| # repo_id="ruurd/tini_bi_m", | |
| # filename="diffusion-model.pth", | |
| # token=os.getenv("HF_TOKEN") | |
| # ) | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # model = torch.load(ckpt_path, map_location=device) | |
| # model = disable_dropout(model) | |
| # model.to(device) | |
| # model.eval() | |
| # return model | |
| def load_model(): | |
| ckpt_path = hf_hub_download( | |
| repo_id="ruurd/tini_bi", | |
| filename="diffusion-model.pth", | |
| token=os.getenv("HF_TOKEN"), | |
| revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426", | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Step 1: Create model from scratch | |
| model = CustomTransformerModel(CustomTransformerConfig()) | |
| # Step 2: Load state_dict from full checkpoint | |
| full_model = torch.load(ckpt_path, map_location=device) | |
| # This handles both full model or just state_dict | |
| try: | |
| state_dict = full_model.state_dict() | |
| except AttributeError: | |
| state_dict = full_model # already a state_dict | |
| # Step 3: Load weights (might print mismatches) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| print("Missing keys:", missing) | |
| print("Unexpected keys:", unexpected) | |
| model = disable_dropout(model) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| rng = np.random.default_rng() | |
| # --- Utility Functions --- | |
| def decode_tokens_safe(token_ids): | |
| return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") | |
| def find_answer_start(input_ids, marker_ids): | |
| for i in range(len(input_ids) - len(marker_ids) + 1): | |
| if input_ids[i:i + len(marker_ids)] == marker_ids: | |
| return i + len(marker_ids) | |
| return None | |
| def get_noising_schedule(i, max_it, sharpness=5.0): | |
| x = i / max_it | |
| return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) | |
| def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, mask_weight=0.0, clustering=0.5, noise_start = 1.0): | |
| noised = input_ids.copy() | |
| answer_len = len(noised) - answer_start | |
| num_to_noise = int(threshold * answer_len * noise_start) | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0] | |
| if num_to_noise == 0: | |
| return noised, [] | |
| mixed_probs = token_probabilities.copy() | |
| # Apply EOT weighting | |
| mixed_probs[eot_token_id] *= eot_weight | |
| # Scale all other probabilities so they sum to 1 - mask_weight | |
| total_other = mixed_probs.sum() - mixed_probs[mask_token_id] | |
| scale = (1.0 - mask_weight) / total_other | |
| mixed_probs *= scale | |
| # Set mask_token_id to mask_weight explicitly | |
| mixed_probs[mask_token_id] = mask_weight | |
| num_clusters = max(1, int((1 - clustering) * num_to_noise)) | |
| cluster_size = max(1, int(num_to_noise / num_clusters)) | |
| noised_indices = set() | |
| for _ in range(num_clusters): | |
| center = rng.integers(answer_start, len(noised)) | |
| span_start = max(answer_start, center - cluster_size // 2) | |
| span_end = min(len(noised), span_start + cluster_size) | |
| noised_indices.update(range(span_start, span_end)) | |
| noised_indices = sorted(list(noised_indices))[:num_to_noise] | |
| noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs) | |
| for idx, val in zip(noised_indices, noise): | |
| noised[idx] = val | |
| return noised, noised_indices | |
| # Add new noising function | |
| def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, eot_weight = 1.0, mask_weight = 0.0, noise_start = 1.0): | |
| noised = input_ids.copy() | |
| answer_len = len(input_ids) - answer_start | |
| num_to_noise = int(threshold * answer_len * noise_start) | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0] | |
| if num_to_noise == 0: | |
| return noised | |
| raw_weights = 1.0 - np.array(confidences[answer_start:]) | |
| # Avoid zero-probability weights for selection | |
| # If noise clipping == 1, all tokens have equal chance to be noised. | |
| # If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction | |
| raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None) | |
| weights = raw_weights / raw_weights.sum() | |
| if num_to_noise > len(weights): | |
| num_to_noise = len(weights) # prevent oversampling | |
| indices = rng.choice( | |
| np.arange(answer_start, len(input_ids)), | |
| size=num_to_noise, | |
| replace=False, | |
| p=weights | |
| ) | |
| mixed_probs = token_probabilities.copy() | |
| # Apply EOT weighting | |
| mixed_probs[eot_token_id] *= eot_weight | |
| # Scale all other probabilities so they sum to 1 - mask_weight | |
| total_other = mixed_probs.sum() - mixed_probs[mask_token_id] | |
| scale = (1.0 - mask_weight) / total_other | |
| mixed_probs *= scale | |
| # Set mask_token_id to mask_weight explicitly | |
| mixed_probs[mask_token_id] = mask_weight | |
| noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs) | |
| for idx, val in zip(indices, noise): | |
| noised[idx] = val | |
| return noised | |
| def generate_diffusion_text(input_ids): | |
| with torch.no_grad(): | |
| input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) | |
| logits = model(input_ids=input_tensor)["logits"] | |
| logits = logits.clamp(min=-1e4, max=1e4) | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| probs = torch.clamp(probs, min=1e-8, max=1.0) | |
| assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!" | |
| assert (probs >= 0).all(), "Negative probs!" | |
| sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() | |
| # Extract confidence of selected tokens | |
| conf = probs[range(len(sampled)), sampled].cpu().numpy() | |
| return sampled, conf | |
| # --- Inference Wrapper --- | |
| def diffusion_chat(question, eot_weight, mask_weight, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping): | |
| placeholder = "What do you know about the city of New York?" | |
| if question.strip() == "": | |
| question = placeholder | |
| print('started generation') | |
| prompt = f"User: {question}\nAssistant:" | |
| prompt = "" | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| answer_start = find_answer_start(input_ids, assistant_marker_ids) | |
| if answer_start is None: | |
| yield "Error: Could not find Assistant marker in input." | |
| return | |
| answer_start = 0 | |
| if len(input_ids) < 256: | |
| input_ids += [pad_token] * (256 - len(input_ids)) | |
| else: | |
| input_ids = input_ids[:256] | |
| ori_input_tokens = input_ids | |
| current_tokens, just_noised_indices = noisify_answer( | |
| input_ids, answer_start, threshold=1.0, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = 1.0, | |
| ) | |
| yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>') | |
| time.sleep(pause_length) | |
| last_tokens = [] | |
| prev_decoded_tokens = [] | |
| for i in range(max_it): | |
| print('Generating output') | |
| # Model step | |
| generated_tokens, confidences = generate_diffusion_text(current_tokens) | |
| # Save full output for noising step | |
| current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] | |
| # --- GREEN HIGHLIGHT --- | |
| decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
| highlighted = [] | |
| for j, tok in enumerate(decoded_tokens): | |
| tok_id = tokenizer.convert_tokens_to_ids(tok) | |
| if tok_id == eot_token_id: | |
| continue | |
| token_str = tokenizer.convert_tokens_to_string([tok]) | |
| if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]: | |
| highlighted.append(f'<span style="color:green">{token_str}</span>') | |
| else: | |
| highlighted.append(token_str) | |
| prev_decoded_tokens = decoded_tokens | |
| yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>') | |
| time.sleep(pause_length) | |
| # --- Early stopping --- | |
| last_tokens.append(current_tokens) | |
| if len(last_tokens) > 3: | |
| last_tokens.pop(0) | |
| if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: | |
| yield f"<b>Stopped early after {i+1} iterations.</b>" | |
| break | |
| previous_tokens = current_tokens.copy() | |
| # --- NOISING STEP --- | |
| threshold = get_noising_schedule(i, max_it, sharpness=sharpness) | |
| if use_confidence_noising: | |
| noised_answer = confidence_guided_noising( | |
| current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, noise_start=noise_start | |
| ) | |
| just_noised_indices = [] | |
| else: | |
| noised_answer, just_noised_indices = noisify_answer( | |
| current_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, mask_weight=mask_weight, clustering=clustering, noise_start = noise_start, | |
| ) | |
| # Compose full input again: prompt + noised answer | |
| current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] | |
| # --- RED HIGHLIGHT --- | |
| decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
| highlighted = [] | |
| for j, tok in enumerate(decoded_tokens): | |
| tok_id = tokenizer.convert_tokens_to_ids(tok) | |
| if tok_id == eot_token_id: | |
| continue | |
| token_str = tokenizer.convert_tokens_to_string([tok]) | |
| abs_idx = answer_start + j | |
| if abs_idx in just_noised_indices: | |
| highlighted.append(f'<span style="color:red">{token_str}</span>') | |
| else: | |
| highlighted.append(token_str) | |
| yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>') | |
| time.sleep(pause_length) | |
| final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
| final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] | |
| final_output = tokenizer.convert_tokens_to_string(final_tokens) | |
| print(final_output) | |
| yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>') | |
| # --- Gradio Interface --- | |
| print("Loading model...") | |
| model = load_model() | |
| print("✅ Model loaded.") | |
| demo = gr.Interface( | |
| fn=diffusion_chat, | |
| inputs=[ | |
| gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"), | |
| gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = longer answers (EOT weight)"), | |
| gr.Slider(0, 1, value=0.5, step=0.05, label="↓ = more random answers (MASK weight)"), | |
| gr.Slider(1, 512, value=32, step=1, label="↑ = more iterations"), | |
| gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"), | |
| gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"), | |
| gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"), | |
| gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="↑ = more noise (noise start)"), | |
| gr.Checkbox(value=False, label="Use confidence-guided noising"), | |
| gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"), | |
| ], | |
| outputs=[gr.HTML(label="Diffusion Output")], | |
| title="Diffusion Language Model Chat", | |
| theme="default", | |
| description="This interface runs a diffusion-based language model to generate answers progressively." | |
| ) | |
| demo.launch(share=True, allowed_paths=["."], ssr_mode=False) | |