Spaces:
Runtime error
Runtime error
Steven Chen
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -181,85 +181,85 @@ def chat_llama3_8b(message: str,
|
|
| 181 |
Returns:
|
| 182 |
str: Generated response with citations if available
|
| 183 |
"""
|
| 184 |
-
try:
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
# 3. Construct the final prompt
|
| 197 |
-
final_message = ""
|
| 198 |
-
if citation:
|
| 199 |
-
final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
|
| 200 |
-
else:
|
| 201 |
-
final_message = f"{background_prompt}\n{message}"
|
| 202 |
-
|
| 203 |
-
conversation.append({"role": "user", "content": final_message})
|
| 204 |
-
|
| 205 |
-
# 4. Prepare model inputs
|
| 206 |
-
input_ids = tokenizer.apply_chat_template(
|
| 207 |
-
conversation,
|
| 208 |
-
return_tensors="pt"
|
| 209 |
-
).to(model.device)
|
| 210 |
-
|
| 211 |
-
# 5. Setup streamer
|
| 212 |
-
streamer = TextIteratorStreamer(
|
| 213 |
-
tokenizer,
|
| 214 |
-
timeout=10.0,
|
| 215 |
-
skip_prompt=True,
|
| 216 |
-
skip_special_tokens=True
|
| 217 |
-
)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
"temperature": temperature,
|
| 226 |
-
"eos_token_id": terminators
|
| 227 |
-
}
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
except (StopIteration, RuntimeError):
|
| 246 |
-
final_chunk = True
|
| 247 |
|
| 248 |
-
|
| 249 |
-
if final_chunk and citation:
|
| 250 |
-
formatted_citations = "\n\nReferences:\n" + "\n".join(
|
| 251 |
-
f"[{i+1}] {cite.strip()}"
|
| 252 |
-
for i, cite in enumerate(citation.split('\n'))
|
| 253 |
-
if cite.strip()
|
| 254 |
-
)
|
| 255 |
-
current_response += formatted_citations
|
| 256 |
-
|
| 257 |
-
yield current_response
|
| 258 |
|
| 259 |
-
except Exception as e:
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
|
| 264 |
|
| 265 |
# Gradio block
|
|
|
|
| 181 |
Returns:
|
| 182 |
str: Generated response with citations if available
|
| 183 |
"""
|
| 184 |
+
# try:
|
| 185 |
+
# 1. Get relevant citations from vector store
|
| 186 |
+
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
|
| 187 |
+
|
| 188 |
+
# 2. Format conversation history
|
| 189 |
+
conversation = []
|
| 190 |
+
for user, assistant in history:
|
| 191 |
+
conversation.extend([
|
| 192 |
+
{"role": "user", "content": str(user)},
|
| 193 |
+
{"role": "assistant", "content": str(assistant)}
|
| 194 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
# 3. Construct the final prompt
|
| 197 |
+
final_message = ""
|
| 198 |
+
if citation:
|
| 199 |
+
final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
|
| 200 |
+
else:
|
| 201 |
+
final_message = f"{background_prompt}\n{message}"
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
conversation.append({"role": "user", "content": final_message})
|
| 204 |
+
|
| 205 |
+
# 4. Prepare model inputs
|
| 206 |
+
input_ids = tokenizer.apply_chat_template(
|
| 207 |
+
conversation,
|
| 208 |
+
return_tensors="pt"
|
| 209 |
+
).to(model.device)
|
| 210 |
+
|
| 211 |
+
# 5. Setup streamer
|
| 212 |
+
streamer = TextIteratorStreamer(
|
| 213 |
+
tokenizer,
|
| 214 |
+
timeout=10.0,
|
| 215 |
+
skip_prompt=True,
|
| 216 |
+
skip_special_tokens=True
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# 6. Configure generation parameters
|
| 220 |
+
generation_config = {
|
| 221 |
+
"input_ids": input_ids,
|
| 222 |
+
"streamer": streamer,
|
| 223 |
+
"max_new_tokens": max_new_tokens,
|
| 224 |
+
"do_sample": temperature > 0,
|
| 225 |
+
"temperature": temperature,
|
| 226 |
+
"eos_token_id": terminators
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
# 7. Generate in a separate thread
|
| 230 |
+
thread = Thread(target=model.generate, kwargs=generation_config)
|
| 231 |
+
thread.start()
|
| 232 |
+
|
| 233 |
+
# 8. Stream the output
|
| 234 |
+
accumulated_text = []
|
| 235 |
+
final_chunk = False
|
| 236 |
+
|
| 237 |
+
for text_chunk in streamer:
|
| 238 |
+
accumulated_text.append(text_chunk)
|
| 239 |
+
current_response = "".join(accumulated_text)
|
| 240 |
|
| 241 |
+
# Check if this is the last chunk
|
| 242 |
+
try:
|
| 243 |
+
next_chunk = next(iter(streamer))
|
| 244 |
+
accumulated_text.append(next_chunk)
|
| 245 |
+
except (StopIteration, RuntimeError):
|
| 246 |
+
final_chunk = True
|
| 247 |
|
| 248 |
+
# Add citations on the final chunk if they exist
|
| 249 |
+
if final_chunk and citation:
|
| 250 |
+
formatted_citations = "\n\nReferences:\n" + "\n".join(
|
| 251 |
+
f"[{i+1}] {cite.strip()}"
|
| 252 |
+
for i, cite in enumerate(citation.split('\n'))
|
| 253 |
+
if cite.strip()
|
| 254 |
+
)
|
| 255 |
+
current_response += formatted_citations
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
yield current_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
+
# except Exception as e:
|
| 260 |
+
# error_message = f"An error occurred: {str(e)}"
|
| 261 |
+
# print(error_message) # For logging
|
| 262 |
+
# yield error_message
|
| 263 |
|
| 264 |
|
| 265 |
# Gradio block
|