Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import torch | |
| import cv2 | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import config as CFG | |
| from datasets import get_transforms | |
| #for running this script as main | |
| from utils import get_datasets, build_loaders | |
| from models import PoemTextModel | |
| from utils import get_poem_embeddings | |
| import json | |
| import os | |
| import regex | |
| def predict_poems_from_text(model, poem_embeddings, query, poems, text_tokenizer=None, n=10, return_similarities=False): | |
| """ | |
| Returns n poems which are the most similar to a text query | |
| Parameters: | |
| ----------- | |
| model: PoemTextModel | |
| model to compute text query's embeddings | |
| poem_embeddings: sequence with shape (#poems, CFG.projection_dim) | |
| poem embeddings to check similarity | |
| query: str | |
| text query | |
| poems: list of str | |
| poems corresponding to poem_embeddings | |
| text_tokenizer: huggingface Tokenizer, optional | |
| tokenizer to tokenize query with. if none, will instantiate a new text tokenizer using configs. | |
| n: int, optional | |
| number of poems to return | |
| return_similarities: bool, optional | |
| if True, a dictionary will be returned which has the poem beyts and their similarities to the text | |
| Returns: | |
| -------- | |
| A list of n poem strings whose embeddings are the most similar to query text's embedding. | |
| """ | |
| #Tokenizing and Encoding the query text | |
| if not text_tokenizer: | |
| text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
| encoded_query = text_tokenizer([query]) | |
| batch = { | |
| key: torch.tensor(values).to(CFG.device) | |
| for key, values in encoded_query.items() | |
| } | |
| # getting query text's embeddings | |
| model.eval() | |
| with torch.no_grad(): | |
| text_features = model.text_encoder( | |
| input_ids= batch["input_ids"], attention_mask=batch["attention_mask"] | |
| ) | |
| text_embeddings = model.text_projection(text_features) | |
| # normalizing and computing dot similarity of poem and text embeddings | |
| poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
| text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
| dot_similarity = text_embeddings_n @ poem_embeddings_n.T | |
| # returning top n poems based on embedding similarity | |
| values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) | |
| # since we collected poems from many sources, some of them are equal (the same beyt with different meanings), | |
| # so we must check the poems added to result not to be duplicates | |
| def is_poem_duplicate(poem, poems): | |
| poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) | |
| for other_poem in poems: | |
| other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) | |
| if poem == other_poem: | |
| return True | |
| return False | |
| results = [] | |
| computed_k = 0 | |
| for i in range(len(poems)): | |
| if computed_k == n: | |
| break | |
| if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): | |
| results.append({ | |
| 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), | |
| 'similarity': values[i] | |
| }) | |
| computed_k += 1 | |
| if return_similarities: | |
| return results | |
| else: | |
| return [res['beyt'] for res in results] | |
| def predict_poems_from_image(model, poem_embeddings, image_filename, poems, n=10, return_similarities=False): | |
| """ | |
| Returns n poems which are the most similar to an image query | |
| Parameters: | |
| ----------- | |
| model: CLIPModel | |
| model to compute image query's embeddings | |
| poem_embeddings: sequence with shape (#poems, CFG.projection_dim) | |
| poem embeddings to check similarity | |
| image_filename: str | |
| path and file name for the image query | |
| poems: list of str | |
| poems corresponding to poem_embeddings | |
| n: int, optional | |
| number of poems to return | |
| return_similarities: bool, optional | |
| if True, a dictionary will be returned which has the poem beyts and their similarities to the text | |
| Returns: | |
| -------- | |
| A list of n poem strings whose embeddings are the most similar to image query's embedding. | |
| """ | |
| # Reading, Processing and applying transforms to image (all explained in datasets.py) | |
| image = cv2.imread(f"{image_filename}") | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = get_transforms(mode="test")(image=image)['image'] | |
| image = torch.tensor(image).permute(2, 0, 1).float() | |
| # getting image query's embeddings | |
| model.eval() | |
| with torch.no_grad(): | |
| image_features = model.image_encoder(torch.unsqueeze(image, 0).to(CFG.device)) | |
| image_embeddings = model.image_projection(image_features) | |
| # normalizing and computing dot similarity of poem and text embeddings | |
| poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
| image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
| dot_similarity = image_embeddings_n @ poem_embeddings_n.T | |
| # returning top n poems based on embedding similarity | |
| values, indices = torch.topk(dot_similarity.squeeze(0), len(poems)) | |
| # since we collected poems from many sources, some of them are equal (the same beyt with different meanings), | |
| # so we must check the poems added to result not to be duplicates | |
| def is_poem_duplicate(poem, poems): | |
| poem = regex.findall(r'\p{L}+', poem.replace('\u200c', '')) | |
| for other_poem in poems: | |
| other_poem = regex.findall(r'\p{L}+', other_poem.replace('\u200c', '')) | |
| if poem == other_poem: | |
| return True | |
| return False | |
| results = [] | |
| computed_k = 0 | |
| for i in range(len(poems)): | |
| if computed_k == n: | |
| break | |
| if not is_poem_duplicate(poems[indices[i]], [res['beyt'] for res in results]): | |
| results.append({ | |
| 'beyt': poems[indices[i]].replace(' * * ', ' * ').replace('*** * ', ''), | |
| 'similarity': values[i] | |
| }) | |
| computed_k += 1 | |
| if return_similarities: | |
| return results | |
| else: | |
| return [res['beyt'] for res in results] | |
| if __name__ == "__main__": | |
| """ | |
| Creates a PoemTextModel based on configs, and outputs some examples of its prediction. | |
| """ | |
| # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) | |
| train_dataset, val_dataset, test_dataset = get_datasets() | |
| model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) | |
| model.eval() | |
| # Inference: Output some example predictions and write them in a file | |
| print("_"*20) | |
| print("Output Examples from test set") | |
| model, poem_embeddings = get_poem_embeddings(test_dataset, model) | |
| example = {} | |
| for i, test_data in enumerate(test_dataset[:100]): | |
| example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)} | |
| for i in range(10): | |
| print("Text: ", example[i]['Text']) | |
| print("True Beyt: ", example[i]['True Beyt']) | |
| print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"])) | |
| with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
| f.write(json.dumps(example, ensure_ascii=False, indent= 4)) | |
| print("Preparing model for user input...") | |
| with open(CFG.dataset_path, encoding="utf-8") as f: | |
| dataset = json.load(f) | |
| model, poem_embeddings = get_poem_embeddings(dataset, model) | |
| while(True): | |
| user_text = input("Enter a Text to find poem beyts for: ") | |
| beyts = predict_poems_from_text(model, poem_embeddings, user_text, [data['beyt'] for data in dataset], n=10) | |
| print("predicted Beyts: \n\t", "\n\t".join(beyts)) | |
| with open('{}_output__{}_{}.json'.format(user_text, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f: | |
| f.write(json.dumps(beyts, ensure_ascii=False, indent= 4)) |