Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						35ef920
	
1
								Parent(s):
							
							c5ca37a
								
- app.py +207 -68
 - checkpoint-31250/checkpoint-decoder-31250/pytorch_model.bin +1 -1
 - checkpoint-31250/checkpoint-decoder-31250/training_decoder_args.bin +2 -2
 - checkpoint-31250/checkpoint-encoder-31250/pytorch_model.bin +1 -1
 - checkpoint-31250/checkpoint-encoder-31250/training_encoder_args.bin +2 -2
 - checkpoint-31250/checkpoint-full-31250/training.bin +2 -2
 - real_im_emb_plot.jpg +0 -0
 
    	
        app.py
    CHANGED
    
    | 
         @@ -7,53 +7,194 @@ Original file is located at 
     | 
|
| 7 | 
         
             
                https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
         
     | 
| 8 | 
         
             
            """
         
     | 
| 9 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 10 | 
         
             
            import torch
         
     | 
| 11 | 
         
             
            import torch.nn as nn
         
     | 
| 12 | 
         
             
            import torch.nn.functional as F
         
     | 
| 13 | 
         
             
            torch.set_float32_matmul_precision('high')
         
     | 
| 14 | 
         | 
| 15 | 
         
             
            from tqdm import tqdm
         
     | 
| 16 | 
         
            -
            from transformers import AutoTokenizer, AutoModelForCausalLM
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            class BottleneckT5Autoencoder:
         
     | 
| 19 | 
         
            -
                def __init__(self, model_path: str, device='cuda'):
         
     | 
| 20 | 
         
            -
                    self.device = device
         
     | 
| 21 | 
         
            -
                    self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16)
         
     | 
| 22 | 
         
            -
                    self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device)
         
     | 
| 23 | 
         
            -
                    self.model.eval()
         
     | 
| 24 | 
         
            -
                    # self.model = torch.compile(self.model)
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
                def embed(self, text: str) -> torch.FloatTensor:
         
     | 
| 28 | 
         
            -
                    inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device)
         
     | 
| 29 | 
         
            -
                    decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
         
     | 
| 30 | 
         
            -
                    return self.model(
         
     | 
| 31 | 
         
            -
                        **inputs,
         
     | 
| 32 | 
         
            -
                        decoder_input_ids=decoder_inputs['input_ids'],
         
     | 
| 33 | 
         
            -
                        encode_only=True,
         
     | 
| 34 | 
         
            -
                    )
         
     | 
| 35 | 
         | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
             
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
                     
     | 
| 43 | 
         
            -
                         
     | 
| 44 | 
         
            -
                         
     | 
| 45 | 
         
            -
                         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
                     
     | 
| 53 | 
         
            -
                     
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 56 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 57 | 
         | 
| 58 | 
         
             
            import gradio as gr
         
     | 
| 59 | 
         
             
            import numpy as np
         
     | 
| 
         @@ -64,7 +205,7 @@ import pandas as pd 
     | 
|
| 64 | 
         
             
            import random
         
     | 
| 65 | 
         
             
            import time
         
     | 
| 66 | 
         | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
             
            dtype = torch.bfloat16
         
     | 
| 69 | 
         
             
            torch.set_grad_enabled(False)
         
     | 
| 70 | 
         | 
| 
         @@ -80,13 +221,20 @@ start_time = time.time() 
     | 
|
| 80 | 
         
             
            def generate(prompt, in_embs=None,):
         
     | 
| 81 | 
         
             
              if prompt != '':
         
     | 
| 82 | 
         
             
                print(prompt)
         
     | 
| 83 | 
         
            -
                in_embs = in_embs / in_embs.abs().max() * . 
     | 
| 84 | 
         
            -
                in_embs =  
     | 
| 85 | 
         
             
              else:
         
     | 
| 86 | 
         
             
                print('From embeds.')
         
     | 
| 87 | 
         
            -
              in_embs = in_embs / in_embs.abs().max() * . 
     | 
| 88 | 
         
            -
               
     | 
| 89 | 
         
            -
               
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
              return text, in_embs.to('cpu')
         
     | 
| 91 | 
         | 
| 92 | 
         | 
| 
         @@ -103,7 +251,6 @@ def next_one(embs, ys, calibrate_prompts): 
     | 
|
| 103 | 
         
             
                    if len(calibrate_prompts) > 0:
         
     | 
| 104 | 
         
             
                        print('######### Calibrating with sample prompts #########')
         
     | 
| 105 | 
         
             
                        prompt = calibrate_prompts.pop(0)
         
     | 
| 106 | 
         
            -
                        print(prompt)
         
     | 
| 107 | 
         
             
                        text, img_embs = generate(prompt)
         
     | 
| 108 | 
         
             
                        embs += img_embs
         
     | 
| 109 | 
         
             
                        print(len(embs))
         
     | 
| 
         @@ -114,12 +261,12 @@ def next_one(embs, ys, calibrate_prompts): 
     | 
|
| 114 | 
         | 
| 115 | 
         
             
                        # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
         
     | 
| 116 | 
         
             
                        if len(list(set(ys))) <= 1:
         
     | 
| 117 | 
         
            -
                            embs.append(.01*torch.randn( 
     | 
| 118 | 
         
            -
                            embs.append(.01*torch.randn( 
     | 
| 119 | 
         
             
                            ys.append(0)
         
     | 
| 120 | 
         
             
                            ys.append(1)
         
     | 
| 121 | 
         
             
                        if len(list(ys)) < 10:
         
     | 
| 122 | 
         
            -
                            embs += [.01*torch.randn( 
     | 
| 123 | 
         
             
                            ys += [0] * 3
         
     | 
| 124 | 
         | 
| 125 | 
         
             
                        pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
         
     | 
| 
         @@ -129,13 +276,6 @@ def next_one(embs, ys, calibrate_prompts): 
     | 
|
| 129 | 
         
             
                        random.shuffle(pos_indices)
         
     | 
| 130 | 
         
             
                        random.shuffle(neg_indices)
         
     | 
| 131 | 
         | 
| 132 | 
         
            -
                        #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
         
     | 
| 133 | 
         
            -
                        #    pos_indices = pos_indices[32:]
         
     | 
| 134 | 
         
            -
                        if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6:
         
     | 
| 135 | 
         
            -
                            pos_indices = pos_indices[5:]
         
     | 
| 136 | 
         
            -
                        if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6:
         
     | 
| 137 | 
         
            -
                            neg_indices = neg_indices[5:]
         
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         | 
| 140 | 
         
             
                        if len(neg_indices) > 25:
         
     | 
| 141 | 
         
             
                            neg_indices = neg_indices[1:]
         
     | 
| 
         @@ -150,17 +290,17 @@ def next_one(embs, ys, calibrate_prompts): 
     | 
|
| 150 | 
         
             
                        indices = list(range(len(embs)))
         
     | 
| 151 | 
         | 
| 152 | 
         
             
                        # also add the latest 0 and the latest 1
         
     | 
| 153 | 
         
            -
                        has_0 = False
         
     | 
| 154 | 
         
            -
                        has_1 = False
         
     | 
| 155 | 
         
            -
                        for i in reversed(range(len(ys))):
         
     | 
| 156 | 
         
            -
             
     | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         | 
| 165 | 
         
             
                        # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
         
     | 
| 166 | 
         
             
                        # this ends up adding a rating but losing an embedding, it seems.
         
     | 
| 
         @@ -177,7 +317,6 @@ def next_one(embs, ys, calibrate_prompts): 
     | 
|
| 177 | 
         
             
                        print('Gathering coefficients')
         
     | 
| 178 | 
         
             
                        lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
         
     | 
| 179 | 
         
             
                        coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
         
     | 
| 180 | 
         
            -
                        coef_ = coef_ / coef_.abs().max() * 3
         
     | 
| 181 | 
         
             
                        print(coef_.shape, 'COEF')
         
     | 
| 182 | 
         
             
                        print('Gathered')
         
     | 
| 183 | 
         | 
| 
         | 
|
| 7 | 
         
             
                https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
         
     | 
| 8 | 
         
             
            """
         
     | 
| 9 | 
         | 
| 10 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 11 | 
         
            +
            import matplotlib
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import argparse
         
     | 
| 14 | 
         
            +
            import glob
         
     | 
| 15 | 
         
            +
            import logging
         
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            import pickle
         
     | 
| 18 | 
         
            +
            import random
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import torch
         
     | 
| 22 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 23 | 
         
            +
            import numpy as np
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from tqdm import tqdm, trange
         
     | 
| 26 | 
         
            +
            from types import SimpleNamespace
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            import sys
         
     | 
| 29 | 
         
            +
            sys.path.append('./Optimus/code/examples/big_ae/')
         
     | 
| 30 | 
         
            +
            sys.path.append('./Optimus/code/')
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig
         
     | 
| 33 | 
         
            +
            from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector
         
     | 
| 34 | 
         
            +
            from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
         
     | 
| 35 | 
         
            +
            from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer
         
     | 
| 36 | 
         
            +
            from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
         
     | 
| 37 | 
         
            +
            from pytorch_transformers import BertForLatentConnector, BertTokenizer
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            from modules import VAE
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
             
            import torch
         
     | 
| 42 | 
         
             
            import torch.nn as nn
         
     | 
| 43 | 
         
             
            import torch.nn.functional as F
         
     | 
| 44 | 
         
             
            torch.set_float32_matmul_precision('high')
         
     | 
| 45 | 
         | 
| 46 | 
         
             
            from tqdm import tqdm
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 47 | 
         | 
| 48 | 
         
            +
            ################################################
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
         
     | 
| 53 | 
         
            +
                """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
         
     | 
| 54 | 
         
            +
                    Args:
         
     | 
| 55 | 
         
            +
                        logits: logits distribution shape (vocabulary size)
         
     | 
| 56 | 
         
            +
                        top_k > 0: keep only top k tokens with highest probability (top-k filtering).
         
     | 
| 57 | 
         
            +
                        top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
         
     | 
| 58 | 
         
            +
                            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
         
     | 
| 59 | 
         
            +
                    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
                assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
         
     | 
| 62 | 
         
            +
                top_k = min(top_k, logits.size(-1))  # Safety check
         
     | 
| 63 | 
         
            +
                if top_k > 0:
         
     | 
| 64 | 
         
            +
                    # Remove all tokens with a probability less than the last token of the top-k
         
     | 
| 65 | 
         
            +
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
         
     | 
| 66 | 
         
            +
                    logits[indices_to_remove] = filter_value
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                if top_p > 0.0:
         
     | 
| 69 | 
         
            +
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
         
     | 
| 70 | 
         
            +
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Remove tokens with cumulative probability above the threshold
         
     | 
| 73 | 
         
            +
                    sorted_indices_to_remove = cumulative_probs > top_p
         
     | 
| 74 | 
         
            +
                    # Shift the indices to the right to keep also the first token above the threshold
         
     | 
| 75 | 
         
            +
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
         
     | 
| 76 | 
         
            +
                    sorted_indices_to_remove[..., 0] = 0
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
         
     | 
| 79 | 
         
            +
                    logits[indices_to_remove] = filter_value
         
     | 
| 80 | 
         
            +
                return logits
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
         
     | 
| 83 | 
         
            +
                
         
     | 
| 84 | 
         
            +
                context = torch.tensor(context, dtype=torch.long, device=device)
         
     | 
| 85 | 
         
            +
                context = context.unsqueeze(0).repeat(num_samples, 1)
         
     | 
| 86 | 
         
            +
                generated = context
         
     | 
| 87 | 
         
            +
                with torch.no_grad():
         
     | 
| 88 | 
         
            +
                    while True:
         
     | 
| 89 | 
         
            +
                    # for _ in trange(length):
         
     | 
| 90 | 
         
            +
                        inputs = {'input_ids': generated, 'past': past}
         
     | 
| 91 | 
         
            +
                        outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
         
     | 
| 92 | 
         
            +
                        next_token_logits = outputs[0][0, -1, :] / temperature
         
     | 
| 93 | 
         
            +
                        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
         
     | 
| 94 | 
         
            +
                        next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
         
     | 
| 95 | 
         
            +
                        generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                        # pdb.set_trace()
         
     | 
| 98 | 
         
            +
                        if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
         
     | 
| 99 | 
         
            +
                            break
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                return generated
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            def latent_code_from_text(text,):# args):
         
     | 
| 105 | 
         
            +
                tokenized1 = tokenizer_encoder.encode(text)
         
     | 
| 106 | 
         
            +
                tokenized1 = [101] + tokenized1 + [102]
         
     | 
| 107 | 
         
            +
                coded1 = torch.Tensor([tokenized1])
         
     | 
| 108 | 
         
            +
                coded1 =torch.Tensor.long(coded1)
         
     | 
| 109 | 
         
            +
                with torch.no_grad():
         
     | 
| 110 | 
         
            +
                    x0 = coded1
         
     | 
| 111 | 
         
            +
                    x0 = x0.to('cuda')
         
     | 
| 112 | 
         
            +
                    pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
         
     | 
| 113 | 
         
            +
                    mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
         
     | 
| 114 | 
         
            +
                    latent_z = mean.squeeze(1)  
         
     | 
| 115 | 
         
            +
                    coded_length = len(tokenized1)
         
     | 
| 116 | 
         
            +
                    return latent_z, coded_length
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            # args
         
     | 
| 119 | 
         
            +
            def text_from_latent_code(latent_z):
         
     | 
| 120 | 
         
            +
                past = latent_z
         
     | 
| 121 | 
         
            +
                context_tokens = tokenizer_decoder.encode('<BOS>')
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                length = 128 # maximum length, but not used 
         
     | 
| 124 | 
         
            +
                out = sample_sequence_conditional(
         
     | 
| 125 | 
         
            +
                    model=model_vae.decoder,
         
     | 
| 126 | 
         
            +
                    context=context_tokens,
         
     | 
| 127 | 
         
            +
                    past=past,
         
     | 
| 128 | 
         
            +
                    length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
         
     | 
| 129 | 
         
            +
                    temperature=.5,
         
     | 
| 130 | 
         
            +
                    top_k=100,
         
     | 
| 131 | 
         
            +
                    top_p=.95,
         
     | 
| 132 | 
         
            +
                    device='cuda',
         
     | 
| 133 | 
         
            +
                    decoder_tokenizer = tokenizer_decoder
         
     | 
| 134 | 
         
            +
                )
         
     | 
| 135 | 
         
            +
                text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
         
     | 
| 136 | 
         
            +
                text_x1 = text_x1.split()[1:-1]
         
     | 
| 137 | 
         
            +
                text_x1 = ' '.join(text_x1)
         
     | 
| 138 | 
         
            +
                return text_x1
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            ################################################
         
     | 
| 142 | 
         
            +
            # Load model
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            MODEL_CLASSES = {
         
     | 
| 146 | 
         
            +
                'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
         
     | 
| 147 | 
         
            +
                'bert': (BertConfig, BertForLatentConnector, BertTokenizer)
         
     | 
| 148 | 
         
            +
            }
         
     | 
| 149 | 
         | 
| 150 | 
         
            +
            latent_size = 768
         
     | 
| 151 | 
         
            +
            model_path = './checkpoint-31250/checkpoint-full-31250/'
         
     | 
| 152 | 
         
            +
            encoder_path = './checkpoint-31250/checkpoint-encoder-31250/'
         
     | 
| 153 | 
         
            +
            decoder_path = './checkpoint-31250/checkpoint-decoder-31250/'
         
     | 
| 154 | 
         
            +
            block_size = 100
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            # Load a trained Encoder model and vocabulary that you have fine-tuned
         
     | 
| 157 | 
         
            +
            encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert']
         
     | 
| 158 | 
         
            +
            model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size)
         
     | 
| 159 | 
         
            +
            tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            model_encoder.to('cuda')
         
     | 
| 162 | 
         
            +
            if block_size <= 0:
         
     | 
| 163 | 
         
            +
                block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
         
     | 
| 164 | 
         
            +
            block_size = min(block_size, tokenizer_encoder.max_len_single_sentence)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            # Load a trained Decoder model and vocabulary that you have fine-tuned
         
     | 
| 167 | 
         
            +
            decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2']
         
     | 
| 168 | 
         
            +
            model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size)
         
     | 
| 169 | 
         
            +
            tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False)
         
     | 
| 170 | 
         
            +
            model_decoder.to('cuda')
         
     | 
| 171 | 
         
            +
            if block_size <= 0:
         
     | 
| 172 | 
         
            +
                block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
         
     | 
| 173 | 
         
            +
            block_size = min(block_size, tokenizer_decoder.max_len_single_sentence)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            # Load full model
         
     | 
| 176 | 
         
            +
            output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/' 
         
     | 
| 177 | 
         
            +
            checkpoint = torch.load(os.path.join(model_path, 'training.bin'))
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            # Chunyuan: Add Padding token to GPT2
         
     | 
| 180 | 
         
            +
            special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
         
     | 
| 181 | 
         
            +
            num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
         
     | 
| 182 | 
         
            +
            print('We have added', num_added_toks, 'tokens to GPT2')
         
     | 
| 183 | 
         
            +
            model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
         
     | 
| 184 | 
         
            +
            assert tokenizer_decoder.pad_token == '<PAD>'
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
            # Evaluation
         
     | 
| 188 | 
         
            +
            model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':'cuda'}))
         
     | 
| 189 | 
         
            +
            model_vae.load_state_dict(checkpoint['model_state_dict'])
         
     | 
| 190 | 
         
            +
            print("Pre-trained Optimus is successfully loaded")
         
     | 
| 191 | 
         
            +
            model_vae.to('cuda').to(torch.bfloat16)
         
     | 
| 192 | 
         
            +
            model_vae = torch.compile(model_vae)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            l = latent_code_from_text('A photo of a mountain.')[0]
         
     | 
| 195 | 
         
            +
            t = text_from_latent_code(l)
         
     | 
| 196 | 
         
            +
            print(t, l, l.shape)
         
     | 
| 197 | 
         
            +
            ################################################
         
     | 
| 198 | 
         | 
| 199 | 
         
             
            import gradio as gr
         
     | 
| 200 | 
         
             
            import numpy as np
         
     | 
| 
         | 
|
| 205 | 
         
             
            import random
         
     | 
| 206 | 
         
             
            import time
         
     | 
| 207 | 
         | 
| 208 | 
         
            +
             
         
     | 
| 209 | 
         
             
            dtype = torch.bfloat16
         
     | 
| 210 | 
         
             
            torch.set_grad_enabled(False)
         
     | 
| 211 | 
         | 
| 
         | 
|
| 221 | 
         
             
            def generate(prompt, in_embs=None,):
         
     | 
| 222 | 
         
             
              if prompt != '':
         
     | 
| 223 | 
         
             
                print(prompt)
         
     | 
| 224 | 
         
            +
                in_embs = in_embs / in_embs.abs().max() * .6 if in_embs != None else None
         
     | 
| 225 | 
         
            +
                in_embs = 1 * in_embs.to('cuda') + 1 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0]
         
     | 
| 226 | 
         
             
              else:
         
     | 
| 227 | 
         
             
                print('From embeds.')
         
     | 
| 228 | 
         
            +
              in_embs = in_embs / in_embs.abs().max() * .6
         
     | 
| 229 | 
         
            +
              in_embs = in_embs.to('cuda').to(torch.bfloat16)
         
     | 
| 230 | 
         
            +
              plt.close('all')
         
     | 
| 231 | 
         
            +
              plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5)
         
     | 
| 232 | 
         
            +
              plt.savefig('real_im_emb_plot.jpg')
         
     | 
| 233 | 
         
            +
                
         
     | 
| 234 | 
         
            +
              
         
     | 
| 235 | 
         
            +
              text = text_from_latent_code(in_embs).replace('<unk> ', '')
         
     | 
| 236 | 
         
            +
              in_embs = latent_code_from_text(text)[0]
         
     | 
| 237 | 
         
            +
              print(text)
         
     | 
| 238 | 
         
             
              return text, in_embs.to('cpu')
         
     | 
| 239 | 
         | 
| 240 | 
         | 
| 
         | 
|
| 251 | 
         
             
                    if len(calibrate_prompts) > 0:
         
     | 
| 252 | 
         
             
                        print('######### Calibrating with sample prompts #########')
         
     | 
| 253 | 
         
             
                        prompt = calibrate_prompts.pop(0)
         
     | 
| 
         | 
|
| 254 | 
         
             
                        text, img_embs = generate(prompt)
         
     | 
| 255 | 
         
             
                        embs += img_embs
         
     | 
| 256 | 
         
             
                        print(len(embs))
         
     | 
| 
         | 
|
| 261 | 
         | 
| 262 | 
         
             
                        # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
         
     | 
| 263 | 
         
             
                        if len(list(set(ys))) <= 1:
         
     | 
| 264 | 
         
            +
                            embs.append(.01*torch.randn(latent_size))
         
     | 
| 265 | 
         
            +
                            embs.append(.01*torch.randn(latent_size))
         
     | 
| 266 | 
         
             
                            ys.append(0)
         
     | 
| 267 | 
         
             
                            ys.append(1)
         
     | 
| 268 | 
         
             
                        if len(list(ys)) < 10:
         
     | 
| 269 | 
         
            +
                            embs += [.01*torch.randn(latent_size)] * 3
         
     | 
| 270 | 
         
             
                            ys += [0] * 3
         
     | 
| 271 | 
         | 
| 272 | 
         
             
                        pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
         
     | 
| 
         | 
|
| 276 | 
         
             
                        random.shuffle(pos_indices)
         
     | 
| 277 | 
         
             
                        random.shuffle(neg_indices)
         
     | 
| 278 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 279 | 
         | 
| 280 | 
         
             
                        if len(neg_indices) > 25:
         
     | 
| 281 | 
         
             
                            neg_indices = neg_indices[1:]
         
     | 
| 
         | 
|
| 290 | 
         
             
                        indices = list(range(len(embs)))
         
     | 
| 291 | 
         | 
| 292 | 
         
             
                        # also add the latest 0 and the latest 1
         
     | 
| 293 | 
         
            +
                        #has_0 = False
         
     | 
| 294 | 
         
            +
                        #has_1 = False
         
     | 
| 295 | 
         
            +
                        #for i in reversed(range(len(ys))):
         
     | 
| 296 | 
         
            +
                        #    if ys[i] == 0 and has_0 == False:
         
     | 
| 297 | 
         
            +
                        #        indices.append(i)
         
     | 
| 298 | 
         
            +
                        #        has_0 = True
         
     | 
| 299 | 
         
            +
                        #    elif ys[i] == 1 and has_1 == False:
         
     | 
| 300 | 
         
            +
                        #        indices.append(i)
         
     | 
| 301 | 
         
            +
                        #        has_1 = True
         
     | 
| 302 | 
         
            +
                        #    if has_0 and has_1:
         
     | 
| 303 | 
         
            +
                        #        break
         
     | 
| 304 | 
         | 
| 305 | 
         
             
                        # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
         
     | 
| 306 | 
         
             
                        # this ends up adding a rating but losing an embedding, it seems.
         
     | 
| 
         | 
|
| 317 | 
         
             
                        print('Gathering coefficients')
         
     | 
| 318 | 
         
             
                        lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
         
     | 
| 319 | 
         
             
                        coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
         
     | 
| 
         | 
|
| 320 | 
         
             
                        print(coef_.shape, 'COEF')
         
     | 
| 321 | 
         
             
                        print('Gathered')
         
     | 
| 322 | 
         | 
    	
        checkpoint-31250/checkpoint-decoder-31250/pytorch_model.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
             
            size 578805986
         
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:956e4d5b697320e6edce57414e379130230773a06073ac61e234148a8b4bbf5d
         
     | 
| 3 | 
         
             
            size 578805986
         
     | 
    	
        checkpoint-31250/checkpoint-decoder-31250/training_decoder_args.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
            -
            size  
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2d81aab70fe9efffb1a6897b867bc45772a53476b746b8ab650150d7c7cd22a7
         
     | 
| 3 | 
         
            +
            size 2337
         
     | 
    	
        checkpoint-31250/checkpoint-encoder-31250/pytorch_model.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
             
            size 438007669
         
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:12c72c37c42dc4b47d60e1f2cde70225c777927b52aaed16c21f75213eedf11a
         
     | 
| 3 | 
         
             
            size 438007669
         
     | 
    	
        checkpoint-31250/checkpoint-encoder-31250/training_encoder_args.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
            -
            size  
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2d81aab70fe9efffb1a6897b867bc45772a53476b746b8ab650150d7c7cd22a7
         
     | 
| 3 | 
         
            +
            size 2337
         
     | 
    	
        checkpoint-31250/checkpoint-full-31250/training.bin
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            -
            oid sha256: 
     | 
| 3 | 
         
            -
            size  
     | 
| 
         | 
|
| 1 | 
         
             
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:78f8d855caf0b82d2912afd262a166d8588c500b0b0576d00cf4910834215627
         
     | 
| 3 | 
         
            +
            size 2949730415
         
     | 
    	
        real_im_emb_plot.jpg
    ADDED
    
    
											 
									 |