Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,6 +49,8 @@ if 'builder' not in st.session_state:
|
|
| 49 |
st.session_state['builder'] = None
|
| 50 |
if 'model_loaded' not in st.session_state:
|
| 51 |
st.session_state['model_loaded'] = False
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Model Configuration Classes
|
| 54 |
@dataclass
|
|
@@ -191,18 +193,19 @@ class DiffusionBuilder:
|
|
| 191 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
| 192 |
optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
|
| 193 |
self.pipeline.unet.train()
|
|
|
|
| 194 |
for epoch in range(epochs):
|
| 195 |
with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
|
| 196 |
total_loss = 0
|
| 197 |
for batch in dataloader:
|
| 198 |
optimizer.zero_grad()
|
| 199 |
-
image = batch["image"][0].to(
|
| 200 |
text = batch["text"][0]
|
| 201 |
-
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(
|
| 202 |
-
noise = torch.randn_like(
|
| 203 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
| 204 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
| 205 |
-
text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(
|
| 206 |
pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
|
| 207 |
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
| 208 |
loss.backward()
|
|
@@ -225,7 +228,7 @@ def generate_filename(sequence, ext="png"):
|
|
| 225 |
import pytz
|
| 226 |
central = pytz.timezone('US/Central')
|
| 227 |
dt = datetime.now(central)
|
| 228 |
-
return f"{dt.strftime('%m-%d-%Y-%I-%M-%p')}.{ext}"
|
| 229 |
|
| 230 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
| 231 |
with open(file_path, 'rb') as f:
|
|
@@ -244,8 +247,7 @@ def get_model_files(model_type="causal_lm"):
|
|
| 244 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 245 |
|
| 246 |
def get_gallery_files(file_types):
|
| 247 |
-
|
| 248 |
-
return files
|
| 249 |
|
| 250 |
def update_gallery():
|
| 251 |
media_files = get_gallery_files(["png"])
|
|
@@ -337,11 +339,19 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
|
| 337 |
st.rerun()
|
| 338 |
|
| 339 |
# Tabs
|
| 340 |
-
|
| 341 |
"Build Titan 🌱", "Camera Snap 📷",
|
| 342 |
"Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
|
| 343 |
"Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
|
| 344 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
with tab1:
|
| 347 |
st.header("Build Titan 🌱")
|
|
@@ -350,9 +360,9 @@ with tab1:
|
|
| 350 |
["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
|
| 351 |
["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
|
| 352 |
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
| 353 |
-
domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪")
|
| 354 |
if st.button("Download Model ⬇️"):
|
| 355 |
-
config =
|
| 356 |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
| 357 |
builder.load_model(base_model, config)
|
| 358 |
builder.save_model(config.model_path)
|
|
|
|
| 49 |
st.session_state['builder'] = None
|
| 50 |
if 'model_loaded' not in st.session_state:
|
| 51 |
st.session_state['model_loaded'] = False
|
| 52 |
+
if 'active_tab' not in st.session_state:
|
| 53 |
+
st.session_state['active_tab'] = "Build Titan 🌱"
|
| 54 |
|
| 55 |
# Model Configuration Classes
|
| 56 |
@dataclass
|
|
|
|
| 193 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
| 194 |
optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
|
| 195 |
self.pipeline.unet.train()
|
| 196 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 197 |
for epoch in range(epochs):
|
| 198 |
with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
|
| 199 |
total_loss = 0
|
| 200 |
for batch in dataloader:
|
| 201 |
optimizer.zero_grad()
|
| 202 |
+
image = batch["image"][0].to(device)
|
| 203 |
text = batch["text"][0]
|
| 204 |
+
latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(device)).latent_dist.sample()
|
| 205 |
+
noise = torch.randn_like(latents)
|
| 206 |
timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
|
| 207 |
noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
|
| 208 |
+
text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(device))[0]
|
| 209 |
pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
|
| 210 |
loss = torch.nn.functional.mse_loss(pred_noise, noise)
|
| 211 |
loss.backward()
|
|
|
|
| 228 |
import pytz
|
| 229 |
central = pytz.timezone('US/Central')
|
| 230 |
dt = datetime.now(central)
|
| 231 |
+
return f"{dt.strftime('%m-%d-%Y-%I-%M-%S-%p')}.{ext}"
|
| 232 |
|
| 233 |
def get_download_link(file_path, mime_type="text/plain", label="Download"):
|
| 234 |
with open(file_path, 'rb') as f:
|
|
|
|
| 247 |
return [d for d in glob.glob(path) if os.path.isdir(d)]
|
| 248 |
|
| 249 |
def get_gallery_files(file_types):
|
| 250 |
+
return sorted(list(set(f for ext in file_types for f in glob.glob(f"*.{ext}")))) # Remove duplicates and sort
|
|
|
|
| 251 |
|
| 252 |
def update_gallery():
|
| 253 |
media_files = get_gallery_files(["png"])
|
|
|
|
| 339 |
st.rerun()
|
| 340 |
|
| 341 |
# Tabs
|
| 342 |
+
tabs = [
|
| 343 |
"Build Titan 🌱", "Camera Snap 📷",
|
| 344 |
"Fine-Tune Titan (NLP) 🔧", "Test Titan (NLP) 🧪", "Agentic RAG Party (NLP) 🌐",
|
| 345 |
"Fine-Tune Titan (CV) 🔧", "Test Titan (CV) 🧪", "Agentic RAG Party (CV) 🌐"
|
| 346 |
+
]
|
| 347 |
+
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8 = st.tabs(tabs)
|
| 348 |
+
|
| 349 |
+
# Log Tab Switches
|
| 350 |
+
for i, tab in enumerate(tabs):
|
| 351 |
+
if st.session_state['active_tab'] != tab and st.session_state.get(f'tab{i}_active', False):
|
| 352 |
+
logger.info(f"Switched to tab: {tab}")
|
| 353 |
+
st.session_state['active_tab'] = tab
|
| 354 |
+
st.session_state[f'tab{i}_active'] = (st.session_state['active_tab'] == tab)
|
| 355 |
|
| 356 |
with tab1:
|
| 357 |
st.header("Build Titan 🌱")
|
|
|
|
| 360 |
["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
|
| 361 |
["stabilityai/stable-diffusion-2-base", "runwayml/stable-diffusion-v1-5"])
|
| 362 |
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
| 363 |
+
domain = st.text_input("Target Domain", "general", help="Where will your Titan flex its muscles? 💪") if model_type == "Causal LM" else None
|
| 364 |
if st.button("Download Model ⬇️"):
|
| 365 |
+
config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain) if model_type == "Causal LM" else DiffusionConfig(name=model_name, base_model=base_model, size="small")
|
| 366 |
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
| 367 |
builder.load_model(base_model, config)
|
| 368 |
builder.save_model(config.model_path)
|