Spaces:
Running
Running
Update pages/19_RNN_Shakespeare.py
Browse files- pages/19_RNN_Shakespeare.py +61 -55
pages/19_RNN_Shakespeare.py
CHANGED
|
@@ -51,58 +51,64 @@ generate_length = st.number_input("Generated text length:", min_value=50, value=
|
|
| 51 |
if st.button("Train and Generate"):
|
| 52 |
# Data Preparation
|
| 53 |
text = text_data
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
if st.button("Train and Generate"):
|
| 52 |
# Data Preparation
|
| 53 |
text = text_data
|
| 54 |
+
if len(text) <= seq_length:
|
| 55 |
+
st.error("Text data is too short for the given sequence length. Please enter more text data.")
|
| 56 |
+
else:
|
| 57 |
+
chars = sorted(list(set(text)))
|
| 58 |
+
char_to_int = {c: i for i, c in enumerate(chars)}
|
| 59 |
+
int_to_char = {i: c for i, c in enumerate(chars)}
|
| 60 |
+
|
| 61 |
+
# Prepare input-output pairs
|
| 62 |
+
dataX = []
|
| 63 |
+
dataY = []
|
| 64 |
+
for i in range(0, len(text) - seq_length):
|
| 65 |
+
seq_in = text[i:i + seq_length]
|
| 66 |
+
seq_out = text[i + seq_length]
|
| 67 |
+
dataX.append([char_to_int[char] for char in seq_in])
|
| 68 |
+
dataY.append(char_to_int[seq_out])
|
| 69 |
+
|
| 70 |
+
if len(dataX) == 0:
|
| 71 |
+
st.error("Not enough data to create input-output pairs. Please provide more text data.")
|
| 72 |
+
else:
|
| 73 |
+
X = np.reshape(dataX, (len(dataX), seq_length, 1))
|
| 74 |
+
X = X / float(len(chars))
|
| 75 |
+
Y = np.array(dataY)
|
| 76 |
+
|
| 77 |
+
# Convert to PyTorch tensors
|
| 78 |
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
| 79 |
+
Y_tensor = torch.tensor(Y, dtype=torch.long)
|
| 80 |
+
|
| 81 |
+
# Model initialization
|
| 82 |
+
model = LSTMModel(input_size=1, hidden_size=hidden_size, output_size=len(chars), num_layers=num_layers)
|
| 83 |
+
|
| 84 |
+
# Loss and optimizer
|
| 85 |
+
criterion = nn.CrossEntropyLoss()
|
| 86 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 87 |
+
|
| 88 |
+
# Training the model
|
| 89 |
+
for epoch in range(num_epochs):
|
| 90 |
+
h = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))
|
| 91 |
+
epoch_loss = 0
|
| 92 |
+
for i in range(len(dataX)):
|
| 93 |
+
inputs = X_tensor[i].unsqueeze(0)
|
| 94 |
+
targets = Y_tensor[i].unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
# Forward pass
|
| 97 |
+
outputs, h = model(inputs, h)
|
| 98 |
+
h = (h[0].detach(), h[1].detach())
|
| 99 |
+
loss = criterion(outputs, targets)
|
| 100 |
+
|
| 101 |
+
# Backward pass and optimization
|
| 102 |
+
optimizer.zero_grad()
|
| 103 |
+
loss.backward()
|
| 104 |
+
optimizer.step()
|
| 105 |
+
|
| 106 |
+
epoch_loss += loss.item()
|
| 107 |
+
|
| 108 |
+
avg_loss = epoch_loss / len(dataX)
|
| 109 |
+
st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
| 110 |
+
|
| 111 |
+
# Text generation
|
| 112 |
+
generated_text = generate_text(model, start_string, generate_length, char_to_int, int_to_char, num_layers, hidden_size)
|
| 113 |
+
st.subheader("Generated Text")
|
| 114 |
+
st.write(generated_text)
|