Spaces:
Runtime error
Runtime error
| import re | |
| import numpy as np | |
| import fasttext | |
| import sentencepiece | |
| import kenlm | |
| import pathlib | |
| from languages_id import langs_id | |
| from parameters_filtering import parameters_filtering | |
| from normalization import normalization | |
| from stopwords import stopwords | |
| from badwords import badwords | |
| class LoadParameters: | |
| def load_parameters(lang_dataset_id): | |
| if lang_dataset_id in parameters_filtering: | |
| param = parameters_filtering[lang_dataset_id] | |
| else: | |
| param = parameters_filtering["default"] | |
| return param | |
| def load_stopwords(lang_dataset_id): | |
| stopwords_lang_id = langs_id.loc[ | |
| langs_id["dataset_id"] == lang_dataset_id, "stopwords_id" | |
| ].iloc[0] | |
| if stopwords_lang_id: | |
| stopwords_lang = set(stopwords[stopwords_lang_id]) | |
| else: | |
| stopwords_lang = None | |
| return stopwords_lang | |
| def load_badwords(lang_dataset_id): | |
| badwords_lang_id = langs_id.loc[ | |
| langs_id["dataset_id"] == lang_dataset_id, "badwords_id" | |
| ].iloc[0] | |
| if badwords_lang_id: | |
| badwords_lang = set(badwords[badwords_lang_id]) | |
| else: | |
| badwords_lang = None | |
| return badwords_lang | |
| def load_model_lang_id(lang_dataset_id, path_fasttext_model): | |
| fasttext_lang_id = langs_id.loc[ | |
| langs_id["dataset_id"] == lang_dataset_id, "fasttext_id" | |
| ].iloc[0] | |
| if fasttext_lang_id: | |
| model_lang_id = fasttext.load_model(path_fasttext_model) | |
| else: | |
| model_lang_id = None | |
| return model_lang_id | |
| def load_sentencepiece_model(lang_dataset_id, path_sentencepiece_model): | |
| sentencepiece_lang_id = langs_id.loc[ | |
| langs_id["dataset_id"] == lang_dataset_id, "sentencepiece_id" | |
| ].iloc[0] | |
| if sentencepiece_lang_id: | |
| sentencepiece_model = sentencepiece.SentencePieceProcessor() | |
| sentencepiece_model.load(path_sentencepiece_model) | |
| else: | |
| sentencepiece_model = None | |
| return sentencepiece_model | |
| def load_kenlm_model(lang_dataset_id, path_kenlm_model): | |
| kenlm_lang_id = langs_id.loc[ | |
| langs_id["dataset_id"] == lang_dataset_id, "kenlm_id" | |
| ].iloc[0] | |
| if kenlm_lang_id: | |
| kenlm_model = kenlm.Model(path_kenlm_model) | |
| else: | |
| kenlm_model = None | |
| return kenlm_model | |
| class ModifyingDocuments: | |
| def remove_empty_el_from_list(list_): | |
| return [el for el in list_ if el] | |
| def remove_non_printing_characters(document, non_printing_characters_re): | |
| return non_printing_characters_re.sub("", document) | |
| def uniform_whitespace( | |
| document, | |
| whitespace=[ | |
| " ", | |
| "β", | |
| "β", | |
| "β―", | |
| "β ", | |
| "γ", | |
| "β", | |
| "Β ", | |
| "β", | |
| "β", | |
| "οΏΌ", | |
| "Β", | |
| ], | |
| ): | |
| """There are different whitespace characters.""" | |
| whitespace = set(whitespace) | |
| document = "".join( | |
| [char if char not in whitespace else " " for char in document] | |
| ) | |
| return document | |
| def replace_digits_with_zeros(document, digits_re): | |
| return digits_re.sub("0", document) | |
| def replace_unicode_punctuation(document, unicode_punctuation): | |
| return "".join(unicode_punctuation.get(c, c) for c in document) | |
| def normalization( | |
| document, | |
| remove_non_printing_characters, | |
| strip, | |
| lower_case, | |
| uniform_whitespace, | |
| replace_digits_with_zeros, | |
| replace_unicode_punctuation, | |
| non_printing_characters_re=normalization["non_printing_characters_re"], | |
| digits_re=normalization["digits_re"], | |
| unicode_punctuation=normalization["unicode_punctuation"], | |
| ): | |
| if remove_non_printing_characters: | |
| document = ModifyingDocuments.remove_non_printing_characters( | |
| document, non_printing_characters_re | |
| ) | |
| if strip: | |
| document = document.strip() | |
| if not document: | |
| return document | |
| if lower_case: | |
| document = document.lower() | |
| if uniform_whitespace: | |
| document = ModifyingDocuments.uniform_whitespace(document) | |
| if replace_digits_with_zeros: | |
| document = ModifyingDocuments.replace_digits_with_zeros(document, digits_re) | |
| if replace_unicode_punctuation: | |
| document = ModifyingDocuments.replace_unicode_punctuation( | |
| document, unicode_punctuation | |
| ) | |
| return document | |
| def tokenization(document, sentencepiece_model, join_on_whitespace): | |
| document_tokenized = sentencepiece_model.encode_as_pieces(document) | |
| if join_on_whitespace: | |
| document_tokenized = " ".join(document_tokenized) | |
| return document_tokenized | |
| def split_on_whitespace( | |
| document, | |
| new_line=False, | |
| tab=False, | |
| ): | |
| """This method also removes concatenated spaces.""" | |
| sep = [" "] + new_line * ["\n"] + tab * ["\t"] | |
| sep = "|".join(sep) | |
| split_document = re.split(sep, document) | |
| split_document = ModifyingDocuments.remove_empty_el_from_list(split_document) | |
| return split_document | |
| def strip(document, strip_characters): | |
| """Way faster than document.strip(strip_characters) | |
| since strip_characters is now a set instead of a str, | |
| and it contains a lot of elements (all the emojis).""" | |
| if not document: | |
| return document | |
| beg_ind = 0 | |
| end_ind = len(document) | |
| for i in range(len(document)): | |
| if document[i] in strip_characters: | |
| beg_ind += 1 | |
| else: | |
| break | |
| for i in range(1, len(document) + 1): | |
| if document[-i] in strip_characters: | |
| end_ind -= 1 | |
| else: | |
| break | |
| document_stripped = document[beg_ind:end_ind] | |
| return document_stripped | |
| def get_words_from_document( | |
| document, sentencepiece_model_tok, lower_case, strip_characters | |
| ): | |
| """Get words from a document. Non reversible since the document | |
| is split on multiple characters, words are stripped of | |
| special characters and characters are converted to lower case. | |
| Useful to compute ratios, like the stopwords ratio.""" | |
| if sentencepiece_model_tok: | |
| document_normalized = ModifyingDocuments.normalization( | |
| document=document, | |
| remove_non_printing_characters=True, | |
| strip=True, | |
| lower_case=True, | |
| uniform_whitespace=True, | |
| replace_digits_with_zeros=True, | |
| replace_unicode_punctuation=True, | |
| ) | |
| words = ModifyingDocuments.tokenization( | |
| document_normalized, sentencepiece_model_tok, join_on_whitespace=False | |
| ) | |
| else: | |
| words = ModifyingDocuments.split_on_whitespace( | |
| document, new_line=True, tab=True | |
| ) | |
| if lower_case: | |
| words = [word.lower() for word in words] | |
| if strip_characters: | |
| words = [ModifyingDocuments.strip(word, strip_characters) for word in words] | |
| words = ModifyingDocuments.remove_empty_el_from_list(words) | |
| return words | |
| def words_augmentation(words, group_size, join_char): | |
| """Augment words, especially for Chinese (without a space between words) | |
| and Vietnamese (with a space between syllables).""" | |
| augmentation = [ | |
| join_char.join(words[i : i + group_size]) | |
| for i in range(len(words) - group_size + 1) | |
| ] | |
| return augmentation | |
| def split_on_newline_tab_whitespace(document): | |
| """First split on "\n", then on "\t", then on " ".""" | |
| sentences = document.split("\n") | |
| sentences = [sentence.split("\t") for sentence in sentences] | |
| sentences = [ | |
| [ | |
| ModifyingDocuments.split_on_whitespace(subsentence) | |
| for subsentence in sentence | |
| ] | |
| for sentence in sentences | |
| ] | |
| return sentences | |
| def merge_on_whitespace_tab_newline(sentences): | |
| """Invert the method split_on_newline_tab_whitespace. | |
| Removes concatenated separators.""" | |
| sentences = [ | |
| [" ".join(subsentence) for subsentence in sentence if subsentence] | |
| for sentence in sentences | |
| ] | |
| sentences = ["\t".join(sentence) for sentence in sentences if sentence] | |
| if not sentences: | |
| return "" | |
| document = "\n".join(sentences) | |
| return document | |
| def should_keep_word_with_incorrect_substrings( | |
| word, strip_characters, incorrect_word_substrings | |
| ): | |
| word = ModifyingDocuments.strip(word, strip_characters) | |
| should_keep = all( | |
| [(i_substr not in word) for i_substr in incorrect_word_substrings] | |
| ) | |
| return should_keep | |
| def remove_words_with_incorrect_substrings( | |
| document, | |
| strip_characters, | |
| incorrect_word_substrings, | |
| ): | |
| sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) | |
| sentences = [ | |
| [ | |
| [ | |
| word | |
| for word in subsentence | |
| if ModifyingDocuments.should_keep_word_with_incorrect_substrings( | |
| word, strip_characters, incorrect_word_substrings | |
| ) | |
| ] | |
| for subsentence in sentence | |
| ] | |
| for sentence in sentences | |
| ] | |
| document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) | |
| return document | |
| def should_keep_long_word(word, strip_characters, length_word_max_cutoff): | |
| """If the word is too long but it contains only one | |
| special character, it might be a concatenation of one word, | |
| a punctuation, and another word, with no space between them. | |
| In this case, we give the word a pass.""" | |
| if len(word) <= length_word_max_cutoff: | |
| return True | |
| word = ModifyingDocuments.strip(word, strip_characters) | |
| if not word: # The word consisted only of strip characters | |
| return False | |
| if len(word) <= length_word_max_cutoff: | |
| return True | |
| return False | |
| def remove_long_words( | |
| document, | |
| strip_characters, | |
| length_word_max_cutoff, | |
| ): | |
| sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document) | |
| sentences = [ | |
| [ | |
| [ | |
| word | |
| for word in subsentence | |
| if ModifyingDocuments.should_keep_long_word( | |
| word, | |
| strip_characters, | |
| length_word_max_cutoff, | |
| ) | |
| ] | |
| for subsentence in sentence | |
| ] | |
| for sentence in sentences | |
| ] | |
| document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences) | |
| return document | |
| def modifying_documents( | |
| document, | |
| cond_uniform_whitespace, | |
| cond_replace_unicode_punctuation, | |
| cond_remove_words_with_incorrect_substrings, | |
| strip_characters, | |
| incorrect_word_substrings, | |
| cond_remove_long_words, | |
| length_word_max_cutoff, | |
| ): | |
| document = ModifyingDocuments.normalization( | |
| document=document, | |
| remove_non_printing_characters=False, | |
| strip=True, | |
| lower_case=False, | |
| uniform_whitespace=cond_uniform_whitespace, | |
| replace_digits_with_zeros=False, | |
| replace_unicode_punctuation=cond_replace_unicode_punctuation, | |
| ) | |
| if cond_remove_words_with_incorrect_substrings: | |
| document = ModifyingDocuments.remove_words_with_incorrect_substrings( | |
| document, | |
| strip_characters, | |
| incorrect_word_substrings, | |
| ) | |
| if cond_remove_long_words: | |
| document = ModifyingDocuments.remove_long_words( | |
| document, | |
| strip_characters, | |
| length_word_max_cutoff, | |
| ) | |
| return document | |
| class FunctionDatasetModifyingDocuments: | |
| def __init__(self, lang_dataset_id): | |
| self.lang_dataset_id = lang_dataset_id | |
| self.param = LoadParameters.load_parameters(lang_dataset_id) | |
| def __call__(self, example): | |
| example["text"] = ModifyingDocuments.modifying_documents( | |
| document=example["text"], | |
| cond_uniform_whitespace=self.param["cond_uniform_whitespace"], | |
| cond_replace_unicode_punctuation=self.param[ | |
| "cond_replace_unicode_punctuation" | |
| ], | |
| cond_remove_words_with_incorrect_substrings=self.param[ | |
| "cond_remove_words_with_incorrect_substrings" | |
| ], | |
| strip_characters=self.param["strip_characters"], | |
| incorrect_word_substrings=self.param["incorrect_word_substrings"], | |
| cond_remove_long_words=self.param["cond_remove_long_words"], | |
| length_word_max_cutoff=self.param["length_word_max_cutoff"], | |
| ) | |
| return example | |
| def __reduce__(self): | |
| return (self.__class__, (self.lang_dataset_id,)) | |
| class Filtering: | |
| def check_number_words( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| number_words_min_cutoff, | |
| number_words_max_cutoff, | |
| ): | |
| words = ModifyingDocuments.get_words_from_document( | |
| document, | |
| sentencepiece_model_tok, | |
| lower_case=False, | |
| strip_characters=strip_characters, | |
| ) | |
| cond = (len(words) >= number_words_min_cutoff) and ( | |
| len(words) <= number_words_max_cutoff | |
| ) | |
| return cond | |
| def compute_repetitions_ratio(document, repetitions_length): | |
| def get_freq_ngrams(document, n): | |
| ngrams = [document[i : i + n] for i in range(len(document) - n + 1)] | |
| freq_ngrams = {} | |
| for ngram in ngrams: | |
| freq_ngrams[ngram] = freq_ngrams.get(ngram, 0) + 1 | |
| return freq_ngrams | |
| freq_ngrams = get_freq_ngrams(document, repetitions_length) | |
| if len(freq_ngrams) == 0: | |
| return 0 | |
| freq_ngrams = list(freq_ngrams.values()) | |
| freq_ngrams = sorted(freq_ngrams, reverse=True) | |
| num_rep_ngrams = int(np.sqrt(len(freq_ngrams))) | |
| repetitions_ratio = sum(freq_ngrams[:num_rep_ngrams]) / sum(freq_ngrams) | |
| return repetitions_ratio | |
| def check_repetitions_removal( | |
| document, | |
| repetitions_length, | |
| repetitions_max_cutoff, | |
| ): | |
| repetitions_ratio = Filtering.compute_repetitions_ratio( | |
| document, repetitions_length | |
| ) | |
| cond = repetitions_ratio <= repetitions_max_cutoff | |
| return cond | |
| def compute_special_characters_ratio(document, special_characters): | |
| special_characters_ratio = len( | |
| [char for char in document if char in special_characters] | |
| ) / len(document) | |
| return special_characters_ratio | |
| def check_special_characters( | |
| document, | |
| special_characters, | |
| special_characters_max_cutoff, | |
| ): | |
| special_characters_ratio = Filtering.compute_special_characters_ratio( | |
| document, special_characters | |
| ) | |
| cond = special_characters_ratio <= special_characters_max_cutoff | |
| return cond | |
| def compute_stopwords_ratio( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| stopwords, | |
| ): | |
| words = ModifyingDocuments.get_words_from_document( | |
| document, | |
| sentencepiece_model_tok, | |
| lower_case=True, | |
| strip_characters=strip_characters, | |
| ) | |
| if not words: | |
| return 0 | |
| augmentation = [] | |
| if cond_words_augmentation: | |
| augmentation = [ | |
| ModifyingDocuments.words_augmentation( | |
| words, group_size, words_augmentation_join_char | |
| ) | |
| for group_size in words_augmentation_group_sizes | |
| ] | |
| augmentation = [word for augm in augmentation for word in augm] | |
| stopwords_ratio = len( | |
| [word for word in words + augmentation if word in stopwords] | |
| ) / len(words) | |
| if stopwords_ratio > 1.0: | |
| stopwords_ratio = 1.0 | |
| return stopwords_ratio | |
| def check_stopwords( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| stopwords, | |
| stopwords_min_cutoff, | |
| ): | |
| cond = True | |
| if stopwords: | |
| stopwords_ratio = Filtering.compute_stopwords_ratio( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| stopwords, | |
| ) | |
| cond = stopwords_ratio >= stopwords_min_cutoff | |
| return cond | |
| def compute_badwords_ratio( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| badwords, | |
| ): | |
| words = ModifyingDocuments.get_words_from_document( | |
| document, | |
| sentencepiece_model_tok, | |
| lower_case=True, | |
| strip_characters=strip_characters, | |
| ) | |
| if not words: | |
| return 0 | |
| augmentation = [] | |
| if cond_words_augmentation: | |
| augmentation = [ | |
| ModifyingDocuments.words_augmentation( | |
| words, group_size, words_augmentation_join_char | |
| ) | |
| for group_size in words_augmentation_group_sizes | |
| ] | |
| augmentation = [word for augm in augmentation for word in augm] | |
| badwords_ratio = len( | |
| [word for word in words + augmentation if word in badwords] | |
| ) / len(words) | |
| if badwords_ratio > 1.0: | |
| badwords_ratio = 1.0 | |
| for word in augmentation: | |
| if word in badwords: | |
| print(word) | |
| return badwords_ratio | |
| def check_badwords( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| badwords, | |
| badwords_max_cutoff, | |
| ): | |
| cond = True | |
| if badwords: | |
| badwords_ratio = Filtering.compute_badwords_ratio( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| badwords, | |
| ) | |
| cond = badwords_ratio <= badwords_max_cutoff | |
| return cond | |
| def compute_lang_id_pred_score(document, model_lang_id): | |
| document = document.lower().replace("\n", " ") | |
| pred = model_lang_id.predict(document) | |
| lang_pred_fasttext_id = pred[0][0].replace("__label__", "") | |
| score_pred = pred[1][0] | |
| lang_pred_dataset_id = langs_id.loc[ | |
| langs_id["fasttext_id"] == lang_pred_fasttext_id, "dataset_id" | |
| ] | |
| if len(lang_pred_dataset_id) > 0: | |
| lang_pred_dataset_id = lang_pred_dataset_id.iloc[0] | |
| else: | |
| lang_pred_dataset_id = "unknown" | |
| return lang_pred_dataset_id, score_pred | |
| def check_lang_id( | |
| document, | |
| lang_dataset_id, | |
| model_lang_id, | |
| lang_id_min_cutoff, | |
| ): | |
| cond = True | |
| if model_lang_id: | |
| lang_pred_dataset_id, score_pred = Filtering.compute_lang_id_pred_score( | |
| document, model_lang_id | |
| ) | |
| cond = (lang_pred_dataset_id == lang_dataset_id) and ( | |
| score_pred >= lang_id_min_cutoff | |
| ) | |
| return cond | |
| def compute_perplexity_score(document, sentencepiece_model, kenlm_model): | |
| document = ModifyingDocuments.normalization( | |
| document=document, | |
| remove_non_printing_characters=True, | |
| strip=True, | |
| lower_case=True, | |
| uniform_whitespace=True, | |
| replace_digits_with_zeros=True, | |
| replace_unicode_punctuation=True, | |
| ) | |
| document = ModifyingDocuments.tokenization( | |
| document, sentencepiece_model, join_on_whitespace=True | |
| ) | |
| doc_log_score, doc_length = 0, 0 | |
| for line in document.split("\n"): | |
| log_score = kenlm_model.score(line) | |
| length = len(line.split()) + 1 | |
| doc_log_score += log_score | |
| doc_length += length | |
| pp_score = 10.0 ** (-doc_log_score / doc_length) | |
| pp_score = round(pp_score, 1) | |
| return pp_score | |
| def check_perplexity( | |
| document, | |
| sentencepiece_model, | |
| kenlm_model, | |
| perplexity_max_cutoff, | |
| ): | |
| cond = True | |
| if kenlm_model: | |
| score = Filtering.compute_perplexity_score( | |
| document, sentencepiece_model, kenlm_model | |
| ) | |
| cond = score <= perplexity_max_cutoff | |
| return cond | |
| def filtering( | |
| document, | |
| cond_check_number_words, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| number_words_min_cutoff, | |
| number_words_max_cutoff, | |
| cond_check_repetitions_removal, | |
| repetitions_length, | |
| repetitions_max_cutoff, | |
| cond_check_special_characters, | |
| special_characters, | |
| special_characters_max_cutoff, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| cond_check_stopwords, | |
| stopwords, | |
| stopwords_min_cutoff, | |
| cond_check_badwords, | |
| badwords, | |
| badwords_max_cutoff, | |
| cond_check_lang_id, | |
| lang_dataset_id, | |
| model_lang_id, | |
| lang_id_min_cutoff, | |
| cond_check_perplexity, | |
| sentencepiece_model, | |
| kenlm_model, | |
| perplexity_max_cutoff, | |
| ): | |
| if cond_check_number_words: | |
| if not Filtering.check_number_words( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| number_words_min_cutoff, | |
| number_words_max_cutoff, | |
| ): | |
| return False | |
| if cond_check_repetitions_removal: | |
| if not Filtering.check_repetitions_removal( | |
| document, | |
| repetitions_length, | |
| repetitions_max_cutoff, | |
| ): | |
| return False | |
| if cond_check_special_characters: | |
| if not Filtering.check_special_characters( | |
| document, | |
| special_characters, | |
| special_characters_max_cutoff, | |
| ): | |
| return False | |
| if cond_check_stopwords: | |
| if not Filtering.check_stopwords( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| stopwords, | |
| stopwords_min_cutoff, | |
| ): | |
| return False | |
| if cond_check_badwords: | |
| if not Filtering.check_badwords( | |
| document, | |
| sentencepiece_model_tok, | |
| strip_characters, | |
| cond_words_augmentation, | |
| words_augmentation_group_sizes, | |
| words_augmentation_join_char, | |
| badwords, | |
| badwords_max_cutoff, | |
| ): | |
| return False | |
| if cond_check_lang_id: | |
| if not Filtering.check_lang_id( | |
| document, | |
| lang_dataset_id, | |
| model_lang_id, | |
| lang_id_min_cutoff, | |
| ): | |
| return False | |
| if cond_check_perplexity: | |
| if not Filtering.check_perplexity( | |
| document, | |
| sentencepiece_model, | |
| kenlm_model, | |
| perplexity_max_cutoff, | |
| ): | |
| return False | |
| return True | |
| class FunctionDatasetFiltering: | |
| def __init__( | |
| self, | |
| lang_dataset_id, | |
| path_fasttext_model, | |
| path_sentencepiece_model, | |
| path_kenlm_model, | |
| ): | |
| self.lang_dataset_id = lang_dataset_id | |
| self.path_fasttext_model = path_fasttext_model | |
| self.path_sentencepiece_model = path_sentencepiece_model | |
| self.path_kenlm_model = path_kenlm_model | |
| self.param = LoadParameters.load_parameters(lang_dataset_id) | |
| self.stopwords = LoadParameters.load_stopwords(lang_dataset_id) | |
| self.badwords = LoadParameters.load_badwords(lang_dataset_id) | |
| self.model_lang_id = LoadParameters.load_model_lang_id( | |
| lang_dataset_id, path_fasttext_model | |
| ) | |
| self.sentencepiece_model = LoadParameters.load_sentencepiece_model( | |
| lang_dataset_id, path_sentencepiece_model | |
| ) | |
| self.sentencepiece_model_tok = ( | |
| self.sentencepiece_model if self.param["tokenization"] else None | |
| ) | |
| self.kenlm_model = LoadParameters.load_kenlm_model( | |
| lang_dataset_id, path_kenlm_model | |
| ) | |
| def __call__(self, example): | |
| keep_example = Filtering.filtering( | |
| document=example["text"], | |
| cond_check_number_words=self.param["cond_check_number_words"], | |
| sentencepiece_model_tok=self.sentencepiece_model_tok, | |
| strip_characters=self.param["strip_characters"], | |
| number_words_min_cutoff=self.param["number_words_min_cutoff"], | |
| number_words_max_cutoff=self.param["number_words_max_cutoff"], | |
| cond_check_repetitions_removal=self.param["check_repetitions_removal"], | |
| repetitions_length=self.param["repetitions_length"], | |
| repetitions_max_cutoff=self.param["repetitions_max_cutoff"], | |
| cond_check_special_characters=self.param["cond_check_special_characters"], | |
| special_characters=self.param["special_characters"], | |
| special_characters_max_cutoff=self.param["special_characters_max_cutoff"], | |
| cond_words_augmentation=self.param["cond_words_augmentation"], | |
| words_augmentation_group_sizes=self.param["words_augmentation_group_sizes"], | |
| words_augmentation_join_char=self.param["words_augmentation_join_char"], | |
| cond_check_stopwords=self.param["cond_check_stopwords"], | |
| stopwords=self.stopwords, | |
| stopwords_min_cutoff=self.param["stopwords_min_cutoff"], | |
| cond_check_badwords=self.param["cond_check_badwords"], | |
| badwords=self.badwords, | |
| badwords_max_cutoff=self.param["badwords_max_cutoff"], | |
| cond_check_lang_id=self.param["cond_check_lang_id"], | |
| lang_dataset_id=self.lang_dataset_id, | |
| model_lang_id=self.model_lang_id, | |
| lang_id_min_cutoff=self.param["lang_id_min_cutoff"], | |
| cond_check_perplexity=self.param["cond_check_perplexity"], | |
| sentencepiece_model=self.sentencepiece_model, | |
| kenlm_model=self.kenlm_model, | |
| perplexity_max_cutoff=self.param["perplexity_max_cutoff"], | |
| ) | |
| return keep_example | |
| def __reduce__(self): | |
| return ( | |
| self.__class__, | |
| ( | |
| self.lang_dataset_id, | |
| self.path_fasttext_model, | |
| self.path_sentencepiece_model, | |
| self.path_kenlm_model, | |
| ), | |
| ) | |
| class DatasetFiltering: | |
| def __init__( | |
| self, | |
| dataset, | |
| lang_dataset_id, | |
| path_fasttext_model, | |
| path_sentencepiece_model, | |
| path_kenlm_model, | |
| num_proc, | |
| path_dir_save_dataset, | |
| ): | |
| self.ds = dataset | |
| self.lang_dataset_id = lang_dataset_id | |
| self.path_fasttext_model = path_fasttext_model | |
| self.path_sentencepiece_model = path_sentencepiece_model | |
| self.path_kenlm_model = path_kenlm_model | |
| self.num_proc = num_proc | |
| self.path_dir_save_dataset = path_dir_save_dataset | |
| def modifying_documents(self): | |
| dataset_modifying_documents = FunctionDatasetModifyingDocuments( | |
| self.lang_dataset_id | |
| ) | |
| self.ds = self.ds.map(dataset_modifying_documents, num_proc=self.num_proc) | |
| def filtering(self): | |
| func_dataset_filtering = FunctionDatasetFiltering( | |
| self.lang_dataset_id, | |
| self.path_fasttext_model, | |
| self.path_sentencepiece_model, | |
| self.path_kenlm_model, | |
| ) | |
| self.ds = self.ds.filter(func_dataset_filtering, num_proc=self.num_proc) | |
| def save_dataset(self): | |
| pathlib.Path(self.path_dir_save_dataset).mkdir(parents=True, exist_ok=True) | |
| path_dir_save_dataset = pathlib.PurePath( | |
| self.path_dir_save_dataset, self.lang_dataset_id | |
| ) | |
| pathlib.Path(path_dir_save_dataset).mkdir(parents=True, exist_ok=True) | |
| self.ds.save_to_disk(path_dir_save_dataset) | |