Spaces:
Runtime error
Runtime error
fix past
Browse files- app.py +3 -1
- midi_model.py +3 -1
app.py
CHANGED
|
@@ -53,10 +53,11 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
| 53 |
cur_len = input_tensor.shape[1]
|
| 54 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
| 55 |
cache1 = DynamicCache()
|
|
|
|
| 56 |
with bar:
|
| 57 |
while cur_len < max_len:
|
| 58 |
end = [False] * batch_size
|
| 59 |
-
hidden = model.forward(input_tensor[:,
|
| 60 |
next_token_seq = None
|
| 61 |
event_names = [""] * batch_size
|
| 62 |
cache2 = DynamicCache()
|
|
@@ -110,6 +111,7 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
|
|
| 110 |
"constant", value=tokenizer.pad_id)
|
| 111 |
next_token_seq = next_token_seq.unsqueeze(1)
|
| 112 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
|
|
|
| 113 |
cur_len += 1
|
| 114 |
bar.update(1)
|
| 115 |
yield next_token_seq[:, 0].cpu().numpy()
|
|
|
|
| 53 |
cur_len = input_tensor.shape[1]
|
| 54 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
| 55 |
cache1 = DynamicCache()
|
| 56 |
+
past_len = 0
|
| 57 |
with bar:
|
| 58 |
while cur_len < max_len:
|
| 59 |
end = [False] * batch_size
|
| 60 |
+
hidden = model.forward(input_tensor[:, past_len:], cache=cache1)[:, -1]
|
| 61 |
next_token_seq = None
|
| 62 |
event_names = [""] * batch_size
|
| 63 |
cache2 = DynamicCache()
|
|
|
|
| 111 |
"constant", value=tokenizer.pad_id)
|
| 112 |
next_token_seq = next_token_seq.unsqueeze(1)
|
| 113 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
| 114 |
+
past_len = cur_len
|
| 115 |
cur_len += 1
|
| 116 |
bar.update(1)
|
| 117 |
yield next_token_seq[:, 0].cpu().numpy()
|
midi_model.py
CHANGED
|
@@ -160,10 +160,11 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
| 160 |
cur_len = input_tensor.shape[1]
|
| 161 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
| 162 |
cache1 = DynamicCache()
|
|
|
|
| 163 |
with bar:
|
| 164 |
while cur_len < max_len:
|
| 165 |
end = [False] * batch_size
|
| 166 |
-
hidden = self.forward(input_tensor[
|
| 167 |
next_token_seq = None
|
| 168 |
event_names = [""] * batch_size
|
| 169 |
cache2 = DynamicCache()
|
|
@@ -210,6 +211,7 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
| 210 |
"constant", value=tokenizer.pad_id)
|
| 211 |
next_token_seq = next_token_seq.unsqueeze(1)
|
| 212 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
|
|
|
| 213 |
cur_len += 1
|
| 214 |
bar.update(1)
|
| 215 |
|
|
|
|
| 160 |
cur_len = input_tensor.shape[1]
|
| 161 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
| 162 |
cache1 = DynamicCache()
|
| 163 |
+
past_len = 0
|
| 164 |
with bar:
|
| 165 |
while cur_len < max_len:
|
| 166 |
end = [False] * batch_size
|
| 167 |
+
hidden = self.forward(input_tensor[:,past_len:], cache=cache1)[:, -1]
|
| 168 |
next_token_seq = None
|
| 169 |
event_names = [""] * batch_size
|
| 170 |
cache2 = DynamicCache()
|
|
|
|
| 211 |
"constant", value=tokenizer.pad_id)
|
| 212 |
next_token_seq = next_token_seq.unsqueeze(1)
|
| 213 |
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
| 214 |
+
past_len = cur_len
|
| 215 |
cur_len += 1
|
| 216 |
bar.update(1)
|
| 217 |
|