Spaces:
Runtime error
Runtime error
Steven Chen
commited on
update chat_llama3_8b function
Browse files
app.py
CHANGED
|
@@ -165,67 +165,101 @@ def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8)
|
|
| 165 |
|
| 166 |
@spaces.GPU(duration=120)
|
| 167 |
def chat_llama3_8b(message: str,
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
"""
|
| 173 |
-
Generate a streaming response using the
|
| 174 |
-
Will display citations after the response if citations are available.
|
| 175 |
-
"""
|
| 176 |
-
# Get citations from vector store
|
| 177 |
-
citation = query_vector_store(vector_store, message, 4, 0.7)
|
| 178 |
-
|
| 179 |
-
# Build conversation history
|
| 180 |
-
conversation = []
|
| 181 |
-
for user, assistant in history:
|
| 182 |
-
conversation.extend([
|
| 183 |
-
{"role": "user", "content": user},
|
| 184 |
-
{"role": "assistant", "content": assistant}
|
| 185 |
-
])
|
| 186 |
-
|
| 187 |
-
# Construct the final message with background prompt and citations
|
| 188 |
-
if citation:
|
| 189 |
-
message = f"{background_prompt}Based on these citations: {citation}\nPlease answer question: {message}"
|
| 190 |
-
else:
|
| 191 |
-
message = f"{background_prompt}{message}"
|
| 192 |
-
|
| 193 |
-
conversation.append({"role": "user", "content": message})
|
| 194 |
-
|
| 195 |
-
# Generate response
|
| 196 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
| 197 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 198 |
-
|
| 199 |
-
generate_kwargs = dict(
|
| 200 |
-
input_ids=input_ids,
|
| 201 |
-
streamer=streamer,
|
| 202 |
-
max_new_tokens=max_new_tokens,
|
| 203 |
-
do_sample=True,
|
| 204 |
-
temperature=temperature,
|
| 205 |
-
eos_token_id=terminators,
|
| 206 |
-
)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
-
#
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
# Gradio block
|
|
|
|
| 165 |
|
| 166 |
@spaces.GPU(duration=120)
|
| 167 |
def chat_llama3_8b(message: str,
|
| 168 |
+
history: list,
|
| 169 |
+
temperature=0.6,
|
| 170 |
+
max_new_tokens=4096
|
| 171 |
+
) -> str:
|
| 172 |
"""
|
| 173 |
+
Generate a streaming response using the LLaMA model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
Args:
|
| 176 |
+
message (str): The current user message
|
| 177 |
+
history (list): List of previous conversation turns
|
| 178 |
+
temperature (float): Sampling temperature (0.0 to 1.0)
|
| 179 |
+
max_new_tokens (int): Maximum number of tokens to generate
|
| 180 |
|
| 181 |
+
Returns:
|
| 182 |
+
str: Generated response with citations if available
|
| 183 |
+
"""
|
| 184 |
+
try:
|
| 185 |
+
# 1. Get relevant citations from vector store
|
| 186 |
+
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
|
| 187 |
+
|
| 188 |
+
# 2. Format conversation history
|
| 189 |
+
conversation = []
|
| 190 |
+
for user, assistant in history:
|
| 191 |
+
conversation.extend([
|
| 192 |
+
{"role": "user", "content": str(user)},
|
| 193 |
+
{"role": "assistant", "content": str(assistant)}
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
# 3. Construct the final prompt
|
| 197 |
+
final_message = ""
|
| 198 |
+
if citation:
|
| 199 |
+
final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
|
| 200 |
+
else:
|
| 201 |
+
final_message = f"{background_prompt}\n{message}"
|
| 202 |
+
|
| 203 |
+
conversation.append({"role": "user", "content": final_message})
|
| 204 |
+
|
| 205 |
+
# 4. Prepare model inputs
|
| 206 |
+
input_ids = tokenizer.apply_chat_template(
|
| 207 |
+
conversation,
|
| 208 |
+
return_tensors="pt"
|
| 209 |
+
).to(model.device)
|
| 210 |
+
|
| 211 |
+
# 5. Setup streamer
|
| 212 |
+
streamer = TextIteratorStreamer(
|
| 213 |
+
tokenizer,
|
| 214 |
+
timeout=10.0,
|
| 215 |
+
skip_prompt=True,
|
| 216 |
+
skip_special_tokens=True
|
| 217 |
+
)
|
| 218 |
|
| 219 |
+
# 6. Configure generation parameters
|
| 220 |
+
generation_config = {
|
| 221 |
+
"input_ids": input_ids,
|
| 222 |
+
"streamer": streamer,
|
| 223 |
+
"max_new_tokens": max_new_tokens,
|
| 224 |
+
"do_sample": temperature > 0,
|
| 225 |
+
"temperature": temperature,
|
| 226 |
+
"eos_token_id": terminators
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
# 7. Generate in a separate thread
|
| 230 |
+
thread = Thread(target=model.generate, kwargs=generation_config)
|
| 231 |
+
thread.start()
|
| 232 |
+
|
| 233 |
+
# 8. Stream the output
|
| 234 |
+
accumulated_text = []
|
| 235 |
+
final_chunk = False
|
| 236 |
+
|
| 237 |
+
for text_chunk in streamer:
|
| 238 |
+
accumulated_text.append(text_chunk)
|
| 239 |
+
current_response = "".join(accumulated_text)
|
| 240 |
+
|
| 241 |
+
# Check if this is the last chunk
|
| 242 |
+
try:
|
| 243 |
+
next_chunk = next(iter(streamer))
|
| 244 |
+
accumulated_text.append(next_chunk)
|
| 245 |
+
except (StopIteration, RuntimeError):
|
| 246 |
+
final_chunk = True
|
| 247 |
|
| 248 |
+
# Add citations on the final chunk if they exist
|
| 249 |
+
if final_chunk and citation:
|
| 250 |
+
formatted_citations = "\n\nReferences:\n" + "\n".join(
|
| 251 |
+
f"[{i+1}] {cite.strip()}"
|
| 252 |
+
for i, cite in enumerate(citation.split('\n'))
|
| 253 |
+
if cite.strip()
|
| 254 |
+
)
|
| 255 |
+
current_response += formatted_citations
|
| 256 |
+
|
| 257 |
+
yield current_response
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
error_message = f"An error occurred: {str(e)}"
|
| 261 |
+
print(error_message) # For logging
|
| 262 |
+
yield error_message
|
| 263 |
|
| 264 |
|
| 265 |
# Gradio block
|