| from typing import Dict | |
| import torch | |
| from .config_manager import ConfigManager | |
| class Diacritizer: | |
| def __init__( | |
| self, config_path: str, model_kind: str, load_model: bool = False | |
| ) -> None: | |
| self.config_path = config_path | |
| self.model_kind = model_kind | |
| self.config_manager = ConfigManager( | |
| config_path=config_path, model_kind=model_kind | |
| ) | |
| self.config = self.config_manager.config | |
| self.text_encoder = self.config_manager.text_encoder | |
| if self.config.get("device"): | |
| self.device = self.config["device"] | |
| else: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if load_model: | |
| self.model, self.global_step = self.config_manager.load_model() | |
| self.model = self.model.to(self.device) | |
| self.start_symbol_id = self.text_encoder.start_symbol_id | |
| def set_model(self, model: torch.nn.Module): | |
| self.model = model | |
| def diacritize_text(self, text: str): | |
| seq = self.text_encoder.input_to_sequence(text) | |
| output = self.diacritize_batch(torch.LongTensor([seq]).to(self.device)) | |
| def diacritize_batch(self, batch): | |
| raise NotImplementedError() | |
| def diacritize_iterators(self, iterator): | |
| pass | |
| class CBHGDiacritizer(Diacritizer): | |
| def diacritize_batch(self, batch): | |
| self.model.eval() | |
| inputs = batch["src"] | |
| lengths = batch["lengths"] | |
| outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
| diacritics = outputs["diacritics"] | |
| predictions = torch.max(diacritics, 2).indices | |
| sentences = [] | |
| for src, prediction in zip(inputs, predictions): | |
| sentence = self.text_encoder.combine_text_and_haraqat( | |
| list(src.detach().cpu().numpy()), | |
| list(prediction.detach().cpu().numpy()), | |
| ) | |
| sentences.append(sentence) | |
| return sentences | |
| class Seq2SeqDiacritizer(Diacritizer): | |
| def diacritize_batch(self, batch): | |
| self.model.eval() | |
| inputs = batch["src"] | |
| lengths = batch["lengths"] | |
| outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
| diacritics = outputs["diacritics"] | |
| predictions = torch.max(diacritics, 2).indices | |
| sentences = [] | |
| for src, prediction in zip(inputs, predictions): | |
| sentence = self.text_encoder.combine_text_and_haraqat( | |
| list(src.detach().cpu().numpy()), | |
| list(prediction.detach().cpu().numpy()), | |
| ) | |
| sentences.append(sentence) | |
| return sentences | |
| class GPTDiacritizer(Diacritizer): | |
| def diacritize_batch(self, batch): | |
| self.model.eval() | |
| inputs = batch["src"] | |
| lengths = batch["lengths"] | |
| outputs = self.model(inputs.to(self.device), lengths.to("cpu")) | |
| diacritics = outputs["diacritics"] | |
| predictions = torch.max(diacritics, 2).indices | |
| sentences = [] | |
| for src, prediction in zip(inputs, predictions): | |
| sentence = self.text_encoder.combine_text_and_haraqat( | |
| list(src.detach().cpu().numpy()), | |
| list(prediction.detach().cpu().numpy()), | |
| ) | |
| sentences.append(sentence) | |
| return sentences | |