Spaces:
Runtime error
Runtime error
🚧 update for longt5
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- summarize.py +23 -13
summarize.py
CHANGED
|
@@ -15,20 +15,19 @@ def load_model_and_tokenizer(model_name):
|
|
| 15 |
AutoModelForSeq2SeqLM: the model
|
| 16 |
AutoTokenizer: the tokenizer
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 20 |
model_name,
|
| 21 |
# low_cpu_mem_usage=True,
|
| 22 |
# use_cache=False,
|
| 23 |
-
)
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
-
model = model.to("cuda") if torch.cuda.is_available() else model
|
| 26 |
|
| 27 |
-
logging.info(f"Loaded model {model_name}")
|
| 28 |
return model, tokenizer
|
| 29 |
|
| 30 |
|
| 31 |
-
def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
| 32 |
"""
|
| 33 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
| 34 |
|
|
@@ -37,6 +36,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
| 37 |
mask (): the attention mask for the batch
|
| 38 |
model (): the model to use for summarization
|
| 39 |
tokenizer (): the tokenizer to use for summarization
|
|
|
|
| 40 |
|
| 41 |
Returns:
|
| 42 |
str: the summary of the batch
|
|
@@ -52,14 +52,23 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
| 52 |
# put global attention on <s> token
|
| 53 |
global_attention_mask[:, 0] = 1
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
summary = tokenizer.batch_decode(
|
| 64 |
summary_pred_ids.sequences,
|
| 65 |
skip_special_tokens=True,
|
|
@@ -70,6 +79,7 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
| 70 |
return summary, score
|
| 71 |
|
| 72 |
|
|
|
|
| 73 |
def summarize_via_tokenbatches(
|
| 74 |
input_text: str,
|
| 75 |
model,
|
|
|
|
| 15 |
AutoModelForSeq2SeqLM: the model
|
| 16 |
AutoTokenizer: the tokenizer
|
| 17 |
"""
|
| 18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 20 |
model_name,
|
| 21 |
# low_cpu_mem_usage=True,
|
| 22 |
# use_cache=False,
|
| 23 |
+
).to(device)
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 25 |
|
| 26 |
+
logging.info(f"Loaded model {model_name} to {device}")
|
| 27 |
return model, tokenizer
|
| 28 |
|
| 29 |
|
| 30 |
+
def summarize_and_score(ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs):
|
| 31 |
"""
|
| 32 |
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
|
| 33 |
|
|
|
|
| 36 |
mask (): the attention mask for the batch
|
| 37 |
model (): the model to use for summarization
|
| 38 |
tokenizer (): the tokenizer to use for summarization
|
| 39 |
+
is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
|
| 40 |
|
| 41 |
Returns:
|
| 42 |
str: the summary of the batch
|
|
|
|
| 52 |
# put global attention on <s> token
|
| 53 |
global_attention_mask[:, 0] = 1
|
| 54 |
|
| 55 |
+
if is_general_attention_model:
|
| 56 |
+
summary_pred_ids = model.generate(
|
| 57 |
+
input_ids,
|
| 58 |
+
attention_mask=attention_mask,
|
| 59 |
+
output_scores=True,
|
| 60 |
+
return_dict_in_generate=True,
|
| 61 |
+
**kwargs,
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
summary_pred_ids = model.generate(
|
| 65 |
+
input_ids,
|
| 66 |
+
attention_mask=attention_mask,
|
| 67 |
+
global_attention_mask=global_attention_mask,
|
| 68 |
+
output_scores=True,
|
| 69 |
+
return_dict_in_generate=True,
|
| 70 |
+
**kwargs,
|
| 71 |
+
)
|
| 72 |
summary = tokenizer.batch_decode(
|
| 73 |
summary_pred_ids.sequences,
|
| 74 |
skip_special_tokens=True,
|
|
|
|
| 79 |
return summary, score
|
| 80 |
|
| 81 |
|
| 82 |
+
|
| 83 |
def summarize_via_tokenbatches(
|
| 84 |
input_text: str,
|
| 85 |
model,
|