Spaces:
Running
Running
| import string | |
| from typing import List, Union | |
| import nltk | |
| import torch | |
| from numpy.typing import NDArray | |
| from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType | |
| def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE: | |
| """ | |
| Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0. | |
| Args: | |
| gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s): | |
| - bool: If True, returns 0 if CUDA is available, otherwise returns "cpu". | |
| - str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available | |
| and the input is not "cpu", otherwise returns "cpu". | |
| - int: Should be a valid GPU index. Returns the index if CUDA is available and valid, | |
| otherwise returns "cpu". | |
| - List[Union[str, int]]: List containing combinations of the str/int. Processes each | |
| element and returns a list of corresponding results. | |
| Returns: | |
| Union[str, int, List[Union[str, int]]]: Depending on the input type: | |
| - str: Returns "cpu" if no GPU is available or the input is "cpu". | |
| - int: Returns the GPU index if valid and CUDA is available. | |
| - List[Union[str, int]]: Returns a list of strings and/or integers based on the input list. | |
| Raises: | |
| ValueError: If the input gpu type is not recognized or invalid. | |
| ValueError: If a string input is not one of ["cpu", "gpu", "cuda"]. | |
| ValueError: If an integer input is outside the valid range of GPU indices. | |
| Notes: | |
| - This function checks CUDA availability using torch.cuda.is_available() and counts | |
| available GPUs using torch.cuda.device_count(). | |
| - Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda"). | |
| - The function ensures robust error handling for invalid input types or out-of-range indices. | |
| """ | |
| # Ensure gpu index is within the range of total available gpus | |
| gpu_available = torch.cuda.is_available() | |
| gpu_count = torch.cuda.device_count() | |
| correct_strs = ["cpu", "gpu", "cuda"] | |
| def _get_single_device(gpu_item): | |
| if isinstance(gpu_item, bool): | |
| return 0 if gpu_item and gpu_available else "cpu" | |
| elif isinstance(gpu_item, str): | |
| if gpu_item.lower() not in correct_strs: | |
| raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}") | |
| return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu" | |
| elif isinstance(gpu_item, int): | |
| if gpu_item >= gpu_count: | |
| raise ValueError( | |
| f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}" | |
| ) | |
| return gpu_item if gpu_available else "cpu" | |
| else: | |
| raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.") | |
| if isinstance(gpu, list): | |
| seen_indices = set() | |
| result = [] | |
| for item in gpu: | |
| device = _get_single_device(item) | |
| if isinstance(device, int): | |
| if device not in seen_indices: | |
| seen_indices.add(device) | |
| result.append(device) | |
| else: | |
| result.append(device) | |
| return result[0] if len(result) == 1 else result | |
| else: | |
| return _get_single_device(gpu) | |
| def prep_sentences(sentences: List[str]) -> List[str]: | |
| """ | |
| Processes a list of sentences by stripping whitespace (at beginning and the end), | |
| , filtering out empty sentences or sentences that only contains punctuations. | |
| Args: | |
| sentences (List[str]): A list of sentences to be processed. | |
| Returns: | |
| List[str]: A list of cleaned sentences | |
| Raises: | |
| ValueError: If the resulting list of sentences is empty. | |
| Example: | |
| >>> prep_sentences(["Hello, world!", " This is a test. ", "!!!"]) | |
| ['Hello, world!', 'This is a test.'] | |
| >>> prep_sentences(["!!!", "..."]) | |
| ValueError: Document can't be empty. | |
| """ | |
| out = [] | |
| for sent in sentences: | |
| sent = sent.strip() | |
| sent_wo_punctuation = ( | |
| sent.translate(str.maketrans("", "", string.punctuation)) | |
| ).strip() | |
| if sent_wo_punctuation: | |
| out.append(sent) | |
| if len(out) == 0: | |
| raise ValueError("Document can't be empty.") | |
| return out | |
| def tokenize_and_prep_document(document: Union[str, List[str]], tokenize: bool) -> List[str]: | |
| """ | |
| Tokenizes and prepares a document by either tokenizing it into sentences and processing each sentence, | |
| or directly processing each element if `tokenize` is False. | |
| Args: | |
| document (Union[str, List[str]]): The document to be processed. It can be a single string (enitre document) or a | |
| list of strings (list of sentences). | |
| tokenize (bool): If True, tokenizes `document` into sentences using NLTK's sentence tokenizer before processing. | |
| If False, processes each element of `document` directly as sentences. | |
| Returns: | |
| List[str]: A list of cleaned sentences. | |
| Raises: | |
| ValueError: If the resulting list of sentences is empty after processing. | |
| Example: | |
| >>> tokenize_and_prep_document("Hello, world! This is a test.", True) | |
| ['Hello, world!', 'This is a test.'] | |
| >>> tokenize_and_prep_document(["Hello, world!", "This is a test."], False) | |
| ['Hello, world!', 'This is a test.'] | |
| >>> tokenize_and_prep_document("!!! ...", True) | |
| ValueError: Document can't be empty. | |
| Note: Only the following two cases are possible. | |
| tokenizer=True -> document: str | |
| tokenizer=False -> document: List[str]. | |
| """ | |
| if tokenize: | |
| return prep_sentences(nltk.tokenize.sent_tokenize(document)) | |
| return prep_sentences(document) | |
| def flatten_list(nested_list: list) -> list: | |
| """ | |
| Recursively flattens a nested list of any depth. | |
| Parameters: | |
| nested_list (list): The nested list to flatten. | |
| Returns: | |
| list: A flat list containing all the elements of the nested list. | |
| """ | |
| flat_list = [] | |
| for item in nested_list: | |
| if isinstance(item, list): | |
| flat_list.extend(flatten_list(item)) | |
| else: | |
| flat_list.append(item) | |
| return flat_list | |
| def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool: | |
| """ | |
| Check if the given object is a nested list of a specific type up to a specified depth. | |
| Args: | |
| - lst_obj: The object to check, expected to be a list or a single element. | |
| - element_type: The type that each element in the nested list should match. | |
| - depth (int): The depth of nesting to check. Must be non-negative. | |
| Returns: | |
| - bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise. | |
| Raises: | |
| - ValueError: If depth is negative. | |
| Example: | |
| ```python | |
| # Test cases | |
| is_nested_list_of_type("test", str, 0) # Returns True | |
| is_nested_list_of_type([1, 2, 3], str, 0) # Returns False | |
| is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True | |
| is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True | |
| is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False | |
| is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True | |
| ``` | |
| Explanation: | |
| - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep. | |
| - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`. | |
| - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`. | |
| - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer. | |
| """ | |
| if depth == 0: | |
| return isinstance(lst_obj, element_type) | |
| elif depth > 0: | |
| return isinstance(lst_obj, list) and all( | |
| is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj) | |
| else: | |
| raise ValueError("Depth can't be negative") | |
| def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: | |
| """ | |
| Slice embeddings into segments based on the provided number of sentences per segment. | |
| Args: | |
| - embeddings (np.ndarray): The array of embeddings to be sliced. | |
| - num_sentences (Union[List[int], List[List[int]]]): | |
| - If a list of integers: Specifies the number of embeddings to take in each slice. | |
| - If a list of lists of integers: Specifies multiple nested levels of slicing. | |
| Returns: | |
| - List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings. | |
| Raises: | |
| - TypeError: If `num_sentences` is not of type List[int] or List[List[int]]. | |
| Example Usage: | |
| ```python | |
| embeddings = np.random.rand(10, 5) | |
| num_sentences = [3, 2, 5] | |
| result = slice_embeddings(embeddings, num_sentences) | |
| # `result` will be a list of numpy arrays: | |
| # [embeddings[:3], embeddings[3:5], embeddings[5:]] | |
| num_sentences_nested = [[2, 1], [3, 4]] | |
| result_nested = slice_embeddings(embeddings, num_sentences_nested) | |
| # `result_nested` will be a nested list of numpy arrays: | |
| # [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] | |
| slice_embeddings(embeddings, "invalid") # Raises a TypeError | |
| ``` | |
| """ | |
| def _slice_embeddings(s_idx: int, n_sentences: List[int]): | |
| """ | |
| Helper function to slice embeddings starting from index `s_idx`. | |
| Args: | |
| - s_idx (int): Starting index for slicing. | |
| - n_sentences (List[int]): List specifying number of sentences in each slice. | |
| Returns: | |
| - Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index. | |
| """ | |
| _result = [] | |
| for count in n_sentences: | |
| _result.append(embeddings[s_idx:s_idx + count]) | |
| s_idx += count | |
| return _result, s_idx | |
| if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences): | |
| result, _ = _slice_embeddings(0, num_sentences) | |
| return result | |
| elif isinstance(num_sentences, list) and all( | |
| isinstance(sublist, list) and all( | |
| isinstance(item, int) for item in sublist | |
| ) | |
| for sublist in num_sentences | |
| ): | |
| nested_result = [] | |
| start_idx = 0 | |
| for nested_num_sentences in num_sentences: | |
| embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences) | |
| nested_result.append(embedding_slice) | |
| return nested_result | |
| else: | |
| raise TypeError(f"Incorrect Type for {num_sentences=}") | |