Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| # Load model and tokenizer | |
| model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| def get_next_token_probs(text): | |
| # Handle empty input | |
| if not text.strip(): | |
| return ["No input text"] * 5 | |
| # Tokenize input | |
| input_ids = tokenizer.encode(text, return_tensors="pt") | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits | |
| # Get probabilities for next token | |
| next_token_logits = logits[0, -1, :] | |
| next_token_probs = torch.softmax(next_token_logits, dim=0) | |
| # Get top-5 tokens and their probabilities | |
| topk_probs, topk_indices = torch.topk(next_token_probs, 20) | |
| topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] | |
| # Format the results as strings | |
| formatted_results = [] | |
| for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)): | |
| # Format probability as percentage with 1 decimal place | |
| prob_percent = f"{prob.item()*100:.1f}%" | |
| # Clean up token display (replace space with visible space symbol) | |
| display_token = token.replace(" ", "␣") | |
| # Format the output string | |
| formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})") | |
| return formatted_results | |
| # Create minimal interface with simpler components | |
| with gr.Blocks(css="footer {display: none}") as demo: | |
| gr.Markdown("### GPT-2 Next Token Predictor") | |
| # Input textbox | |
| input_text = gr.Textbox( | |
| label="Text Input", | |
| placeholder="Type text here...", | |
| value="The weather tomorrow will be" | |
| ) | |
| # Predict button | |
| predict_btn = gr.Button("Predict Next Tokens") | |
| # Simple header for results | |
| gr.Markdown("##### Most likely next tokens:") | |
| # Individual output textboxes for each token | |
| token1 = gr.Markdown() | |
| token2 = gr.Markdown() | |
| token3 = gr.Markdown() | |
| token4 = gr.Markdown() | |
| token5 = gr.Markdown() | |
| token_outputs = [token1, token2, token3, token4, token5] | |
| # Set up button click event | |
| predict_btn.click( | |
| fn=get_next_token_probs, | |
| inputs=input_text, | |
| outputs=token_outputs | |
| ) | |
| # Initialize with default text | |
| demo.load( | |
| fn=get_next_token_probs, | |
| inputs=input_text, | |
| outputs=token_outputs | |
| ) | |
| # Launch the app | |
| demo.launch() |