AlexanderKazakov
commited on
Commit
·
34b78ab
1
Parent(s):
eeafaaa
add cross-encoder and HF API LLM
Browse files- gradio_app/app.py +47 -18
- gradio_app/backend/ChatGptInteractor.py +34 -32
- gradio_app/backend/HuggingfaceGenerator.py +44 -0
- gradio_app/backend/cross_encoder.py +32 -0
- gradio_app/backend/query_llm.py +41 -145
- settings.py +8 -2
gradio_app/app.py
CHANGED
|
@@ -13,7 +13,8 @@ import markdown
|
|
| 13 |
from jinja2 import Environment, FileSystemLoader
|
| 14 |
|
| 15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
| 16 |
-
from gradio_app.backend.
|
|
|
|
| 17 |
from gradio_app.backend.semantic_search import table, embedder
|
| 18 |
|
| 19 |
from settings import *
|
|
@@ -45,42 +46,52 @@ def add_text(history, text):
|
|
| 45 |
return history, gr.Textbox(value="", interactive=False)
|
| 46 |
|
| 47 |
|
| 48 |
-
def bot(history,
|
| 49 |
-
top_k_rank = 5
|
| 50 |
-
thresh_dist = 1.2
|
| 51 |
history[-1][1] = ""
|
| 52 |
query = history[-1][0]
|
| 53 |
|
| 54 |
if not query:
|
| 55 |
-
gr.
|
| 56 |
-
raise ValueError("Empty string was submitted")
|
| 57 |
|
| 58 |
logger.info('Retrieving documents...')
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
query_vec = embedder.embed(query)[0]
|
| 63 |
-
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
|
|
|
|
|
|
| 64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
| 65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
| 66 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
| 67 |
|
| 68 |
-
|
| 69 |
-
logger.info(f'Finished Retrieving documents in {round(
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
while len(documents) != 0:
|
| 72 |
context = context_template.render(documents=documents)
|
| 73 |
documents_html = [markdown.markdown(d) for d in documents]
|
| 74 |
context_html = context_html_template.render(documents=documents_html)
|
| 75 |
-
messages =
|
| 76 |
-
num_tokens = num_tokens_from_messages(messages,
|
| 77 |
-
if num_tokens + 512 < context_lengths[
|
| 78 |
break
|
| 79 |
documents.pop()
|
| 80 |
else:
|
| 81 |
raise gr.Error('Model context length exceeded, reload the page')
|
| 82 |
|
| 83 |
-
|
|
|
|
| 84 |
history[-1][1] += part
|
| 85 |
yield history, context_html
|
| 86 |
else:
|
|
@@ -110,7 +121,25 @@ with gr.Blocks() as demo:
|
|
| 110 |
)
|
| 111 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Examples
|
| 116 |
gr.Examples(examples, input_textbox)
|
|
@@ -122,7 +151,7 @@ with gr.Blocks() as demo:
|
|
| 122 |
txt_msg = txt_btn.click(
|
| 123 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
| 124 |
).then(
|
| 125 |
-
bot, [chatbot,
|
| 126 |
)
|
| 127 |
|
| 128 |
# Turn it back on
|
|
@@ -130,7 +159,7 @@ with gr.Blocks() as demo:
|
|
| 130 |
|
| 131 |
# Turn off interactivity while generating if you hit enter
|
| 132 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
| 133 |
-
bot, [chatbot,
|
| 134 |
|
| 135 |
# Turn it back on
|
| 136 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
|
|
|
| 13 |
from jinja2 import Environment, FileSystemLoader
|
| 14 |
|
| 15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
| 16 |
+
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
|
| 17 |
+
from gradio_app.backend.query_llm import *
|
| 18 |
from gradio_app.backend.semantic_search import table, embedder
|
| 19 |
|
| 20 |
from settings import *
|
|
|
|
| 46 |
return history, gr.Textbox(value="", interactive=False)
|
| 47 |
|
| 48 |
|
| 49 |
+
def bot(history, llm, cross_enc):
|
|
|
|
|
|
|
| 50 |
history[-1][1] = ""
|
| 51 |
query = history[-1][0]
|
| 52 |
|
| 53 |
if not query:
|
| 54 |
+
raise gr.Error("Empty string was submitted")
|
|
|
|
| 55 |
|
| 56 |
logger.info('Retrieving documents...')
|
| 57 |
+
gr.Info('Start documents retrieval ...')
|
| 58 |
+
time = perf_counter()
|
| 59 |
|
| 60 |
query_vec = embedder.embed(query)[0]
|
| 61 |
+
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
| 62 |
+
documents = documents.limit(TOP_K_RANK).to_list()
|
| 63 |
+
thresh_dist = thresh_distances[EMBED_NAME]
|
| 64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
| 65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
| 66 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
| 67 |
|
| 68 |
+
time = perf_counter() - time
|
| 69 |
+
logger.info(f'Finished Retrieving documents in {round(time, 2)} seconds...')
|
| 70 |
|
| 71 |
+
logger.info('Reranking documents...')
|
| 72 |
+
gr.Info('Start documents reranking ...')
|
| 73 |
+
time = perf_counter()
|
| 74 |
+
|
| 75 |
+
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
| 76 |
+
|
| 77 |
+
time = perf_counter() - time
|
| 78 |
+
logger.info(f'Finished Reranking documents in {round(time, 2)} seconds...')
|
| 79 |
+
|
| 80 |
+
msg_constructor = get_message_constructor(llm)
|
| 81 |
while len(documents) != 0:
|
| 82 |
context = context_template.render(documents=documents)
|
| 83 |
documents_html = [markdown.markdown(d) for d in documents]
|
| 84 |
context_html = context_html_template.render(documents=documents_html)
|
| 85 |
+
messages = msg_constructor(context, history)
|
| 86 |
+
num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo') # todo for HF, it is approximation
|
| 87 |
+
if num_tokens + 512 < context_lengths[llm]:
|
| 88 |
break
|
| 89 |
documents.pop()
|
| 90 |
else:
|
| 91 |
raise gr.Error('Model context length exceeded, reload the page')
|
| 92 |
|
| 93 |
+
llm_gen = get_llm_generator(llm)
|
| 94 |
+
for part in llm_gen(messages):
|
| 95 |
history[-1][1] += part
|
| 96 |
yield history, context_html
|
| 97 |
else:
|
|
|
|
| 121 |
)
|
| 122 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
| 123 |
|
| 124 |
+
llm_name = gr.Radio(
|
| 125 |
+
choices=[
|
| 126 |
+
"gpt-3.5-turbo",
|
| 127 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
| 128 |
+
"GeneZC/MiniChat-3B",
|
| 129 |
+
],
|
| 130 |
+
value="gpt-3.5-turbo",
|
| 131 |
+
label='LLM'
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
cross_enc_name = gr.Radio(
|
| 135 |
+
choices=[
|
| 136 |
+
None,
|
| 137 |
+
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
| 138 |
+
"cross-encoder/ms-marco-MiniLM-L-12-v2",
|
| 139 |
+
],
|
| 140 |
+
value=None,
|
| 141 |
+
label='Cross-Encoder'
|
| 142 |
+
)
|
| 143 |
|
| 144 |
# Examples
|
| 145 |
gr.Examples(examples, input_textbox)
|
|
|
|
| 151 |
txt_msg = txt_btn.click(
|
| 152 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
| 153 |
).then(
|
| 154 |
+
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html]
|
| 155 |
)
|
| 156 |
|
| 157 |
# Turn it back on
|
|
|
|
| 159 |
|
| 160 |
# Turn off interactivity while generating if you hit enter
|
| 161 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
| 162 |
+
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html])
|
| 163 |
|
| 164 |
# Turn it back on
|
| 165 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
gradio_app/backend/ChatGptInteractor.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
|
| 3 |
import tiktoken
|
|
@@ -9,6 +10,10 @@ with open('data/openaikey.txt') as f:
|
|
| 9 |
openai.api_key = OPENAI_KEY
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def num_tokens_from_messages(messages, model):
|
| 13 |
"""
|
| 14 |
Return the number of tokens used by a list of messages.
|
|
@@ -17,7 +22,7 @@ def num_tokens_from_messages(messages, model):
|
|
| 17 |
try:
|
| 18 |
encoding = tiktoken.encoding_for_model(model)
|
| 19 |
except KeyError:
|
| 20 |
-
|
| 21 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 22 |
if model in {
|
| 23 |
"gpt-3.5-turbo-0613",
|
|
@@ -33,10 +38,10 @@ def num_tokens_from_messages(messages, model):
|
|
| 33 |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
| 34 |
tokens_per_name = -1 # if there's a name, the role is omitted
|
| 35 |
elif "gpt-3.5-turbo" in model:
|
| 36 |
-
#
|
| 37 |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
| 38 |
elif "gpt-4" in model:
|
| 39 |
-
#
|
| 40 |
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
| 41 |
else:
|
| 42 |
raise NotImplementedError(
|
|
@@ -54,8 +59,11 @@ def num_tokens_from_messages(messages, model):
|
|
| 54 |
|
| 55 |
|
| 56 |
class ChatGptInteractor:
|
| 57 |
-
def __init__(self, model_name='gpt-3.5-turbo'):
|
| 58 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
| 59 |
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
| 60 |
|
| 61 |
def chat_completion_simple(
|
|
@@ -63,15 +71,9 @@ class ChatGptInteractor:
|
|
| 63 |
*,
|
| 64 |
user_text,
|
| 65 |
system_text=None,
|
| 66 |
-
max_tokens=None,
|
| 67 |
-
temperature=None,
|
| 68 |
-
stream=False,
|
| 69 |
):
|
| 70 |
return self.chat_completion(
|
| 71 |
self._construct_messages_simple(user_text, system_text),
|
| 72 |
-
max_tokens=max_tokens,
|
| 73 |
-
temperature=temperature,
|
| 74 |
-
stream=stream,
|
| 75 |
)
|
| 76 |
|
| 77 |
def count_tokens_simple(self, *, user_text, system_text=None):
|
|
@@ -91,27 +93,17 @@ class ChatGptInteractor:
|
|
| 91 |
})
|
| 92 |
return messages
|
| 93 |
|
| 94 |
-
def chat_completion(
|
| 95 |
-
|
| 96 |
-
messages,
|
| 97 |
-
max_tokens=None,
|
| 98 |
-
temperature=None,
|
| 99 |
-
stream=False,
|
| 100 |
-
):
|
| 101 |
-
print(f'Sending request to {self.model_name} stream={stream} ...')
|
| 102 |
t1 = time.time()
|
| 103 |
-
completion = self._request(
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
stream=stream,
|
| 109 |
-
)
|
| 110 |
-
if stream:
|
| 111 |
-
return completion
|
| 112 |
t2 = time.time()
|
| 113 |
usage = completion['usage']
|
| 114 |
-
|
| 115 |
f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
|
| 116 |
f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
|
| 117 |
)
|
|
@@ -121,14 +113,23 @@ class ChatGptInteractor:
|
|
| 121 |
def get_stream_text(stream_part):
|
| 122 |
return stream_part['choices'][0]['delta'].get('content', '')
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def count_tokens(self, messages):
|
| 125 |
return num_tokens_from_messages(messages, self.model_name)
|
| 126 |
|
| 127 |
-
def _request(self,
|
| 128 |
for _ in range(5):
|
| 129 |
try:
|
| 130 |
completion = openai.ChatCompletion.create(
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
request_timeout=100.0,
|
| 133 |
)
|
| 134 |
return completion
|
|
@@ -164,7 +165,8 @@ if __name__ == '__main__':
|
|
| 164 |
print(cgi.chat_completion_simple(user_text=ut, system_text=st))
|
| 165 |
print('---')
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
|
|
|
| 169 |
print('\n---')
|
| 170 |
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import time
|
| 3 |
|
| 4 |
import tiktoken
|
|
|
|
| 10 |
openai.api_key = OPENAI_KEY
|
| 11 |
|
| 12 |
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
def num_tokens_from_messages(messages, model):
|
| 18 |
"""
|
| 19 |
Return the number of tokens used by a list of messages.
|
|
|
|
| 22 |
try:
|
| 23 |
encoding = tiktoken.encoding_for_model(model)
|
| 24 |
except KeyError:
|
| 25 |
+
logger.info("Warning: model not found. Using cl100k_base encoding.")
|
| 26 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 27 |
if model in {
|
| 28 |
"gpt-3.5-turbo-0613",
|
|
|
|
| 38 |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
| 39 |
tokens_per_name = -1 # if there's a name, the role is omitted
|
| 40 |
elif "gpt-3.5-turbo" in model:
|
| 41 |
+
# logger.info()("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
| 42 |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
| 43 |
elif "gpt-4" in model:
|
| 44 |
+
# logger.info()("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
| 45 |
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
| 46 |
else:
|
| 47 |
raise NotImplementedError(
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
class ChatGptInteractor:
|
| 62 |
+
def __init__(self, model_name='gpt-3.5-turbo', max_tokens=None, temperature=None, stream=False):
|
| 63 |
self.model_name = model_name
|
| 64 |
+
self.max_tokens = max_tokens
|
| 65 |
+
self.temperature = temperature
|
| 66 |
+
self.stream = stream
|
| 67 |
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
| 68 |
|
| 69 |
def chat_completion_simple(
|
|
|
|
| 71 |
*,
|
| 72 |
user_text,
|
| 73 |
system_text=None,
|
|
|
|
|
|
|
|
|
|
| 74 |
):
|
| 75 |
return self.chat_completion(
|
| 76 |
self._construct_messages_simple(user_text, system_text),
|
|
|
|
|
|
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
def count_tokens_simple(self, *, user_text, system_text=None):
|
|
|
|
| 93 |
})
|
| 94 |
return messages
|
| 95 |
|
| 96 |
+
def chat_completion(self, messages):
|
| 97 |
+
logger.info(f'Sending request to {self.model_name} stream={self.stream} ...')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
t1 = time.time()
|
| 99 |
+
completion = self._request(messages)
|
| 100 |
+
|
| 101 |
+
if self.stream:
|
| 102 |
+
return self._generator(completion)
|
| 103 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
t2 = time.time()
|
| 105 |
usage = completion['usage']
|
| 106 |
+
logger.info(
|
| 107 |
f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
|
| 108 |
f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
|
| 109 |
)
|
|
|
|
| 113 |
def get_stream_text(stream_part):
|
| 114 |
return stream_part['choices'][0]['delta'].get('content', '')
|
| 115 |
|
| 116 |
+
@staticmethod
|
| 117 |
+
def _generator(completion):
|
| 118 |
+
for part in completion:
|
| 119 |
+
yield ChatGptInteractor.get_stream_text(part)
|
| 120 |
+
|
| 121 |
def count_tokens(self, messages):
|
| 122 |
return num_tokens_from_messages(messages, self.model_name)
|
| 123 |
|
| 124 |
+
def _request(self, messages):
|
| 125 |
for _ in range(5):
|
| 126 |
try:
|
| 127 |
completion = openai.ChatCompletion.create(
|
| 128 |
+
messages=messages,
|
| 129 |
+
model=self.model_name,
|
| 130 |
+
max_tokens=self.max_tokens,
|
| 131 |
+
temperature=self.temperature,
|
| 132 |
+
stream=self.stream,
|
| 133 |
request_timeout=100.0,
|
| 134 |
)
|
| 135 |
return completion
|
|
|
|
| 165 |
print(cgi.chat_completion_simple(user_text=ut, system_text=st))
|
| 166 |
print('---')
|
| 167 |
|
| 168 |
+
cgi = ChatGptInteractor(stream=True)
|
| 169 |
+
for part in cgi.chat_completion_simple(user_text=ut, system_text=st):
|
| 170 |
+
print(part, end='')
|
| 171 |
print('\n---')
|
| 172 |
|
gradio_app/backend/HuggingfaceGenerator.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import InferenceClient
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
|
| 6 |
+
with open('data/hftoken.txt') as f:
|
| 7 |
+
HF_TOKEN = f.read().strip()
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.INFO)
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# noinspection PyTypeChecker
|
| 14 |
+
class HuggingfaceGenerator:
|
| 15 |
+
def __init__(
|
| 16 |
+
self, model_name,
|
| 17 |
+
temperature: float = 0.9, max_new_tokens: int = 512,
|
| 18 |
+
top_p: float = None, repetition_penalty: float = None,
|
| 19 |
+
stream: bool = True,
|
| 20 |
+
):
|
| 21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 22 |
+
self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
|
| 23 |
+
self.stream = stream
|
| 24 |
+
|
| 25 |
+
self.generate_kwargs = {
|
| 26 |
+
'temperature': max(temperature, 0.1),
|
| 27 |
+
'max_new_tokens': max_new_tokens,
|
| 28 |
+
'top_p': top_p,
|
| 29 |
+
'repetition_penalty': repetition_penalty,
|
| 30 |
+
'do_sample': True,
|
| 31 |
+
'seed': 42,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def generate(self, messages):
|
| 35 |
+
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
| 36 |
+
|
| 37 |
+
logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
|
| 38 |
+
stream = self.hf_client.text_generation(
|
| 39 |
+
formatted_prompt, **self.generate_kwargs,
|
| 40 |
+
stream=self.stream, details=True, return_full_text=not self.stream
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
for response in stream:
|
| 44 |
+
yield response.token.text
|
gradio_app/backend/cross_encoder.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
|
| 4 |
+
from settings import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
cross_encoder = None
|
| 8 |
+
cross_enc_tokenizer = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@torch.no_grad()
|
| 12 |
+
def rerank_with_cross_encoder(cross_enc_name, documents, query):
|
| 13 |
+
if cross_enc_name is None or len(documents) <= 1:
|
| 14 |
+
return documents
|
| 15 |
+
|
| 16 |
+
global cross_encoder, cross_enc_tokenizer
|
| 17 |
+
if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
|
| 18 |
+
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
|
| 19 |
+
cross_encoder.eval()
|
| 20 |
+
cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
|
| 21 |
+
|
| 22 |
+
features = cross_enc_tokenizer(
|
| 23 |
+
[query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
|
| 24 |
+
)
|
| 25 |
+
scores = cross_encoder(**features).logits.squeeze()
|
| 26 |
+
ranks = torch.argsort(scores, descending=True)
|
| 27 |
+
documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
|
| 28 |
+
return documents
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
gradio_app/backend/query_llm.py
CHANGED
|
@@ -1,102 +1,30 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
|
| 3 |
-
from typing import Any, Dict, Generator, List
|
| 4 |
-
|
| 5 |
-
# from huggingface_hub import InferenceClient
|
| 6 |
-
# from transformers import AutoTokenizer
|
| 7 |
from jinja2 import Environment, FileSystemLoader
|
| 8 |
|
| 9 |
-
from settings import *
|
| 10 |
from gradio_app.backend.ChatGptInteractor import *
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
|
| 14 |
-
# HF_TOKEN = None
|
| 15 |
-
# hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def format_prompt(message: str, api_kind: str):
|
| 19 |
-
"""
|
| 20 |
-
Formats the given message using a chat template.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
message (str): The user message to be formatted.
|
| 24 |
-
|
| 25 |
-
Returns:
|
| 26 |
-
str: Formatted message after applying the chat template.
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
# Create a list of message dictionaries with role and content
|
| 30 |
-
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
| 31 |
-
|
| 32 |
-
if api_kind == "openai":
|
| 33 |
-
return messages
|
| 34 |
-
elif api_kind == "hf":
|
| 35 |
-
return tokenizer.apply_chat_template(messages, tokenize=False)
|
| 36 |
-
else:
|
| 37 |
-
raise ValueError("API is not supported")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
| 41 |
-
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
| 42 |
-
"""
|
| 43 |
-
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 44 |
-
|
| 45 |
-
Args:
|
| 46 |
-
prompt (str): The initial prompt for the text generation.
|
| 47 |
-
history (str): Context or history for the text generation.
|
| 48 |
-
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
|
| 49 |
-
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
|
| 50 |
-
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
|
| 51 |
-
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
Generator[str, None, str]: A generator yielding chunks of generated text.
|
| 55 |
-
Returns a final string if an error occurs.
|
| 56 |
-
"""
|
| 57 |
-
|
| 58 |
-
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
|
| 59 |
-
top_p = float(top_p)
|
| 60 |
-
|
| 61 |
-
generate_kwargs = {
|
| 62 |
-
'temperature': temperature,
|
| 63 |
-
'max_new_tokens': max_new_tokens,
|
| 64 |
-
'top_p': top_p,
|
| 65 |
-
'repetition_penalty': repetition_penalty,
|
| 66 |
-
'do_sample': True,
|
| 67 |
-
'seed': 42,
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
formatted_prompt = format_prompt(prompt, "hf")
|
| 71 |
-
|
| 72 |
-
try:
|
| 73 |
-
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
|
| 74 |
-
stream=True, details=True, return_full_text=False)
|
| 75 |
-
output = ""
|
| 76 |
-
for response in stream:
|
| 77 |
-
output += response.token.text
|
| 78 |
-
yield output
|
| 79 |
-
|
| 80 |
-
except Exception as e:
|
| 81 |
-
if "Too Many Requests" in str(e):
|
| 82 |
-
print("ERROR: Too many requests on Mistral client")
|
| 83 |
-
gr.Warning("Unfortunately Mistral is unable to process")
|
| 84 |
-
return "Unfortunately, I am not able to process your request now."
|
| 85 |
-
elif "Authorization header is invalid" in str(e):
|
| 86 |
-
print("Authetification error:", str(e))
|
| 87 |
-
gr.Warning("Authentication error: HF token was either not provided or incorrect")
|
| 88 |
-
return "Authentication error"
|
| 89 |
-
else:
|
| 90 |
-
print("Unhandled Exception:", str(e))
|
| 91 |
-
gr.Warning("Unfortunately Mistral is unable to process")
|
| 92 |
-
return "I do not know what happened, but I couldn't understand you."
|
| 93 |
-
|
| 94 |
|
| 95 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
| 96 |
context_template = env.get_template('context_template.j2')
|
| 97 |
start_system_message = context_template.render(documents=[])
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def construct_openai_messages(context, history):
|
| 101 |
messages = [
|
| 102 |
{
|
|
@@ -122,64 +50,32 @@ def construct_openai_messages(context, history):
|
|
| 122 |
return messages
|
| 123 |
|
| 124 |
|
| 125 |
-
def
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
| 132 |
-
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
| 133 |
-
"""
|
| 134 |
-
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
| 135 |
|
| 136 |
-
Args:
|
| 137 |
-
prompt (str): The initial prompt for the text generation.
|
| 138 |
-
history (str): Context or history for the text generation.
|
| 139 |
-
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
|
| 140 |
-
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
|
| 141 |
-
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
|
| 142 |
-
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
formatted_prompt = format_prompt(prompt, "openai")
|
| 160 |
|
| 161 |
-
try:
|
| 162 |
-
stream = openai.ChatCompletion.create(
|
| 163 |
-
model=LLM_NAME,
|
| 164 |
-
messages=formatted_prompt,
|
| 165 |
-
**generate_kwargs,
|
| 166 |
-
stream=True
|
| 167 |
-
)
|
| 168 |
-
output = ""
|
| 169 |
-
for chunk in stream:
|
| 170 |
-
output += chunk.choices[0].delta.get("content", "")
|
| 171 |
-
yield output
|
| 172 |
|
| 173 |
-
except Exception as e:
|
| 174 |
-
if "Too Many Requests" in str(e):
|
| 175 |
-
print("ERROR: Too many requests on OpenAI client")
|
| 176 |
-
gr.Warning("Unfortunately OpenAI is unable to process")
|
| 177 |
-
return "Unfortunately, I am not able to process your request now."
|
| 178 |
-
elif "You didn't provide an API key" in str(e):
|
| 179 |
-
print("Authetification error:", str(e))
|
| 180 |
-
gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
|
| 181 |
-
return "Authentication error"
|
| 182 |
-
else:
|
| 183 |
-
print("Unhandled Exception:", str(e))
|
| 184 |
-
gr.Warning("Unfortunately OpenAI is unable to process")
|
| 185 |
-
return "I do not know what happened, but I couldn't understand you."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from jinja2 import Environment, FileSystemLoader
|
| 2 |
|
|
|
|
| 3 |
from gradio_app.backend.ChatGptInteractor import *
|
| 4 |
+
from gradio_app.backend.HuggingfaceGenerator import HuggingfaceGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
| 7 |
context_template = env.get_template('context_template.j2')
|
| 8 |
start_system_message = context_template.render(documents=[])
|
| 9 |
|
| 10 |
|
| 11 |
+
def construct_mistral_messages(context, history):
|
| 12 |
+
messages = []
|
| 13 |
+
for q, a in history:
|
| 14 |
+
if len(a) == 0: # the last message
|
| 15 |
+
q = context + f'\n\nQuery:\n\n{q}'
|
| 16 |
+
messages.append({
|
| 17 |
+
"role": "user",
|
| 18 |
+
"content": q,
|
| 19 |
+
})
|
| 20 |
+
if len(a) != 0: # some of the previous LLM answers
|
| 21 |
+
messages.append({
|
| 22 |
+
"role": "assistant",
|
| 23 |
+
"content": a,
|
| 24 |
+
})
|
| 25 |
+
return messages
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def construct_openai_messages(context, history):
|
| 29 |
messages = [
|
| 30 |
{
|
|
|
|
| 50 |
return messages
|
| 51 |
|
| 52 |
|
| 53 |
+
def get_message_constructor(llm_name):
|
| 54 |
+
if llm_name == 'gpt-3.5-turbo':
|
| 55 |
+
return construct_openai_messages
|
| 56 |
+
if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "GeneZC/MiniChat-3B"]:
|
| 57 |
+
return construct_mistral_messages
|
| 58 |
+
raise ValueError('Unknown LLM name')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def get_llm_generator(llm_name):
|
| 62 |
+
if llm_name == 'gpt-3.5-turbo':
|
| 63 |
+
cgi = ChatGptInteractor(
|
| 64 |
+
model_name=llm_name, max_tokens=512, temperature=0, stream=True
|
| 65 |
+
)
|
| 66 |
+
return cgi.chat_completion
|
| 67 |
+
if llm_name == 'mistralai/Mistral-7B-Instruct-v0.1':
|
| 68 |
+
hfg = HuggingfaceGenerator(
|
| 69 |
+
model_name=llm_name, temperature=0, max_new_tokens=512,
|
| 70 |
+
)
|
| 71 |
+
return hfg.generate
|
| 72 |
|
| 73 |
+
if llm_name == "GeneZC/MiniChat-3B":
|
| 74 |
+
hfg = HuggingfaceGenerator(
|
| 75 |
+
model_name=llm_name, temperature=0, max_new_tokens=250, stream=False,
|
| 76 |
+
)
|
| 77 |
+
return hfg.generate
|
| 78 |
+
raise ValueError('Unknown LLM name')
|
| 79 |
|
|
|
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
settings.py
CHANGED
|
@@ -5,11 +5,11 @@ VECTOR_COLUMN_NAME = "embedding"
|
|
| 5 |
TEXT_COLUMN_NAME = "text"
|
| 6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
| 7 |
|
| 8 |
-
# LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
| 9 |
-
LLM_NAME = "gpt-3.5-turbo"
|
| 10 |
# EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 11 |
EMBED_NAME = "text-embedding-ada-002"
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
emb_sizes = {
|
| 15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
@@ -17,8 +17,14 @@ emb_sizes = {
|
|
| 17 |
"text-embedding-ada-002": 1536,
|
| 18 |
}
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
context_lengths = {
|
| 21 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
|
|
|
| 22 |
"gpt-3.5-turbo": 4096,
|
| 23 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
| 24 |
"thenlper/gte-large": 512,
|
|
|
|
| 5 |
TEXT_COLUMN_NAME = "text"
|
| 6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
| 7 |
|
|
|
|
|
|
|
| 8 |
# EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 9 |
EMBED_NAME = "text-embedding-ada-002"
|
| 10 |
|
| 11 |
+
TOP_K_RANK = 50
|
| 12 |
+
TOP_K_RERANK = 5
|
| 13 |
|
| 14 |
emb_sizes = {
|
| 15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
|
|
| 17 |
"text-embedding-ada-002": 1536,
|
| 18 |
}
|
| 19 |
|
| 20 |
+
thresh_distances = {
|
| 21 |
+
"sentence-transformers/all-MiniLM-L6-v2": 1.2,
|
| 22 |
+
"text-embedding-ada-002": 0.5,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
context_lengths = {
|
| 26 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
| 27 |
+
"GeneZC/MiniChat-3B": 4096,
|
| 28 |
"gpt-3.5-turbo": 4096,
|
| 29 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
| 30 |
"thenlper/gte-large": 512,
|