wb-droid's picture
use pre-built vector table for performance.
fe60887
#from myTextEmbedding import *
import gradio as gr
import torch
import torch.nn as nn
from torch import tensor
from transformers import BertModel, BertTokenizer
#import gzip
#import pandas as pd
import requests
import pickle
class EmbeddingModel(nn.Module):
def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
super().__init__()
self.bertName = bertName
# use BERT model
self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
self.model = BertModel.from_pretrained(self.bertName)
def forward(self, s, device = "cuda"):
# get tokens, which also include attention_mask
tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
# get token embeddings
output = self.model(**tokens)
tokens_embeddings = output.last_hidden_state
#print("tokens_embeddings:" + str(tokens_embeddings.shape))
# mean pooling to get text embedding
embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
#print("embeddings:" + str(embeddings.shape))
embeddings = embeddings.sum(1) # [B, emb]
valid_tokens = tokens.attention_mask.sum(1) # [B]
embeddings = embeddings / valid_tokens[...,None] # [B, emb]
return embeddings
# from scratch: nn.CosineSimilarity(dim = 1)(q,a)
def cos_score(self, q, a):
q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
return (q_norm @ r_norm.T).diagonal()
# contrastive training
class TrainModel(nn.Module):
def __init__(self):
super().__init__()
self.m = EmbeddingModel("bert-base-uncased")
def forward(self, s1, s2, score):
cos_score = self.m.cos_score(self.m(s1), self.m(s2))
loss = nn.MSELoss()(cos_score, score)
return loss, cos_score
def searchWiki(s):
response = requests.get(
'https://en.wikipedia.org/w/api.php',
params={
'action': 'query',
'format': 'json',
'titles': s,
'prop': 'extracts',
'exintro': True,
'explaintext': True,
}
).json()
page = next(iter(response['query']['pages'].values()))
return page['extract'].replace("\n","")
# sentence chunking
def chunk(w):
return w.split(".")
def generate_chunk_data(concepts):
wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
chunk_data = []
for w in wiki_data:
chunk_data = chunk_data + chunk(w)
chunk_data = [c.strip()+"." for c in chunk_data]
while '.' in chunk_data:
chunk_data.remove('.')
return chunk_data
def generate_chunk_emb(m, chunk_data):
with torch.no_grad():
emb = m(chunk_data, device = "cpu")
return emb
def search_document(s, chunk_data, chunk_emb, m, topk=3):
question = [s]
with torch.no_grad():
result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
#result_score = m.cos_score(m(question, device = "cpu"),chunk_emb)
print(result_score)
_,idxs = torch.topk(result_score,topk)
print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
print(idxs.flatten().tolist())
print(chunk_data)
print(len(chunk_data))
return [chunk_data[idx] for idx in idxs.flatten().tolist() if idx < len(chunk_data)]
# create the student training model
class TrainStudent(nn.Module):
def __init__(self, student_model):
super().__init__()
self.student_model = student_model
def forward(self, s1, teacher_model):
emb_student = self.student_model(s1)
emb_teacher = teacher_model(s1)
mse = (emb_student - emb_teacher).pow(2).mean()
return mse
student_model=torch.load("myTextEmbeddingStudent.pt",map_location='cpu').student_model.eval()
with open("vector_database.pkl","rb") as f:
vector_database=pickle.load(f)
def addNewConcepts(user_concepts):
return user_concepts
def search(input, user_concepts):
result = search_document(input, vector_database["chunk_data"], vector_database["chunk_emb"], student_model)
return " ".join(result)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">Sentence Embedding and Vector Database</h1>""")
search_result = gr.Textbox(show_label=False, placeholder="Search Result", lines=8)
with gr.Row():
with gr.Column(scale=1):
new_concept_box = gr.Textbox(show_label=False, placeholder="Currently supported concepts in vector database:" + str(vector_database["concepts"]), lines=8)
#addConceptBtn = gr.Button("Add concepts")
with gr.Column(scale=4):
user_input = gr.Textbox(show_label=False, placeholder="Enter question on the concept...", lines=8)
searchBtn = gr.Button("Search", variant="primary")
searchBtn.click(
search,
[user_input],
[search_result],
show_progress=True,
)
#addConceptBtn.click(addNewConcepts, [user_concepts], [new_concept_box])
searchBtn.click(search, inputs=[user_input, new_concept_box], outputs=[search_result], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)