Spaces:
Running
on
Zero
Running
on
Zero
altered processor due to huggingface update
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers import (
|
|
| 6 |
TextIteratorStreamer,
|
| 7 |
Gemma3Processor,
|
| 8 |
Gemma3nForConditionalGeneration,
|
|
|
|
| 9 |
)
|
| 10 |
import spaces
|
| 11 |
from threading import Thread
|
|
@@ -22,7 +23,8 @@ load_dotenv(dotenv_path)
|
|
| 22 |
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
| 23 |
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
| 24 |
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
|
| 28 |
model_12_id,
|
|
@@ -70,11 +72,13 @@ def run(
|
|
| 70 |
|
| 71 |
def try_fallback_model(original_model_choice: str):
|
| 72 |
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
|
|
|
|
| 73 |
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
|
| 74 |
logger.info(f"Attempting fallback to {fallback_name} model")
|
| 75 |
-
return fallback_model, fallback_name
|
| 76 |
|
| 77 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
|
|
|
| 78 |
current_model_name = model_choice
|
| 79 |
|
| 80 |
try:
|
|
@@ -94,7 +98,7 @@ def run(
|
|
| 94 |
for i, msg in enumerate(messages):
|
| 95 |
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
|
| 96 |
|
| 97 |
-
inputs =
|
| 98 |
messages,
|
| 99 |
add_generation_prompt=True,
|
| 100 |
tokenize=True,
|
|
@@ -103,7 +107,7 @@ def run(
|
|
| 103 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
| 104 |
|
| 105 |
streamer = TextIteratorStreamer(
|
| 106 |
-
|
| 107 |
)
|
| 108 |
generate_kwargs = dict(
|
| 109 |
inputs,
|
|
@@ -156,11 +160,11 @@ def run(
|
|
| 156 |
|
| 157 |
# Try fallback model
|
| 158 |
try:
|
| 159 |
-
selected_model, fallback_name = try_fallback_model(model_choice)
|
| 160 |
logger.info(f"Switching to fallback model: {fallback_name}")
|
| 161 |
|
| 162 |
# Rebuild inputs for fallback model
|
| 163 |
-
inputs =
|
| 164 |
messages,
|
| 165 |
add_generation_prompt=True,
|
| 166 |
tokenize=True,
|
|
@@ -169,7 +173,7 @@ def run(
|
|
| 169 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
| 170 |
|
| 171 |
streamer = TextIteratorStreamer(
|
| 172 |
-
|
| 173 |
)
|
| 174 |
generate_kwargs = dict(
|
| 175 |
inputs,
|
|
|
|
| 6 |
TextIteratorStreamer,
|
| 7 |
Gemma3Processor,
|
| 8 |
Gemma3nForConditionalGeneration,
|
| 9 |
+
Gemma3nProcessor
|
| 10 |
)
|
| 11 |
import spaces
|
| 12 |
from threading import Thread
|
|
|
|
| 23 |
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
| 24 |
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
| 25 |
|
| 26 |
+
input_processor_12 = Gemma3Processor.from_pretrained(model_12_id)
|
| 27 |
+
input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id)
|
| 28 |
|
| 29 |
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
|
| 30 |
model_12_id,
|
|
|
|
| 72 |
|
| 73 |
def try_fallback_model(original_model_choice: str):
|
| 74 |
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
|
| 75 |
+
fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12
|
| 76 |
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
|
| 77 |
logger.info(f"Attempting fallback to {fallback_name} model")
|
| 78 |
+
return fallback_model, fallback_processor, fallback_name
|
| 79 |
|
| 80 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
| 81 |
+
selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n
|
| 82 |
current_model_name = model_choice
|
| 83 |
|
| 84 |
try:
|
|
|
|
| 98 |
for i, msg in enumerate(messages):
|
| 99 |
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
|
| 100 |
|
| 101 |
+
inputs = selected_processor.apply_chat_template(
|
| 102 |
messages,
|
| 103 |
add_generation_prompt=True,
|
| 104 |
tokenize=True,
|
|
|
|
| 107 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
| 108 |
|
| 109 |
streamer = TextIteratorStreamer(
|
| 110 |
+
selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
| 111 |
)
|
| 112 |
generate_kwargs = dict(
|
| 113 |
inputs,
|
|
|
|
| 160 |
|
| 161 |
# Try fallback model
|
| 162 |
try:
|
| 163 |
+
selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice)
|
| 164 |
logger.info(f"Switching to fallback model: {fallback_name}")
|
| 165 |
|
| 166 |
# Rebuild inputs for fallback model
|
| 167 |
+
inputs = fallback_processor.apply_chat_template(
|
| 168 |
messages,
|
| 169 |
add_generation_prompt=True,
|
| 170 |
tokenize=True,
|
|
|
|
| 173 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
| 174 |
|
| 175 |
streamer = TextIteratorStreamer(
|
| 176 |
+
fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
| 177 |
)
|
| 178 |
generate_kwargs = dict(
|
| 179 |
inputs,
|