Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| # Introduction | |
| st.title('RNN Character Prediction') | |
| st.write(""" | |
| This app demonstrates how to train a Recurrent Neural Network (RNN) to predict the next character in a given string. | |
| The RNN learns the sequence of characters from a provided text and generates text based on the learned patterns. | |
| You can choose different options for training the RNN and see how it affects the generated text. | |
| """) | |
| # User input for the training string | |
| sequence = st.text_area('Enter the training string:', 'In the vast expanse of the digital realm.') | |
| chars = list(set(sequence)) | |
| data_size, vocab_size = len(sequence), len(chars) | |
| # Create mappings from characters to indices and vice versa | |
| char_to_idx = {ch: i for i, ch in enumerate(chars)} | |
| idx_to_char = {i: ch for i, ch in enumerate(chars)} | |
| # Convert the sequence to indices | |
| indices = np.array([char_to_idx[ch] for ch in sequence]) | |
| class RNN(nn.Module): | |
| def __init__(self, input_size, hidden_size, output_size): | |
| super(RNN, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.i2h = nn.Linear(input_size + hidden_size, hidden_size) | |
| self.i2o = nn.Linear(input_size + hidden_size, output_size) | |
| self.softmax = nn.LogSoftmax(dim=1) | |
| def forward(self, input, hidden): | |
| combined = torch.cat((input, hidden), 1) | |
| hidden = self.i2h(combined) | |
| output = self.i2o(combined) | |
| output = self.softmax(output) | |
| return output, hidden | |
| def init_hidden(self): | |
| return torch.zeros(1, self.hidden_size) | |
| # Hyperparameters | |
| n_hidden = 128 | |
| learning_rate = 0.005 | |
| # Initialize the model, loss function, and optimizer | |
| rnn = RNN(vocab_size, n_hidden, vocab_size) | |
| criterion = nn.NLLLoss() | |
| # Define training options | |
| options = { | |
| 'Quick Train (100 epochs)': 100, | |
| 'Medium Train (500 epochs)': 500, | |
| 'Long Train (1000 epochs)': 1000 | |
| } | |
| train_option = st.selectbox('Select training option:', list(options.keys())) | |
| n_epochs = options[train_option] | |
| optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate) | |
| def char_tensor(char): | |
| if char not in char_to_idx: | |
| raise ValueError(f"Character '{char}' not in vocabulary.") | |
| tensor = torch.zeros(1, vocab_size) | |
| tensor[0][char_to_idx[char]] = 1 | |
| return tensor | |
| # Training function | |
| def train_model(n_epochs): | |
| for epoch in range(n_epochs): | |
| hidden = rnn.init_hidden() | |
| rnn.zero_grad() | |
| loss = 0 | |
| for i in range(data_size - 1): | |
| input_char = char_tensor(sequence[i]) | |
| target_char = torch.tensor([char_to_idx[sequence[i + 1]]], dtype=torch.long) | |
| output, hidden = rnn(input_char, hidden) | |
| loss += criterion(output, target_char) | |
| loss.backward() | |
| optimizer.step() | |
| if epoch % 10 == 0: | |
| st.write(f'Epoch {epoch} loss: {loss.item() / (data_size - 1)}') | |
| st.write("Training complete.") | |
| # Train the model | |
| if st.button('Train Model'): | |
| train_model(n_epochs) | |
| def generate(start_char, predict_len=100): | |
| if start_char not in char_to_idx: | |
| raise ValueError(f"Start character '{start_char}' not in vocabulary.") | |
| hidden = rnn.init_hidden() | |
| input_char = char_tensor(start_char) | |
| predicted_str = start_char | |
| for _ in range(predict_len): | |
| output, hidden = rnn(input_char, hidden) | |
| topv, topi = output.topk(1) | |
| predicted_char_idx = topi[0][0].item() | |
| predicted_char = idx_to_char[predicted_char_idx] | |
| predicted_str += predicted_char | |
| input_char = char_tensor(predicted_char) | |
| return predicted_str | |
| start_char = st.text_input('Enter a starting character:', 'h') | |
| predict_len = st.slider('Select the length of the generated text:', min_value=10, max_value=200, value=50) | |
| if st.button('Generate Text'): | |
| try: | |
| generated_text = generate(start_char, predict_len) | |
| st.write('Generated Text:') | |
| st.text(generated_text) | |
| except ValueError as e: | |
| st.error(str(e)) | |