Spaces:
Sleeping
Sleeping
File size: 1,695 Bytes
903b444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# Punctuation restoration — loads Oliver Guhr’s model and restores punctuation in raw text
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# Model
MODEL_NAME = "oliverguhr/fullstop-punctuation-multilang-large"
DEVICE = 0 if torch.cuda.is_available() else -1
print(f"Loading punctuation model ({MODEL_NAME}) on {'GPU' if DEVICE == 0 else 'CPU'}...")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME)
# pipeline for token classification
punctuation_pipeline = pipeline(
"token-classification",
model=model,
tokenizer=tokenizer,
device=DEVICE,
aggregation_strategy="simple"
)
# Main function
def punctuate_text(text: str) -> str:
"""
Restores punctuation in the given text using Oliver Guhr's model.
Returns the punctuated text.
"""
if not text.strip():
return text
try:
results = punctuation_pipeline(text)
punctuated_text = ""
for item in results:
word = item['word'].replace("▁", " ")
label = item['entity_group']
# Map labels to punctuation marks
if label == "COMMA":
punctuated_text += word + ","
elif label == "PERIOD":
punctuated_text += word + "."
elif label == "QUESTION":
punctuated_text += word + "?"
else:
punctuated_text += word
# Clean spacing
return " ".join(punctuated_text.split())
except Exception as e:
print(f"[punctuate_text] Error: {e}")
return text
|