Spaces:
Running
Running
| import re | |
| from datasets import load_dataset, Dataset, DatasetDict | |
| from itertools import chain | |
| from tqdm import tqdm | |
| from collections import Counter | |
| from accelerate import Accelerator | |
| LANGUAGES_TO_DECODE_FROM_BYTES = ["he", "fr", "uk"] | |
| STREAMING_DATASETS = ["fineweb-edu"] | |
| def load_pg19_val_and_test(): | |
| # Load the dataset in streaming mode | |
| streaming_dataset = load_dataset("deepmind/pg19", split=None, streaming=True) | |
| # Extract test and validation splits | |
| test_split = list(streaming_dataset["test"]) | |
| validation_split = list(streaming_dataset["validation"]) | |
| # Convert them into regular datasets | |
| test_dataset = Dataset.from_list(test_split) | |
| validation_dataset = Dataset.from_list(validation_split) | |
| # validation_dataset = load_dataset("deepmind/pg19", split="validation") | |
| # test_dataset = load_dataset("deepmind/pg19", split="test") | |
| return DatasetDict({"validation": validation_dataset, "test": test_dataset}) | |
| def load_pubmed(n_samples=10000): | |
| # Load the dataset in streaming mode | |
| streaming_dataset = load_dataset("MedRAG/pubmed", streaming=True) | |
| # Extract test and validation splits | |
| data = list(streaming_dataset["train"].take(n_samples*4)) | |
| train = data[:2*n_samples] | |
| validation = data[2*n_samples:3*n_samples] | |
| test = data[3*n_samples:] | |
| # Convert them into regular datasets | |
| train = Dataset.from_list(train) | |
| validation = Dataset.from_list(validation) | |
| test = Dataset.from_list(test) | |
| dataset = DatasetDict({"train": train, 'validation': validation, 'test': test}) | |
| dataset = dataset.rename_column('content', 'text') | |
| return dataset | |
| def load_lm_dataset(dataset_name, language="en", split=None): | |
| """ | |
| Loads a popular pretraining or perplexity evaluation dataset by name and language. | |
| Args: | |
| dataset_name (str): The name of the dataset to load. Options include: | |
| - 'wikitext' (wikitext-2, smaller WikiText dataset) | |
| - 'wikitext-103' (larger WikiText dataset) | |
| - 'pg19' (Project Gutenberg dataset for long-context modeling) | |
| - 'c4' (Common Crawl-based English corpus) | |
| - 'wiki40b' (Wikipedia dataset in multiple languages) | |
| - 'mc4' (Multilingual C4 dataset in various languages) | |
| language (str): Language code for datasets that support multilingual options (e.g., 'en' for English). | |
| Defaults to 'en'. | |
| Returns: | |
| Dataset: Loaded Hugging Face dataset. | |
| """ | |
| if dataset_name.lower() == 'wikitext': | |
| return load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split) | |
| elif dataset_name.lower() == 'fineweb-edu': | |
| return load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT") | |
| elif dataset_name.lower() == 'wikitext-103': | |
| return load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split) | |
| elif dataset_name.lower() == 'cord19': | |
| return load_dataset("allenai/cord19", "fulltext", trust_remote_code=True) | |
| elif dataset_name.lower() == 'pubmed': | |
| return load_pubmed() | |
| elif dataset_name.lower() == 'wikilingua': | |
| dataset = load_dataset("GEM/wiki_lingua", trust_remote_code=True) | |
| dataset = dataset.filter(lambda ex: (ex['source_language'] == "en") & (ex['target_language'] == "en")) | |
| dataset = dataset.rename_column("source", "text") | |
| dataset = dataset.rename_column("target", "summary") | |
| return dataset | |
| elif dataset_name.lower() == 'xsum': | |
| dataset = load_dataset("EdinburghNLP/xsum") | |
| dataset = dataset.rename_column("document", "text") | |
| return dataset | |
| elif dataset_name.lower() == 'cnn': | |
| dataset = load_dataset("abisee/cnn_dailymail", "3.0.0") | |
| dataset = dataset.rename_column("article", "text") | |
| dataset = dataset.rename_column("highlights", "summary") | |
| dataset = dataset.map(lambda example: {"text": example["text"].replace("(CNN)", "")}) | |
| return dataset | |
| elif dataset_name.lower() == 'pg19': | |
| return load_pg19_val_and_test() | |
| elif dataset_name.lower() == 'wiki40b': | |
| dataset = load_dataset("google/wiki40b", language, split=split) | |
| if language in LANGUAGES_TO_DECODE_FROM_BYTES: | |
| dataset = dataset.map(lambda x: { | |
| "text": bytes(x["text"][2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8").replace("_NEWLINE_", "\n") | |
| }) | |
| return dataset | |
| else: | |
| raise ValueError( | |
| "Dataset not recognized. Available options: 'wikitext-2', 'wikitext-103', 'pg19', 'c4', 'wiki40b', 'mc4'.") | |
| def extract_new_words_from_dataset( | |
| dataset: Dataset, tokenizer, text_column: str = "text", max_samples: int = None, filter_func=(lambda word, token_count: True)): | |
| """ | |
| Loads a Hugging Face dataset and extracts all unique words from the specified text column. | |
| Args: | |
| dataset (Dataset): Name of the dataset to load. | |
| split (str): Dataset split to use, typically 'train' for training data. Defaults to 'train'. | |
| text_column (str): The column in the dataset containing text. Defaults to 'text'. | |
| max_samples (int): Number of samples from the dataset to go over. | |
| Returns: | |
| set: A set of unique words in the dataset. | |
| """ | |
| if max_samples: | |
| dataset = dataset.select(range(max_samples)) | |
| # Regular expression to split text into words (adjust as needed for specific languages) | |
| # word_pattern = re.compile(r"\b\w+\b") | |
| word_pattern = re.compile(r"\b\w+(?:[-']\w+)*\b") | |
| # Iterate over each entry in the dataset and extract unique words | |
| all_words = list() | |
| new_words = list() | |
| for record in tqdm(dataset, total=len(dataset), miniters=10, desc="Extracting all words from dataset...", unit="examples"): | |
| text = record.get(text_column, "") | |
| words = word_pattern.findall(text) | |
| all_words += words | |
| # all_words = list(dict.fromkeys(all_words)) | |
| word_frequencies = Counter(all_words) | |
| all_words = list(word_frequencies.keys()) | |
| token_counts = [len(x) for x in tokenizer(all_words, add_special_tokens=False)["input_ids"]] | |
| w_whitespace_token_counts = [len(x) for x in tokenizer([f" {w}" for w in all_words], add_special_tokens=False)["input_ids"]] | |
| new_words = [word for word, count, w_whitespace_count in zip(all_words, token_counts, w_whitespace_token_counts) if ((count > 1) and (w_whitespace_count > 1) and filter_func(word, count))] | |
| new_words_freq = {word: word_frequencies[word] for word in new_words} | |
| # for word, token_count in tqdm(all_words, total=len(all_words), miniters=10, desc="Finding new words...", unit="words"): | |
| # if (not tokenizer.vocab.get(word, False)) and : | |
| # new_words.append(word) | |
| # remove duplicates and return | |
| return new_words, new_words_freq | |
| def get_group_texts_func(block_size=1024): | |
| def group_texts(examples): | |
| # Concatenate all texts. | |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
| # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. | |
| # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. | |
| total_length = (total_length // block_size) * block_size | |
| # Split by chunks of max_len. | |
| result = { | |
| k: [t[i: i + block_size] for i in range(0, total_length, block_size)] | |
| for k, t in concatenated_examples.items() | |
| } | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| return group_texts | |
| def get_tokenize_func(tokenizer, text_col_name): | |
| def _tokenize(examples): | |
| output = tokenizer( | |
| examples[text_col_name], | |
| return_token_type_ids=False, | |
| add_special_tokens=False, | |
| ) | |
| return output | |
| return _tokenize | |
| def tokenize_and_prepare_dataset( | |
| dataset, tokenizer, accelerator=None, | |
| text_col_name: str = "text", | |
| max_length: int = 256, | |
| eval_max_samples: int = None, | |
| ): | |
| if tokenizer.bos_token is not None and max_length: | |
| # leave room for <BOS> token to be added: | |
| max_tokenized_len = max_length - 1 | |
| else: | |
| max_tokenized_len = max_length | |
| tokenize_function = get_tokenize_func(tokenizer, text_col_name) | |
| column_names = dataset.column_names | |
| tokenized_dataset = dataset.map( | |
| tokenize_function, | |
| batched=True, | |
| remove_columns=column_names, | |
| load_from_cache_file=False, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| group_texts = get_group_texts_func(block_size=max_tokenized_len) | |
| lm_dataset = tokenized_dataset.map( | |
| group_texts, | |
| batched=True, | |
| ) | |
| if eval_max_samples: | |
| lm_dataset = lm_dataset.select(range(eval_max_samples)) | |
| return lm_dataset | |