Spaces:
Running
Running
Fix: Remove duplicates when using more than 1 device
Browse files- translate.py +17 -4
translate.py
CHANGED
|
@@ -134,11 +134,17 @@ def main(
|
|
| 134 |
|
| 135 |
model, data_loader = accelerator.prepare(model, data_loader)
|
| 136 |
|
|
|
|
|
|
|
| 137 |
with tqdm(
|
| 138 |
-
total=total_lines,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
| 140 |
with torch.no_grad():
|
| 141 |
-
for batch in data_loader:
|
| 142 |
batch["input_ids"] = batch["input_ids"]
|
| 143 |
batch["attention_mask"] = batch["attention_mask"]
|
| 144 |
|
|
@@ -157,8 +163,15 @@ def main(
|
|
| 157 |
tgt_text = tokenizer.batch_decode(
|
| 158 |
generated_tokens, skip_special_tokens=True
|
| 159 |
)
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
pbar.update(len(tgt_text))
|
| 164 |
|
|
|
|
| 134 |
|
| 135 |
model, data_loader = accelerator.prepare(model, data_loader)
|
| 136 |
|
| 137 |
+
samples_seen: int = 0
|
| 138 |
+
|
| 139 |
with tqdm(
|
| 140 |
+
total=total_lines,
|
| 141 |
+
desc="Dataset translation",
|
| 142 |
+
leave=True,
|
| 143 |
+
ascii=True,
|
| 144 |
+
disable=(not accelerator.is_main_process),
|
| 145 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
| 146 |
with torch.no_grad():
|
| 147 |
+
for step, batch in enumerate(data_loader):
|
| 148 |
batch["input_ids"] = batch["input_ids"]
|
| 149 |
batch["attention_mask"] = batch["attention_mask"]
|
| 150 |
|
|
|
|
| 163 |
tgt_text = tokenizer.batch_decode(
|
| 164 |
generated_tokens, skip_special_tokens=True
|
| 165 |
)
|
| 166 |
+
if accelerator.is_main_process:
|
| 167 |
+
if step == len(data_loader) - 1:
|
| 168 |
+
tgt_text = tgt_text[
|
| 169 |
+
: len(data_loader.dataset) - samples_seen
|
| 170 |
+
]
|
| 171 |
+
else:
|
| 172 |
+
samples_seen += len(tgt_text)
|
| 173 |
+
|
| 174 |
+
print("\n".join(tgt_text), file=output_file)
|
| 175 |
|
| 176 |
pbar.update(len(tgt_text))
|
| 177 |
|