Spaces:
Runtime error
Runtime error
Trent
commited on
Commit
·
883e41e
1
Parent(s):
75c3a89
Clustering function
Browse files- app.py +5 -1
- backend/inference.py +68 -1
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -118,4 +118,8 @@ For more cool information on sentence embeddings, see the [sBert project](https:
|
|
| 118 |
|
| 119 |
if st.button('Give me my search.'):
|
| 120 |
results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
|
| 121 |
-
st.table(pd.DataFrame(results[select_models[0]]).T)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
if st.button('Give me my search.'):
|
| 120 |
results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
|
| 121 |
+
st.table(pd.DataFrame(results[select_models[0]]).T)
|
| 122 |
+
|
| 123 |
+
if st.button('3D Clustering of search result (new window)'):
|
| 124 |
+
fig = inference.text_cluster(anchor, 1000, select_models[0], QA_MODELS_ID)
|
| 125 |
+
fig.show()
|
backend/inference.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gzip
|
| 2 |
import json
|
|
|
|
| 3 |
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
|
@@ -11,7 +12,7 @@ from typing import List, Union
|
|
| 11 |
import torch
|
| 12 |
|
| 13 |
from backend.utils import load_model, filter_questions, load_embeddings
|
| 14 |
-
|
| 15 |
|
| 16 |
def cos_sim(a, b):
|
| 17 |
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
|
@@ -71,3 +72,69 @@ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
|
| 71 |
urls.append(f"https://stackoverflow.com/q/{post['id']}")
|
| 72 |
|
| 73 |
return hits_titles, hits_scores, urls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gzip
|
| 2 |
import json
|
| 3 |
+
from collections import Counter
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
|
|
|
| 12 |
import torch
|
| 13 |
|
| 14 |
from backend.utils import load_model, filter_questions, load_embeddings
|
| 15 |
+
from MulticoreTSNE import MulticoreTSNE as TSNE
|
| 16 |
|
| 17 |
def cos_sim(a, b):
|
| 18 |
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
|
|
|
| 72 |
urls.append(f"https://stackoverflow.com/q/{post['id']}")
|
| 73 |
|
| 74 |
return hits_titles, hits_scores, urls
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
|
| 78 |
+
# Proceeding with model
|
| 79 |
+
print(model_name)
|
| 80 |
+
assert model_name == "mpnet_qa"
|
| 81 |
+
model = load_model(model_name, model_dict)
|
| 82 |
+
|
| 83 |
+
# Creating embeddings
|
| 84 |
+
query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
|
| 85 |
+
|
| 86 |
+
print("loading embeddings")
|
| 87 |
+
corpus_emb = load_embeddings()
|
| 88 |
+
|
| 89 |
+
# Getting hits
|
| 90 |
+
hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
|
| 91 |
+
|
| 92 |
+
filtered_posts = filter_questions("python")
|
| 93 |
+
|
| 94 |
+
hits_dict = [filtered_posts[hit['corpus_id']] for hit in hits]
|
| 95 |
+
hits_dict.append(dict(id = '1', title = anchor, tags = ['']))
|
| 96 |
+
|
| 97 |
+
hits_emb = torch.stack([corpus_emb[hit['corpus_id']] for hit in hits])
|
| 98 |
+
hits_emb = torch.cat((hits_emb, query_emb))
|
| 99 |
+
|
| 100 |
+
# Dimensionality reduction with t-SNE
|
| 101 |
+
tsne = TSNE(n_components=3, verbose=1, perplexity=15, n_iter=1000)
|
| 102 |
+
tsne_results = tsne.fit_transform(hits_emb.cpu())
|
| 103 |
+
df = pd.DataFrame(hits_dict)
|
| 104 |
+
tags = list(df['tags'])
|
| 105 |
+
|
| 106 |
+
counter = Counter(tags[0])
|
| 107 |
+
for i in tags[1:]:
|
| 108 |
+
counter.update(i)
|
| 109 |
+
|
| 110 |
+
df_tags = pd.DataFrame(counter.most_common(), columns=['Tag', 'Mentions'])
|
| 111 |
+
most_common_tags = list(df_tags['Tag'])[1:5]
|
| 112 |
+
|
| 113 |
+
labels = []
|
| 114 |
+
|
| 115 |
+
for tags_list in list(df['tags']):
|
| 116 |
+
for common_tag in most_common_tags:
|
| 117 |
+
if common_tag in tags_list:
|
| 118 |
+
labels.append(common_tag)
|
| 119 |
+
break
|
| 120 |
+
elif common_tag != most_common_tags[-1]:
|
| 121 |
+
continue
|
| 122 |
+
else:
|
| 123 |
+
labels.append('others')
|
| 124 |
+
|
| 125 |
+
df['title'] = [post['title'] for post in hits_dict]
|
| 126 |
+
df['labels'] = labels
|
| 127 |
+
df['tsne_x'] = tsne_results[:, 0]
|
| 128 |
+
df['tsne_y'] = tsne_results[:, 1]
|
| 129 |
+
df['tsne_z'] = tsne_results[:, 2]
|
| 130 |
+
|
| 131 |
+
df['size'] = [2 for i in range(len(df))]
|
| 132 |
+
|
| 133 |
+
# Making the query bigger than the rest of the observations
|
| 134 |
+
df['size'][len(df) - 1] = 10
|
| 135 |
+
df['labels'][len(df) - 1] = 'QUERY'
|
| 136 |
+
import plotly.express as px
|
| 137 |
+
|
| 138 |
+
fig = px.scatter_3d(df, x='tsne_x', y='tsne_y', z='tsne_z', color='labels', size='size',
|
| 139 |
+
color_discrete_sequence=px.colors.qualitative.D3, hover_data=[df.title])
|
| 140 |
+
return fig
|
requirements.txt
CHANGED
|
@@ -5,3 +5,5 @@ jaxlib
|
|
| 5 |
streamlit
|
| 6 |
numpy
|
| 7 |
torch
|
|
|
|
|
|
|
|
|
| 5 |
streamlit
|
| 6 |
numpy
|
| 7 |
torch
|
| 8 |
+
MulticoreTSNE
|
| 9 |
+
plotly
|