Spaces:
Sleeping
Sleeping
| import spacy.cli | |
| import errant | |
| class Gramformer: | |
| def __init__(self, models=1, use_gpu=False): | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSeq2SeqLM | |
| # Ensure the SpaCy model 'en_core_web_sm' is downloaded | |
| spacy.cli.download("en_core_web_sm") | |
| # Load the correct SpaCy model for errant | |
| self.annotator = errant.load('en_core_web_sm') | |
| if use_gpu: | |
| device = "cuda:0" | |
| else: | |
| device = "cpu" | |
| batch_size = 1 | |
| self.device = device | |
| correction_model_tag = "prithivida/grammar_error_correcter_v1" | |
| self.model_loaded = False | |
| if models == 1: | |
| self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag, use_auth_token=False) | |
| self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag, use_auth_token=False) | |
| self.correction_model = self.correction_model.to(device) | |
| self.model_loaded = True | |
| print("[Gramformer] Grammar error correct/highlight model loaded..") | |
| elif models == 2: | |
| # TODO: Implement this part | |
| print("TO BE IMPLEMENTED!!!") | |
| def correct(self, input_sentence, max_candidates=1): | |
| if self.model_loaded: | |
| correction_prefix = "gec: " | |
| input_sentence = correction_prefix + input_sentence | |
| input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt') | |
| input_ids = input_ids.to(self.device) | |
| preds = self.correction_model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length=128, | |
| num_beams=7, | |
| early_stopping=True, | |
| num_return_sequences=max_candidates | |
| ) | |
| corrected = set() | |
| for pred in preds: | |
| corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip()) | |
| return corrected | |
| else: | |
| print("Model is not loaded") | |
| return None | |
| def highlight(self, orig, cor): | |
| edits = self._get_edits(orig, cor) | |
| orig_tokens = orig.split() | |
| ignore_indexes = [] | |
| for edit in edits: | |
| edit_type = edit[0] | |
| edit_str_start = edit[1] | |
| edit_spos = edit[2] | |
| edit_epos = edit[3] | |
| edit_str_end = edit[4] | |
| # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion | |
| for i in range(edit_spos + 1, edit_epos): | |
| ignore_indexes.append(i) | |
| if edit_str_start == "": | |
| if edit_spos - 1 >= 0: | |
| new_edit_str = orig_tokens[edit_spos - 1] | |
| edit_spos -= 1 | |
| else: | |
| new_edit_str = orig_tokens[edit_spos + 1] | |
| edit_spos += 1 | |
| if edit_type == "PUNCT": | |
| st = f"<a type='{edit_type}' edit='{edit_str_end}'>{new_edit_str}</a>" | |
| else: | |
| st = f"<a type='{edit_type}' edit='{new_edit_str} {edit_str_end}'>{new_edit_str}</a>" | |
| orig_tokens[edit_spos] = st | |
| elif edit_str_end == "": | |
| st = f"<d type='{edit_type}' edit=''>{edit_str_start}</d>" | |
| orig_tokens[edit_spos] = st | |
| else: | |
| st = f"<c type='{edit_type}' edit='{edit_str_end}'>{edit_str_start}</c>" | |
| orig_tokens[edit_spos] = st | |
| for i in sorted(ignore_indexes, reverse=True): | |
| del orig_tokens[i] | |
| return " ".join(orig_tokens) | |
| def detect(self, input_sentence): | |
| # TO BE IMPLEMENTED | |
| pass | |
| def _get_edits(self, orig, cor): | |
| orig = self.annotator.parse(orig) | |
| cor = self.annotator.parse(cor) | |
| alignment = self.annotator.align(orig, cor) | |
| edits = self.annotator.merge(alignment) | |
| if len(edits) == 0: | |
| return [] | |
| edit_annotations = [] | |
| for e in edits: | |
| e = self.annotator.classify(e) | |
| edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end)) | |
| if len(edit_annotations) > 0: | |
| return edit_annotations | |
| else: | |
| return [] | |
| def get_edits(self, orig, cor): | |
| return self._get_edits(orig, cor) | |