Spaces:
Running
Running
pin torch to 2.4.0
Browse files- app.py +19 -19
- requirements.txt +2 -2
app.py
CHANGED
|
@@ -7,19 +7,18 @@ from datetime import datetime
|
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
-
from huggingface_hub import hf_hub_download
|
| 11 |
from duckduckgo_search import DDGS
|
| 12 |
-
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# ------------------------------
|
| 15 |
# Global Cancellation Event
|
| 16 |
# ------------------------------
|
| 17 |
cancel_event = threading.Event()
|
| 18 |
|
| 19 |
-
# ------------------------------
|
| 20 |
-
# Model Definitions and Global Variables (PyTorch/Transformers)
|
| 21 |
-
# ------------------------------
|
| 22 |
-
# Here, the repo_id should point to a model checkpoint that is compatible with Hugging Face Transformers.
|
| 23 |
# ------------------------------
|
| 24 |
# Torch-Compatible Model Definitions with Adjusted Descriptions
|
| 25 |
# ------------------------------
|
|
@@ -70,7 +69,6 @@ MODELS = {
|
|
| 70 |
},
|
| 71 |
}
|
| 72 |
|
| 73 |
-
|
| 74 |
LOADED_MODELS = {}
|
| 75 |
CURRENT_MODEL_NAME = None
|
| 76 |
|
|
@@ -82,7 +80,7 @@ def load_model(model_name):
|
|
| 82 |
if model_name in LOADED_MODELS:
|
| 83 |
return LOADED_MODELS[model_name]
|
| 84 |
selected_model = MODELS[model_name]
|
| 85 |
-
# Load
|
| 86 |
model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
| 87 |
tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
| 88 |
LOADED_MODELS[model_name] = (model, tokenizer)
|
|
@@ -106,15 +104,15 @@ def retrieve_context(query, max_results=6, max_chars_per_result=600):
|
|
| 106 |
return ""
|
| 107 |
|
| 108 |
# ------------------------------
|
| 109 |
-
# Chat Response Generation
|
| 110 |
# ------------------------------
|
| 111 |
-
@spaces.GPU
|
| 112 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
| 113 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
| 114 |
# Reset the cancellation event.
|
| 115 |
cancel_event.clear()
|
| 116 |
|
| 117 |
-
# Prepare internal history.
|
| 118 |
internal_history = list(chat_history) if chat_history else []
|
| 119 |
internal_history.append({"role": "user", "content": user_message})
|
| 120 |
|
|
@@ -138,7 +136,7 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
| 138 |
retrieved_context = ""
|
| 139 |
debug_message = "Web search disabled."
|
| 140 |
|
| 141 |
-
# Augment prompt with search context if available.
|
| 142 |
if enable_search and retrieved_context:
|
| 143 |
augmented_user_input = (
|
| 144 |
f"{system_prompt.strip()}\n\n"
|
|
@@ -153,11 +151,13 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
| 153 |
internal_history.append({"role": "assistant", "content": ""})
|
| 154 |
|
| 155 |
try:
|
| 156 |
-
# Load the
|
| 157 |
model, tokenizer = load_model(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
# Tokenize the input prompt.
|
| 160 |
-
input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids
|
| 161 |
with torch.no_grad():
|
| 162 |
output_ids = model.generate(
|
| 163 |
input_ids,
|
|
@@ -168,13 +168,12 @@ def chat_response(user_message, chat_history, system_prompt, enable_search,
|
|
| 168 |
repetition_penalty=repeat_penalty,
|
| 169 |
do_sample=True
|
| 170 |
)
|
| 171 |
-
|
| 172 |
# Decode the generated tokens.
|
| 173 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 174 |
-
#
|
| 175 |
assistant_text = generated_text[len(augmented_user_input):].strip()
|
| 176 |
|
| 177 |
-
# Simulate streaming by yielding
|
| 178 |
words = assistant_text.split()
|
| 179 |
assistant_message = ""
|
| 180 |
for word in words:
|
|
@@ -205,7 +204,7 @@ def cancel_generation():
|
|
| 205 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
| 206 |
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
|
| 207 |
gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
|
| 208 |
-
|
| 209 |
with gr.Row():
|
| 210 |
with gr.Column(scale=3):
|
| 211 |
default_model = list(MODELS.keys())[0] if MODELS else "No models available"
|
|
@@ -252,6 +251,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 252 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
| 253 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
| 254 |
|
|
|
|
| 255 |
msg_input.submit(
|
| 256 |
fn=chat_response,
|
| 257 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 10 |
from duckduckgo_search import DDGS
|
| 11 |
+
import spaces # Import spaces early to enable ZeroGPU support
|
| 12 |
+
|
| 13 |
+
# Disable GPU visibility if you wish to force CPU usage outside of GPU functions
|
| 14 |
+
# (Not strictly needed for ZeroGPU as the decorator handles allocation)
|
| 15 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 16 |
|
| 17 |
# ------------------------------
|
| 18 |
# Global Cancellation Event
|
| 19 |
# ------------------------------
|
| 20 |
cancel_event = threading.Event()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# ------------------------------
|
| 23 |
# Torch-Compatible Model Definitions with Adjusted Descriptions
|
| 24 |
# ------------------------------
|
|
|
|
| 69 |
},
|
| 70 |
}
|
| 71 |
|
|
|
|
| 72 |
LOADED_MODELS = {}
|
| 73 |
CURRENT_MODEL_NAME = None
|
| 74 |
|
|
|
|
| 80 |
if model_name in LOADED_MODELS:
|
| 81 |
return LOADED_MODELS[model_name]
|
| 82 |
selected_model = MODELS[model_name]
|
| 83 |
+
# Load the model and tokenizer using Transformers.
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
| 85 |
tokenizer = AutoTokenizer.from_pretrained(selected_model["repo_id"], trust_remote_code=True)
|
| 86 |
LOADED_MODELS[model_name] = (model, tokenizer)
|
|
|
|
| 104 |
return ""
|
| 105 |
|
| 106 |
# ------------------------------
|
| 107 |
+
# Chat Response Generation with ZeroGPU
|
| 108 |
# ------------------------------
|
| 109 |
+
@spaces.GPU(duration=60) # This decorator triggers GPU allocation for up to 60 seconds.
|
| 110 |
def chat_response(user_message, chat_history, system_prompt, enable_search,
|
| 111 |
max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty):
|
| 112 |
# Reset the cancellation event.
|
| 113 |
cancel_event.clear()
|
| 114 |
|
| 115 |
+
# Prepare internal chat history.
|
| 116 |
internal_history = list(chat_history) if chat_history else []
|
| 117 |
internal_history.append({"role": "user", "content": user_message})
|
| 118 |
|
|
|
|
| 136 |
retrieved_context = ""
|
| 137 |
debug_message = "Web search disabled."
|
| 138 |
|
| 139 |
+
# Augment the prompt with search context if available.
|
| 140 |
if enable_search and retrieved_context:
|
| 141 |
augmented_user_input = (
|
| 142 |
f"{system_prompt.strip()}\n\n"
|
|
|
|
| 151 |
internal_history.append({"role": "assistant", "content": ""})
|
| 152 |
|
| 153 |
try:
|
| 154 |
+
# Load the model and tokenizer.
|
| 155 |
model, tokenizer = load_model(model_name)
|
| 156 |
+
# Move the model to GPU (using .to('cuda')) inside the GPU-decorated function.
|
| 157 |
+
model = model.to('cuda')
|
| 158 |
+
# Tokenize the augmented prompt and move input tensors to GPU.
|
| 159 |
+
input_ids = tokenizer(augmented_user_input, return_tensors="pt").input_ids.to('cuda')
|
| 160 |
|
|
|
|
|
|
|
| 161 |
with torch.no_grad():
|
| 162 |
output_ids = model.generate(
|
| 163 |
input_ids,
|
|
|
|
| 168 |
repetition_penalty=repeat_penalty,
|
| 169 |
do_sample=True
|
| 170 |
)
|
|
|
|
| 171 |
# Decode the generated tokens.
|
| 172 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 173 |
+
# Remove the original prompt to isolate the assistant's reply.
|
| 174 |
assistant_text = generated_text[len(augmented_user_input):].strip()
|
| 175 |
|
| 176 |
+
# Simulate streaming output by yielding word-by-word.
|
| 177 |
words = assistant_text.split()
|
| 178 |
assistant_message = ""
|
| 179 |
for word in words:
|
|
|
|
| 204 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
| 205 |
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
|
| 206 |
gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.")
|
| 207 |
+
|
| 208 |
with gr.Row():
|
| 209 |
with gr.Column(scale=3):
|
| 210 |
default_model = list(MODELS.keys())[0] if MODELS else "No models available"
|
|
|
|
| 251 |
clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug])
|
| 252 |
cancel_button.click(fn=cancel_generation, outputs=search_debug)
|
| 253 |
|
| 254 |
+
# Submission: the chat_response function is now decorated with @spaces.GPU.
|
| 255 |
msg_input.submit(
|
| 256 |
fn=chat_response,
|
| 257 |
inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox,
|
requirements.txt
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
wheel
|
| 2 |
streamlit
|
| 3 |
duckduckgo_search
|
| 4 |
-
gradio
|
| 5 |
-
torch
|
| 6 |
transformers
|
| 7 |
spaces
|
| 8 |
sentencepiece
|
|
|
|
| 1 |
wheel
|
| 2 |
streamlit
|
| 3 |
duckduckgo_search
|
| 4 |
+
gradio>=4.0.0
|
| 5 |
+
torch==2.4.0
|
| 6 |
transformers
|
| 7 |
spaces
|
| 8 |
sentencepiece
|