Spaces:
Runtime error
Runtime error
taskswithcode
commited on
Commit
·
b65a786
1
Parent(s):
fb73c83
Fixes
Browse files- app.py +23 -15
- doc_app_models.json +61 -1
- text-search-ada-doc-001_planets_qna_search.json +0 -0
- text-search-ada-doc-001_qna2_search.json +0 -0
- text-search-ada-doc-001_qna_search.json +0 -0
- text-search-babbage-doc-001_planets_qna_search.json +0 -0
- text-search-babbage-doc-001_qna2_search.json +0 -0
- text-search-babbage-doc-001_qna_search.json +0 -0
- text-search-curie-doc-001_planets_qna_search.json +0 -0
- text-search-curie-doc-001_qna2_search.json +0 -0
- text-search-curie-doc-001_qna_search.json +0 -0
- text-search-davinci-doc-001_planets_qna_search.json +0 -0
- text-search-davinci-doc-001_qna2_search.json +0 -0
- text-search-davinci-doc-001_qna_search.json +0 -0
- twc_embeddings.py +6 -6
- twc_openai_search.py +124 -0
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from io import StringIO
|
|
| 6 |
import pdb
|
| 7 |
import json
|
| 8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
|
|
|
| 9 |
import torch
|
| 10 |
import requests
|
| 11 |
import socket
|
|
@@ -59,7 +60,7 @@ def get_views(action):
|
|
| 59 |
|
| 60 |
def construct_model_info_for_display(model_names):
|
| 61 |
options_arr = []
|
| 62 |
-
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>
|
| 63 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
| 64 |
for node in model_names:
|
| 65 |
options_arr .append(node["name"])
|
|
@@ -102,15 +103,15 @@ def load_model(model_name,model_class,load_model_name):
|
|
| 102 |
|
| 103 |
|
| 104 |
@st.experimental_memo
|
| 105 |
-
def cached_compute_similarity(sentences,_model,model_name,main_index):
|
| 106 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 107 |
results = _model.output_results(None,texts,embeddings,main_index)
|
| 108 |
return results
|
| 109 |
|
| 110 |
|
| 111 |
-
def uncached_compute_similarity(sentences,_model,model_name,main_index):
|
| 112 |
with st.spinner('Computing vectors for sentences'):
|
| 113 |
-
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
| 114 |
results = _model.output_results(None,texts,embeddings,main_index)
|
| 115 |
#st.success("Similarity computation complete")
|
| 116 |
return results
|
|
@@ -123,7 +124,7 @@ def get_model_info(model_names,model_name):
|
|
| 123 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
| 124 |
|
| 125 |
|
| 126 |
-
def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded,custom_model):
|
| 127 |
display_area.text("Loading model:" + model_name)
|
| 128 |
#Note. model_name may get mapped to new name in the call below for custom models
|
| 129 |
orig_model_name = model_name
|
|
@@ -135,14 +136,18 @@ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploa
|
|
| 135 |
if ("Note" in model_info):
|
| 136 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
| 137 |
display_area.write(fail_link)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
model = load_model(model_name,model_info["class"],load_model_name)
|
| 139 |
display_area.text("Model " + model_name + " load complete")
|
| 140 |
try:
|
| 141 |
if (user_uploaded):
|
| 142 |
-
results = uncached_compute_similarity(sentences,model,model_name,main_index)
|
| 143 |
else:
|
| 144 |
display_area.text("Computing vectors for sentences")
|
| 145 |
-
results = cached_compute_similarity(sentences,model,model_name,main_index)
|
| 146 |
display_area.text("Similarity computation complete")
|
| 147 |
return results
|
| 148 |
|
|
@@ -254,15 +259,18 @@ def app_main(app_mode,example_files,model_name_files):
|
|
| 254 |
run_model = selected_model
|
| 255 |
st.session_state["model_name"] = selected_model
|
| 256 |
st.session_state["main_index"] = main_index
|
| 257 |
-
results = run_test(model_names,run_model,sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
|
| 258 |
display_area.empty()
|
| 259 |
with display_area.container():
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
| 266 |
st.download_button(
|
| 267 |
label="Download results as json",
|
| 268 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
|
|
|
| 6 |
import pdb
|
| 7 |
import json
|
| 8 |
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
| 9 |
+
from twc_openai_search import OpenAIQnAModel
|
| 10 |
import torch
|
| 11 |
import requests
|
| 12 |
import socket
|
|
|
|
| 60 |
|
| 61 |
def construct_model_info_for_display(model_names):
|
| 62 |
options_arr = []
|
| 63 |
+
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
|
| 64 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
| 65 |
for node in model_names:
|
| 66 |
options_arr .append(node["name"])
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@st.experimental_memo
|
| 106 |
+
def cached_compute_similarity(input_file_name,sentences,_model,model_name,main_index):
|
| 107 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
| 108 |
results = _model.output_results(None,texts,embeddings,main_index)
|
| 109 |
return results
|
| 110 |
|
| 111 |
|
| 112 |
+
def uncached_compute_similarity(input_file_name,sentences,_model,model_name,main_index):
|
| 113 |
with st.spinner('Computing vectors for sentences'):
|
| 114 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
| 115 |
results = _model.output_results(None,texts,embeddings,main_index)
|
| 116 |
#st.success("Similarity computation complete")
|
| 117 |
return results
|
|
|
|
| 124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
| 125 |
|
| 126 |
|
| 127 |
+
def run_test(model_names,model_name,input_file_name,sentences,display_area,main_index,user_uploaded,custom_model):
|
| 128 |
display_area.text("Loading model:" + model_name)
|
| 129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
| 130 |
orig_model_name = model_name
|
|
|
|
| 136 |
if ("Note" in model_info):
|
| 137 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
| 138 |
display_area.write(fail_link)
|
| 139 |
+
if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
|
| 140 |
+
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
| 141 |
+
display_area.write(fail_link)
|
| 142 |
+
return {"error":fail_link}
|
| 143 |
model = load_model(model_name,model_info["class"],load_model_name)
|
| 144 |
display_area.text("Model " + model_name + " load complete")
|
| 145 |
try:
|
| 146 |
if (user_uploaded):
|
| 147 |
+
results = uncached_compute_similarity(input_file_name,sentences,model,model_name,main_index)
|
| 148 |
else:
|
| 149 |
display_area.text("Computing vectors for sentences")
|
| 150 |
+
results = cached_compute_similarity(input_file_name,sentences,model,model_name,main_index)
|
| 151 |
display_area.text("Similarity computation complete")
|
| 152 |
return results
|
| 153 |
|
|
|
|
| 259 |
run_model = selected_model
|
| 260 |
st.session_state["model_name"] = selected_model
|
| 261 |
st.session_state["main_index"] = main_index
|
| 262 |
+
results = run_test(model_names,run_model,st.session_state["file_name"],sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
|
| 263 |
display_area.empty()
|
| 264 |
with display_area.container():
|
| 265 |
+
if ("error" in results):
|
| 266 |
+
st.error(results["error"])
|
| 267 |
+
else:
|
| 268 |
+
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
| 269 |
+
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
| 270 |
+
if (len(custom_model_selection) != 0):
|
| 271 |
+
st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
| 272 |
+
display_results(sentences,main_index - 1,results,response_info,app_mode,run_model)
|
| 273 |
+
#st.json(results)
|
| 274 |
st.download_button(
|
| 275 |
label="Download results as json",
|
| 276 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
doc_app_models.json
CHANGED
|
@@ -108,7 +108,67 @@
|
|
| 108 |
},
|
| 109 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
| 110 |
"mark":"True",
|
| 111 |
-
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
]
|
|
|
|
| 108 |
},
|
| 109 |
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
| 110 |
"mark":"True",
|
| 111 |
+
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"},
|
| 112 |
+
{ "name":"GPT-3-175B (text-search-davinci-doc-001)" ,
|
| 113 |
+
"model":"text-search-davinci-doc-001",
|
| 114 |
+
"fork_url":"https://openai.com/api/",
|
| 115 |
+
"orig_author_url":"https://openai.com/api/",
|
| 116 |
+
"orig_author":"OpenAI",
|
| 117 |
+
"sota_info": {
|
| 118 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
| 119 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
| 120 |
+
},
|
| 121 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
| 122 |
+
"mark":"True",
|
| 123 |
+
"custom_load":"False",
|
| 124 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
| 125 |
+
"alt_url":"https://openai.com/api/",
|
| 126 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
| 127 |
+
{ "name":"GPT-3-6.7B (text-search-curie-doc-001)" ,
|
| 128 |
+
"model":"text-search-curie-doc-001",
|
| 129 |
+
"fork_url":"https://openai.com/api/",
|
| 130 |
+
"orig_author_url":"https://openai.com/api/",
|
| 131 |
+
"orig_author":"OpenAI",
|
| 132 |
+
"sota_info": {
|
| 133 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
| 134 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
| 135 |
+
},
|
| 136 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
| 137 |
+
"mark":"True",
|
| 138 |
+
"custom_load":"False",
|
| 139 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
| 140 |
+
"alt_url":"https://openai.com/api/",
|
| 141 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
| 142 |
+
{ "name":"GPT-3-1.3B (text-search-babbage-doc-001)" ,
|
| 143 |
+
"model":"text-search-babbage-doc-001",
|
| 144 |
+
"fork_url":"https://openai.com/api/",
|
| 145 |
+
"orig_author_url":"https://openai.com/api/",
|
| 146 |
+
"orig_author":"OpenAI",
|
| 147 |
+
"sota_info": {
|
| 148 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
| 149 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
| 150 |
+
},
|
| 151 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
| 152 |
+
"mark":"True",
|
| 153 |
+
"custom_load":"False",
|
| 154 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
| 155 |
+
"alt_url":"https://openai.com/api/",
|
| 156 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"},
|
| 157 |
+
{ "name":"GPT-3-350M (text-search-ada-doc-001)" ,
|
| 158 |
+
"model":"text-search-ada-doc-001",
|
| 159 |
+
"fork_url":"https://openai.com/api/",
|
| 160 |
+
"orig_author_url":"https://openai.com/api/",
|
| 161 |
+
"orig_author":"OpenAI",
|
| 162 |
+
"sota_info": {
|
| 163 |
+
"task":"GPT-3 achieves strong zero-shot and few-shot performance on many NLP datasets etc.",
|
| 164 |
+
"sota_link":"https://paperswithcode.com/method/gpt-3"
|
| 165 |
+
},
|
| 166 |
+
"paper_url":"https://arxiv.org/abs/2005.14165v4",
|
| 167 |
+
"mark":"True",
|
| 168 |
+
"custom_load":"False",
|
| 169 |
+
"Note":"Custom file upload requires OpenAI API access to create embeddings. For API access, use this link ",
|
| 170 |
+
"alt_url":"https://openai.com/api/",
|
| 171 |
+
"class":"OpenAIQnAModel","sota_link":"https://arxiv.org/abs/2005.14165v4"}
|
| 172 |
|
| 173 |
|
| 174 |
]
|
text-search-ada-doc-001_planets_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-ada-doc-001_qna2_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-ada-doc-001_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-babbage-doc-001_planets_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-babbage-doc-001_qna2_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-babbage-doc-001_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-curie-doc-001_planets_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-curie-doc-001_qna2_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-curie-doc-001_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-davinci-doc-001_planets_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-davinci-doc-001_qna2_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text-search-davinci-doc-001_qna_search.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
twc_embeddings.py
CHANGED
|
@@ -32,7 +32,7 @@ class CausalLMModel:
|
|
| 32 |
self.model.eval()
|
| 33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
| 34 |
|
| 35 |
-
def compute_embeddings(self,input_data,is_file):
|
| 36 |
if (self.debug):
|
| 37 |
print("Computing embeddings for:", input_data[:20])
|
| 38 |
model = self.model
|
|
@@ -160,7 +160,7 @@ class SGPTQnAModel:
|
|
| 160 |
|
| 161 |
return embeddings
|
| 162 |
|
| 163 |
-
def compute_embeddings(self,input_data,is_file):
|
| 164 |
if (self.debug):
|
| 165 |
print("Computing embeddings for:", input_data[:20])
|
| 166 |
model = self.model
|
|
@@ -215,7 +215,7 @@ class SimCSEModel:
|
|
| 215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 216 |
self.model = AutoModel.from_pretrained(model_name)
|
| 217 |
|
| 218 |
-
def compute_embeddings(self,input_data,is_file):
|
| 219 |
texts = read_text(input_data) if is_file == True else input_data
|
| 220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 221 |
with torch.no_grad():
|
|
@@ -266,7 +266,7 @@ class SGPTModel:
|
|
| 266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
| 267 |
self.model.eval()
|
| 268 |
|
| 269 |
-
def compute_embeddings(self,input_data,is_file):
|
| 270 |
if (self.debug):
|
| 271 |
print("Computing embeddings for:", input_data[:20])
|
| 272 |
model = self.model
|
|
@@ -353,7 +353,7 @@ class HFModel:
|
|
| 353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 355 |
|
| 356 |
-
def compute_embeddings(self,input_data,is_file):
|
| 357 |
#print("Computing embeddings for:", input_data[:20])
|
| 358 |
model = self.model
|
| 359 |
tokenizer = self.tokenizer
|
|
@@ -403,5 +403,5 @@ if __name__ == '__main__':
|
|
| 403 |
results = parser.parse_args()
|
| 404 |
obj = HFModel()
|
| 405 |
obj.init_model(results.model)
|
| 406 |
-
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
|
| 407 |
results = obj.output_results(results.output,texts,embeddings)
|
|
|
|
| 32 |
self.model.eval()
|
| 33 |
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
| 34 |
|
| 35 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
| 36 |
if (self.debug):
|
| 37 |
print("Computing embeddings for:", input_data[:20])
|
| 38 |
model = self.model
|
|
|
|
| 160 |
|
| 161 |
return embeddings
|
| 162 |
|
| 163 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
| 164 |
if (self.debug):
|
| 165 |
print("Computing embeddings for:", input_data[:20])
|
| 166 |
model = self.model
|
|
|
|
| 215 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 216 |
self.model = AutoModel.from_pretrained(model_name)
|
| 217 |
|
| 218 |
+
def compute_embeddings(self,input_file_name,input_file,input_data,is_file):
|
| 219 |
texts = read_text(input_data) if is_file == True else input_data
|
| 220 |
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
| 221 |
with torch.no_grad():
|
|
|
|
| 266 |
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
| 267 |
self.model.eval()
|
| 268 |
|
| 269 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
| 270 |
if (self.debug):
|
| 271 |
print("Computing embeddings for:", input_data[:20])
|
| 272 |
model = self.model
|
|
|
|
| 353 |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 354 |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 355 |
|
| 356 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
| 357 |
#print("Computing embeddings for:", input_data[:20])
|
| 358 |
model = self.model
|
| 359 |
tokenizer = self.tokenizer
|
|
|
|
| 403 |
results = parser.parse_args()
|
| 404 |
obj = HFModel()
|
| 405 |
obj.init_model(results.model)
|
| 406 |
+
texts, embeddings = obj.compute_embeddings(results.input,results.input,is_file = True)
|
| 407 |
results = obj.output_results(results.output,texts,embeddings)
|
twc_openai_search.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.spatial.distance import cosine
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import openai
|
| 6 |
+
import pdb
|
| 7 |
+
|
| 8 |
+
def read_text(input_file):
|
| 9 |
+
arr = open(input_file).read().split("\n")
|
| 10 |
+
return arr[:-1]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OpenAIQnAModel:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.debug = False
|
| 16 |
+
self.q_model_name = None
|
| 17 |
+
self.d_model_name = None
|
| 18 |
+
self.skip_key = True
|
| 19 |
+
print("In OpenAI API constructor")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def init_model(self,model_name = None):
|
| 23 |
+
#print("OpenAI: Init model",model_name)
|
| 24 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 25 |
+
if (openai.api_key == None):
|
| 26 |
+
openai.api_key = ""
|
| 27 |
+
print("API key not set")
|
| 28 |
+
|
| 29 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
| 30 |
+
print("Open API key not set")
|
| 31 |
+
|
| 32 |
+
if (model_name is None):
|
| 33 |
+
self.d_model_name = "text-search-ada-doc-001"
|
| 34 |
+
else:
|
| 35 |
+
self.d_model_name = model_name
|
| 36 |
+
self.q_model_name = self.construct_query_model_name(self.d_model_name)
|
| 37 |
+
print(f"OpenAI: Init model complete :query model {self.q_model_name} doc:{self.d_model_name}")
|
| 38 |
+
|
| 39 |
+
def construct_query_model_name(self,d_model_name):
|
| 40 |
+
return d_model_name.replace('-doc-','-query-')
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_embeddings(self,input_file_name,input_data,is_file):
|
| 44 |
+
if (len(openai.api_key) == 0 and not self.skip_key):
|
| 45 |
+
print("Open API key not set")
|
| 46 |
+
return [],[]
|
| 47 |
+
#print("In compute embeddings after key check")
|
| 48 |
+
in_file = input_file_name.split('/')[-1]
|
| 49 |
+
in_file = self.d_model_name + '_' + '.'.join(in_file.split('.')[:-1]) + "_search.json"
|
| 50 |
+
cached = False
|
| 51 |
+
try:
|
| 52 |
+
fp = open(in_file)
|
| 53 |
+
cached = True
|
| 54 |
+
embeddings = json.load(fp)
|
| 55 |
+
q_embeddings = [embeddings[0]]
|
| 56 |
+
d_embeddings = embeddings[1:]
|
| 57 |
+
print("Using cached embeddings")
|
| 58 |
+
except:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
texts = read_text(input_data) if is_file == True else input_data
|
| 62 |
+
queries = [texts[0]]
|
| 63 |
+
docs = texts[1:]
|
| 64 |
+
|
| 65 |
+
if (not cached):
|
| 66 |
+
print(f"Computing embeddings for {input_file_name} and query model {self.q_model_name}")
|
| 67 |
+
query_embeds = openai.Embedding.create(
|
| 68 |
+
input=queries,
|
| 69 |
+
model=self.q_model_name
|
| 70 |
+
)
|
| 71 |
+
print(f"Computing embeddings for {input_file_name} and doc model {self.q_model_name}")
|
| 72 |
+
doc_embeds = openai.Embedding.create(
|
| 73 |
+
input=docs,
|
| 74 |
+
model=self.d_model_name
|
| 75 |
+
)
|
| 76 |
+
q_embeddings = []
|
| 77 |
+
d_embeddings = []
|
| 78 |
+
for i in range(len(query_embeds['data'])):
|
| 79 |
+
q_embeddings.append(query_embeds['data'][i]['embedding'])
|
| 80 |
+
for i in range(len(doc_embeds['data'])):
|
| 81 |
+
d_embeddings.append(doc_embeds['data'][i]['embedding'])
|
| 82 |
+
if (not cached):
|
| 83 |
+
embeddings = q_embeddings + d_embeddings
|
| 84 |
+
with open(in_file,"w") as fp:
|
| 85 |
+
json.dump(embeddings,fp)
|
| 86 |
+
return texts,(q_embeddings,d_embeddings)
|
| 87 |
+
|
| 88 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
| 89 |
+
# Calculate cosine similarities
|
| 90 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
| 91 |
+
query_embeddings = embeddings[0]
|
| 92 |
+
doc_embeddings = embeddings[1]
|
| 93 |
+
cosine_dict = {}
|
| 94 |
+
queries = [texts[0]]
|
| 95 |
+
docs = texts[1:]
|
| 96 |
+
if (self.debug):
|
| 97 |
+
print("Total sentences",len(texts))
|
| 98 |
+
for i in range(len(docs)):
|
| 99 |
+
cosine_dict[docs[i]] = 1 - cosine(query_embeddings[0], doc_embeddings[i])
|
| 100 |
+
|
| 101 |
+
if (self.debug):
|
| 102 |
+
print("Input sentence:",texts[main_index])
|
| 103 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
| 104 |
+
if (self.debug):
|
| 105 |
+
for key in sorted_dict:
|
| 106 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
| 107 |
+
if (output_file is not None):
|
| 108 |
+
with open(output_file,"w") as fp:
|
| 109 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
| 110 |
+
return sorted_dict
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
parser = argparse.ArgumentParser(description='OpenAI model for document search embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 116 |
+
parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
|
| 117 |
+
parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
|
| 118 |
+
parser.add_argument('-model', action="store", dest="model",default="text-search-ada-doc-001",help="model name")
|
| 119 |
+
|
| 120 |
+
results = parser.parse_args()
|
| 121 |
+
obj = OpenAIQnAModel()
|
| 122 |
+
obj.init_model(results.model)
|
| 123 |
+
texts, embeddings = obj.compute_embeddings(results.input,results.input,is_file = True)
|
| 124 |
+
results = obj.output_results(results.output,texts,embeddings)
|