Spaces:
Runtime error
Runtime error
| from utils import get_datasets, build_loaders | |
| from models import PoemTextModel | |
| from train import train, test | |
| from metrics import calc_metrics | |
| from inference import predict_poems_from_text | |
| from utils import get_poem_embeddings | |
| import config as CFG | |
| import json | |
| def main(): | |
| """ | |
| Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction. | |
| """ | |
| train_or_not = input("Train a new CLIP model using text embeddings? (needs the sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets to be downloaded)\n[Y/N]") | |
| if train_or_not == 'Y': | |
| # Please download sajjadayobi360/cc3mfav2 and adityajn105/flickr8k datasets from kaggle | |
| # !kaggle datasets download -d sajjadayobi360/cc3mfav2 | |
| # !kaggle datasets download -d adityajn105/flickr8k | |
| #.... TODO | |
| clip_dataset_dict = [] | |
| # get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) | |
| train_dataset, val_dataset, test_dataset = get_clip_datasets(clip_dataset_dict) | |
| train_loader = build_image_loaders(train_dataset, mode="train") | |
| valid_loader = build_image_loaders(val_dataset, mode="valid") | |
| # train a PoemTextModel and write its loss history in a file | |
| model = CLIPModel(image_encoder_pretrained=True, | |
| text_encoder_pretrained=True, | |
| text_projection_trainable=False, | |
| is_image_poem_pair=False | |
| ).to(CFG.device) | |
| model, loss_history = train(model, train_loader, valid_loader) | |
| with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
| f.write(json.dumps(loss_history, indent= 4)) | |
| # Inference: Get a filename and output predictions then write them in a file | |
| print("_"*20) | |
| print("INFERENCE PHASE") | |
| model = CLIPModel(image_encoder_pretrained=True, | |
| text_encoder_pretrained=True, | |
| text_projection_trainable=False, | |
| is_image_poem_pair=True | |
| ).to(CFG.device) | |
| model.eval() | |
| with open(CFG.dataset_path, encoding="utf-8") as f: | |
| dataset = json.load(f) | |
| model, poem_embeddings = get_poem_embeddings(test_dataset, model) | |
| while(True): | |
| image_filename = input("Enter an image filename to predict poems for") | |
| beyts = predict_poems_from_image(model, poem_embeddings, image_filename, [data['beyt'] for data in dataset], n=10) | |
| print("predicted Beyts: \n\t", "\n\t".join(beyts)) | |
| with open('{}_output__{}_{}.json'.format(image_filename, CFG.poem_encoder_model, CFG.text_encoder_model),'a+', encoding="utf-8") as f: | |
| f.write(json.dumps(beyts, ensure_ascii=False, indent= 4)) | |
| if __name__ == "__main__": | |
| main() |