Update donut_inference.py
Browse files- donut_inference.py +3 -2
donut_inference.py
CHANGED
|
@@ -26,7 +26,7 @@ def inference(image):
|
|
| 26 |
|
| 27 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
# model.to(device)
|
| 29 |
-
|
| 30 |
outputs = model.generate(pixel_values.to(device),
|
| 31 |
decoder_input_ids=decoder_input_ids.to(device),
|
| 32 |
max_length=model.decoder.config.max_position_embeddings,
|
|
@@ -38,11 +38,12 @@ def inference(image):
|
|
| 38 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 39 |
return_dict_in_generate=True,
|
| 40 |
output_scores=True,)
|
| 41 |
-
|
| 42 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
| 43 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
| 44 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
| 45 |
print(processor.token2json(sequence))
|
|
|
|
| 46 |
return processor.token2json(sequence)
|
| 47 |
|
| 48 |
# data = inference(image)
|
|
|
|
| 26 |
|
| 27 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
# model.to(device)
|
| 29 |
+
start_time = time.time()
|
| 30 |
outputs = model.generate(pixel_values.to(device),
|
| 31 |
decoder_input_ids=decoder_input_ids.to(device),
|
| 32 |
max_length=model.decoder.config.max_position_embeddings,
|
|
|
|
| 38 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 39 |
return_dict_in_generate=True,
|
| 40 |
output_scores=True,)
|
| 41 |
+
end_time = time.time()
|
| 42 |
sequence = processor.batch_decode(outputs.sequences)[0]
|
| 43 |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
| 44 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
| 45 |
print(processor.token2json(sequence))
|
| 46 |
+
print(f"Donut Inference time {start_time-end_time}")
|
| 47 |
return processor.token2json(sequence)
|
| 48 |
|
| 49 |
# data = inference(image)
|