Spaces:
Sleeping
Sleeping
| # Punctuation restoration — loads Oliver Guhr’s model and restores punctuation in raw text | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
| # Model | |
| MODEL_NAME = "oliverguhr/fullstop-punctuation-multilang-large" | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| print(f"Loading punctuation model ({MODEL_NAME}) on {'GPU' if DEVICE == 0 else 'CPU'}...") | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME) | |
| # pipeline for token classification | |
| punctuation_pipeline = pipeline( | |
| "token-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=DEVICE, | |
| aggregation_strategy="simple" | |
| ) | |
| # Main function | |
| def punctuate_text(text: str) -> str: | |
| """ | |
| Restores punctuation in the given text using Oliver Guhr's model. | |
| Returns the punctuated text. | |
| """ | |
| if not text.strip(): | |
| return text | |
| try: | |
| results = punctuation_pipeline(text) | |
| punctuated_text = "" | |
| for item in results: | |
| word = item['word'].replace("▁", " ") | |
| label = item['entity_group'] | |
| # Map labels to punctuation marks | |
| if label == "COMMA": | |
| punctuated_text += word + "," | |
| elif label == "PERIOD": | |
| punctuated_text += word + "." | |
| elif label == "QUESTION": | |
| punctuated_text += word + "?" | |
| else: | |
| punctuated_text += word | |
| # Clean spacing | |
| return " ".join(punctuated_text.split()) | |
| except Exception as e: | |
| print(f"[punctuate_text] Error: {e}") | |
| return text | |