Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import time | |
| import random | |
| import importlib | |
| import torch.nn as nn | |
| import os | |
| from IPython.display import display, HTML, Markdown, clear_output | |
| from transformers import AutoTokenizer | |
| rng = np.random.default_rng() | |
| def disable_dropout(model): | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Dropout): | |
| setattr(model, name, nn.Identity()) # Replace Dropout with Identity | |
| return model | |
| def load_trained_model(checkpoint_path: str, base_model_name: str = "meta-llama/Llama-3.2-3B"): | |
| # Load tokenizer + config from saved dir | |
| hf_token = os.getenv("HF_TOKEN") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, | |
| use_fast=True, | |
| token=hf_token, | |
| torch_dtype=torch.float32) | |
| # Step 5: Load the model safely | |
| model = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=False) | |
| # Disable dropout | |
| model = disable_dropout(model) | |
| print("✅ Model successfully loaded from checkpoint:", checkpoint_path) | |
| # Move to correct device | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| # model = model.to(torch.float32) | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer | |
| def filter_logits(logits, top_k=0, top_p=1.0, temperature=1.0): | |
| """ | |
| Vectorized top-k and/or top-p (nucleus) filtering with temperature scaling. | |
| Accepts logits of shape (seq_len, vocab_size) or (1, seq_len, vocab_size), | |
| and returns logits in the same shape. | |
| """ | |
| original_shape = logits.shape | |
| if logits.dim() == 3: | |
| logits = logits.squeeze(0) # shape: (seq_len, vocab_size) | |
| logits = logits.clone() | |
| # --- Temperature scaling --- | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # --- Top-k filtering --- | |
| if top_k > 0 and top_k < logits.size(-1): | |
| topk_vals, _ = torch.topk(logits, top_k, dim=-1) | |
| thresholds = topk_vals[:, -1].unsqueeze(-1) | |
| logits = torch.where(logits < thresholds, torch.full_like(logits, float("-inf")), logits) | |
| # --- Top-p filtering --- | |
| if top_p > 0.0 and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cum_probs = probs.cumsum(dim=-1) | |
| mask = cum_probs > top_p | |
| mask[:, 0] = False # always keep top token | |
| scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter(dim=-1, index=sorted_indices, src=mask) | |
| logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) | |
| # Restore original shape | |
| if original_shape[0] == 1: | |
| logits = logits.unsqueeze(0) | |
| return logits | |
| # --- Utility Functions --- | |
| def decode_tokens_safe(token_ids, tokenizer): | |
| return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") | |
| def find_answer_start(input_ids, marker_ids): | |
| for i in range(len(input_ids) - len(marker_ids) + 1): | |
| if input_ids[i:i + len(marker_ids)] == marker_ids: | |
| return i + len(marker_ids) | |
| return None | |
| def get_noising_schedule(i, max_it, sharpness=5.0): | |
| x = i / max_it | |
| return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) | |
| def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0.5, noise_start = 1.0): | |
| noised = input_ids.copy() | |
| answer_len = len(noised) - answer_start | |
| num_to_noise = int(threshold * answer_len * noise_start) | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0] | |
| if num_to_noise == 0: | |
| return noised, [] | |
| num_clusters = max(1, int((1 - clustering) * num_to_noise)) | |
| cluster_size = max(1, int(num_to_noise / num_clusters)) | |
| noised_indices = set() | |
| for _ in range(num_clusters): | |
| center = rng.integers(answer_start, len(noised)) | |
| span_start = max(answer_start, center - cluster_size // 2) | |
| span_end = min(len(noised), span_start + cluster_size) | |
| noised_indices.update(range(span_start, span_end)) | |
| noised_indices = sorted(list(noised_indices))[:num_to_noise] | |
| for idx in noised_indices: | |
| noised[idx] = mask_token_id | |
| return noised, noised_indices | |
| import torch.nn.functional as F | |
| def noisify_answer_without_remasking(input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0, unmasked_mask=None): | |
| noised = input_ids.copy() | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] | |
| eligible_indices = list(range(answer_start, len(noised))) | |
| if unmasked_mask is not None: | |
| eligible_indices = [i for i in eligible_indices if not unmasked_mask[i]] | |
| answer_len = len(noised) - answer_start | |
| num_to_noise = int(threshold * answer_len * noise_start) | |
| if num_to_noise == 0 or len(eligible_indices) == 0: | |
| return noised, [] | |
| selected = rng.choice(eligible_indices, size=num_to_noise, replace=False).tolist() | |
| for idx in selected: | |
| noised[idx] = mask_token_id | |
| return noised, selected | |
| def confidence_guided_noising(input_ids, answer_start, tokenizer, confidences, noise_clipping, threshold=1.0, noise_start=1.0): | |
| noised = input_ids.copy() | |
| answer_len = len(input_ids) - answer_start | |
| num_to_noise = int(threshold * answer_len * noise_start) | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] | |
| eos_token_id = tokenizer.eos_token_id | |
| if num_to_noise == 0: | |
| return noised, [] | |
| all_indices = np.arange(answer_start, len(input_ids)) | |
| eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id] | |
| non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id] | |
| # Proportionally split how many to noise | |
| num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)) | |
| num_eos_to_noise = num_to_noise - num_non_eos_to_noise | |
| noised_indices = [] | |
| # --- Non-EOS --- | |
| if non_eos_indices: | |
| raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices]) | |
| raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) | |
| weights = raw_weights / raw_weights.sum() | |
| chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights) | |
| noised_indices.extend(chosen.tolist()) | |
| # --- EOS --- | |
| if eos_indices and num_eos_to_noise > 0: | |
| raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices]) | |
| raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) | |
| weights = raw_weights / raw_weights.sum() | |
| chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights) | |
| noised_indices.extend(chosen.tolist()) | |
| for idx in noised_indices: | |
| noised[idx] = mask_token_id | |
| noised_indices = sorted(noised_indices) | |
| return noised, noised_indices | |
| def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0, | |
| eos_token_id=None, eos_boost=0.0): | |
| model.eval() | |
| with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): | |
| input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) | |
| logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size) | |
| # Optionally boost or suppress EOS token | |
| if eos_token_id is not None and eos_boost != 0.0: | |
| logits[:, :, eos_token_id] += eos_boost | |
| # Filter and sample | |
| filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature) | |
| probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size) | |
| probs = torch.clamp(probs, min=1e-8, max=1.0) | |
| sampled = torch.multinomial(probs, num_samples=1).squeeze(-1) | |
| confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1) | |
| return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences | |
| def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name).eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| model.to(device) | |
| full_input = prompt + answer | |
| enc = tokenizer(full_input, return_tensors="pt") | |
| input_ids = enc.input_ids.to(device) | |
| with torch.no_grad(): | |
| labels = input_ids.clone() | |
| prompt_len = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) | |
| labels[0, :prompt_len] = -100 | |
| loss = model(input_ids, labels=labels).loss | |
| return torch.exp(loss).item() | |
| def format_token_colored_inline(token_id, conf, tokenizer, mask_token_id=128000): | |
| token_str = tokenizer.decode([token_id]).replace("\n", "<br>") | |
| # token_str = token_str.replace(" ", " ") # Preserve spaces for inline display | |
| # token_str = token_str.replace("\t", " ") # Replace tabs with spaces | |
| if token_id == mask_token_id: | |
| color = "black" | |
| else: | |
| color = f"hsl({int(conf * 120)}, 100%, 25%)" | |
| return f"<span style='color:{color}' title='Conf: {conf:.2f}'>{token_str}</span>" | |
| def display_diffusion_output(i, max_it, question, ori_input_tokens, generated_tokens, confidences, answer_start, tokenizer): | |
| clear_output(wait=True) | |
| display(Markdown(f"### Iteration {i}/{max_it-1}")) | |
| display(Markdown(f"**Question:** {tokenizer.decode(ori_input_tokens[:answer_start])}")) | |
| mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] | |
| output_html = ''.join([ | |
| format_token_colored_inline(tok, conf, tokenizer, mask_token_id) | |
| for tok, conf in zip(generated_tokens[answer_start:], confidences[answer_start:]) | |
| if tok != 128001 # skip EOT | |
| ]) | |
| output_html = f"<div style='white-space: pre-wrap'>{output_html}</div>" | |
| html = HTML(f"<b>Diffusion Output with Confidence:</b><br><div style='line-height:1.8; white-space: pre-wrap'>{output_html}</div>") | |
| display(html) | |
| return output_html | |
| def save_html_colored_output(filename, html_content): | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(f""" | |
| <html> | |
| <head> | |
| <meta charset="utf-8"> | |
| <style> | |
| body {{ font-family: sans-serif; line-height: 1.6; }} | |
| span {{ padding: 0 2px; }} | |
| </style> | |
| </head> | |
| <body> | |
| {html_content} | |
| </body> | |
| </html> | |
| """) | |
| def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5, | |
| noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0, | |
| temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str: | |
| if eos_token_id is None: | |
| eos_token_id = tokenizer.eos_token_id | |
| # Format prompt with LLaMA 3 chat template | |
| prompt = ( | |
| "<|begin_of_text|>\n" | |
| "<|start_header_id|>system<|end_header_id|>\n" | |
| "You are a helpful assistant.\n" | |
| "<|eot_id|>\n" | |
| "<|start_header_id|>user<|end_header_id|>\n" | |
| f"{question.strip()}\n" | |
| "<|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False) | |
| def find_answer_start(ids, marker): | |
| for i in range(len(ids) - len(marker) + 1): | |
| if ids[i:i+len(marker)] == marker: | |
| return i + len(marker) | |
| return None | |
| answer_start = find_answer_start(input_ids, marker) | |
| if answer_start is None: | |
| raise ValueError("Assistant marker not found in prompt.") | |
| # Pad to max length | |
| pad_token = tokenizer.eos_token_id | |
| mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0] | |
| input_ids = input_ids[:max_length] | |
| if len(input_ids) < max_length: | |
| input_ids += [mask_token] * (max_length - len(input_ids)) | |
| ori_tokens = input_ids | |
| current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token) | |
| last_tokens = [] | |
| for step in range(max_it): | |
| # Generate a new prediction | |
| current_tokens, confidence_scores = generate_diffusion_text( | |
| model, current_tokens, answer_start, | |
| top_k=top_k, top_p=top_p, temperature=temperature, | |
| eos_token_id=eos_token_id, eos_boost=eos_boost | |
| ) | |
| # Display for debugging / tracking | |
| display_diffusion_output( | |
| step, max_it, question, | |
| ori_tokens, current_tokens, confidence_scores, | |
| answer_start, tokenizer | |
| ) | |
| # Early stopping | |
| last_tokens.append(current_tokens) | |
| if len(last_tokens) > 4: | |
| last_tokens.pop(0) | |
| if all(t == last_tokens[0] for t in last_tokens): | |
| break | |
| # Re-apply noise for next iteration | |
| if step < max_it - 1: | |
| threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness) | |
| current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token) | |
| return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip() | |