#from myTextEmbedding import * import gradio as gr import torch import torch.nn as nn from torch import tensor from transformers import BertModel, BertTokenizer #import gzip #import pandas as pd import requests import pickle class EmbeddingModel(nn.Module): def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported super().__init__() self.bertName = bertName # use BERT model self.tokenizer = BertTokenizer.from_pretrained(self.bertName) self.model = BertModel.from_pretrained(self.bertName) def forward(self, s, device = "cuda"): # get tokens, which also include attention_mask tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device) # get token embeddings output = self.model(**tokens) tokens_embeddings = output.last_hidden_state #print("tokens_embeddings:" + str(tokens_embeddings.shape)) # mean pooling to get text embedding embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb] #print("embeddings:" + str(embeddings.shape)) embeddings = embeddings.sum(1) # [B, emb] valid_tokens = tokens.attention_mask.sum(1) # [B] embeddings = embeddings / valid_tokens[...,None] # [B, emb] return embeddings # from scratch: nn.CosineSimilarity(dim = 1)(q,a) def cos_score(self, q, a): q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5)) r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5)) return (q_norm @ r_norm.T).diagonal() # contrastive training class TrainModel(nn.Module): def __init__(self): super().__init__() self.m = EmbeddingModel("bert-base-uncased") def forward(self, s1, s2, score): cos_score = self.m.cos_score(self.m(s1), self.m(s2)) loss = nn.MSELoss()(cos_score, score) return loss, cos_score def searchWiki(s): response = requests.get( 'https://en.wikipedia.org/w/api.php', params={ 'action': 'query', 'format': 'json', 'titles': s, 'prop': 'extracts', 'exintro': True, 'explaintext': True, } ).json() page = next(iter(response['query']['pages'].values())) return page['extract'].replace("\n","") # sentence chunking def chunk(w): return w.split(".") def generate_chunk_data(concepts): wiki_data = [searchWiki(c).replace("\n","") for c in concepts] chunk_data = [] for w in wiki_data: chunk_data = chunk_data + chunk(w) chunk_data = [c.strip()+"." for c in chunk_data] while '.' in chunk_data: chunk_data.remove('.') return chunk_data def generate_chunk_emb(m, chunk_data): with torch.no_grad(): emb = m(chunk_data, device = "cpu") return emb def search_document(s, chunk_data, chunk_emb, m, topk=3): question = [s] with torch.no_grad(): result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb) #result_score = m.cos_score(m(question, device = "cpu"),chunk_emb) print(result_score) _,idxs = torch.topk(result_score,topk) print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()]) print(idxs.flatten().tolist()) print(chunk_data) print(len(chunk_data)) return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)] # create the student training model class TrainStudent(nn.Module): def __init__(self, student_model): super().__init__() self.student_model = student_model def forward(self, s1, teacher_model): emb_student = self.student_model(s1) emb_teacher = teacher_model(s1) mse = (emb_student - emb_teacher).pow(2).mean() return mse student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model.eval() with open("vector_database.pkl","rb") as f: vector_database=pickle.load(f) def addNewConcepts(user_concepts): return user_concepts def search(input, user_concepts): result = search_document(input, vector_database["chunk_data"], vector_database["chunk_emb"], student_model) return " ".join(result) with gr.Blocks() as demo: gr.HTML("""

Sentence Embedding and Vector Database

""") search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8) with gr.Row(): with gr.Column(scale=1): new_concept_box = gr.Textbox(show_label=False, placeholder="Currently supported concepts in vector database:" + str(vector_database["concepts"]), lines=8) #addConceptBtn = gr.Button("Add concepts") with gr.Column(scale=4): user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8) searchBtn = gr.Button("Search", variant="primary") searchBtn.click( search, [user_input], [search_result], show_progress=True, ) #addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box]) searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True) demo.queue().launch(share=False, inbrowser=True)