ks-version-1-1 / backend /punctuation.py
NIKKI77's picture
Deploy: GPU-ready HF Space (Docker)
903b444
# Punctuation restoration — loads Oliver Guhr’s model and restores punctuation in raw text
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# Model
MODEL_NAME = "oliverguhr/fullstop-punctuation-multilang-large"
DEVICE = 0 if torch.cuda.is_available() else -1
print(f"Loading punctuation model ({MODEL_NAME}) on {'GPU' if DEVICE == 0 else 'CPU'}...")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME)
# pipeline for token classification
punctuation_pipeline = pipeline(
"token-classification",
model=model,
tokenizer=tokenizer,
device=DEVICE,
aggregation_strategy="simple"
)
# Main function
def punctuate_text(text: str) -> str:
"""
Restores punctuation in the given text using Oliver Guhr's model.
Returns the punctuated text.
"""
if not text.strip():
return text
try:
results = punctuation_pipeline(text)
punctuated_text = ""
for item in results:
word = item['word'].replace("▁", " ")
label = item['entity_group']
# Map labels to punctuation marks
if label == "COMMA":
punctuated_text += word + ","
elif label == "PERIOD":
punctuated_text += word + "."
elif label == "QUESTION":
punctuated_text += word + "?"
else:
punctuated_text += word
# Clean spacing
return " ".join(punctuated_text.split())
except Exception as e:
print(f"[punctuate_text] Error: {e}")
return text