Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -81,7 +81,7 @@ def get_gallery_files(file_types):
|
|
| 81 |
import glob
|
| 82 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 83 |
|
| 84 |
-
# Video
|
| 85 |
class VideoSnapshot:
|
| 86 |
def __init__(self):
|
| 87 |
self.snapshot = None
|
|
@@ -94,7 +94,7 @@ class VideoSnapshot:
|
|
| 94 |
return self.snapshot
|
| 95 |
|
| 96 |
# Main App
|
| 97 |
-
st.title("SFT Tiny Titans 🚀 (Fast &
|
| 98 |
|
| 99 |
# Sidebar Galleries
|
| 100 |
st.sidebar.header("Media Gallery 🎨")
|
|
@@ -107,15 +107,15 @@ for gallery_type, file_types, emoji in [("Images 📸", ["png", "jpg", "jpeg"],
|
|
| 107 |
with cols[idx % 2]:
|
| 108 |
if "Images" in gallery_type:
|
| 109 |
from PIL import Image
|
| 110 |
-
st.image(Image.open(file), caption=file.split('/')[-1],
|
| 111 |
elif "Videos" in gallery_type:
|
| 112 |
st.video(file)
|
| 113 |
|
| 114 |
# Sidebar Model Management
|
| 115 |
st.sidebar.subheader("Model Hub 🗂️")
|
| 116 |
model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
|
| 117 |
-
model_options =
|
| 118 |
-
selected_model = st.sidebar.selectbox("Select Model", ["None"
|
| 119 |
if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
| 120 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
| 121 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
|
|
@@ -130,7 +130,7 @@ tab1, tab2, tab3, tab4 = st.tabs(["Build Titan 🌱", "Fine-Tune Titans 🔧", "
|
|
| 130 |
with tab1:
|
| 131 |
st.header("Build Titan 🌱 (Quick Start!)")
|
| 132 |
model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
|
| 133 |
-
base_model = st.selectbox("Select Model", model_options, key="build_model")
|
| 134 |
if st.button("Download Model ⬇️"):
|
| 135 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
|
| 136 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
|
@@ -175,7 +175,7 @@ with tab2:
|
|
| 175 |
dataloader = DataLoader(dataset, batch_size=2)
|
| 176 |
optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
|
| 177 |
st.session_state['builder'].model.train()
|
| 178 |
-
for _ in range(1):
|
| 179 |
for batch in dataloader:
|
| 180 |
optimizer.zero_grad()
|
| 181 |
outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
|
|
@@ -194,7 +194,7 @@ with tab2:
|
|
| 194 |
texts = text_input.splitlines()[:len(images)]
|
| 195 |
optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
|
| 196 |
st.session_state['builder'].pipeline.unet.train()
|
| 197 |
-
for _ in range(1):
|
| 198 |
for img, text in zip(images, texts):
|
| 199 |
optimizer.zero_grad()
|
| 200 |
latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
|
|
@@ -233,7 +233,11 @@ with tab3:
|
|
| 233 |
with tab4:
|
| 234 |
st.header("Camera Snap 📷 (Instant Shots!)")
|
| 235 |
from streamlit_webrtc import webrtc_streamer
|
| 236 |
-
ctx = webrtc_streamer(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
if ctx.video_processor:
|
| 238 |
snapshot_text = st.text_input("Snapshot Text", "Live Snap")
|
| 239 |
if st.button("Snap It! 📸"):
|
|
@@ -241,7 +245,7 @@ with tab4:
|
|
| 241 |
if snapshot:
|
| 242 |
filename = generate_filename(snapshot_text)
|
| 243 |
snapshot.save(filename)
|
| 244 |
-
st.image(snapshot, caption=filename)
|
| 245 |
st.success("Snapped! 🎉")
|
| 246 |
|
| 247 |
# Demo Dataset
|
|
|
|
| 81 |
import glob
|
| 82 |
return sorted([f for ext in file_types for f in glob.glob(f"*.{ext}")])
|
| 83 |
|
| 84 |
+
# Video Processor for WebRTC
|
| 85 |
class VideoSnapshot:
|
| 86 |
def __init__(self):
|
| 87 |
self.snapshot = None
|
|
|
|
| 94 |
return self.snapshot
|
| 95 |
|
| 96 |
# Main App
|
| 97 |
+
st.title("SFT Tiny Titans 🚀 (Fast & Fixed!)")
|
| 98 |
|
| 99 |
# Sidebar Galleries
|
| 100 |
st.sidebar.header("Media Gallery 🎨")
|
|
|
|
| 107 |
with cols[idx % 2]:
|
| 108 |
if "Images" in gallery_type:
|
| 109 |
from PIL import Image
|
| 110 |
+
st.image(Image.open(file), caption=file.split('/')[-1], use_container_width=True)
|
| 111 |
elif "Videos" in gallery_type:
|
| 112 |
st.video(file)
|
| 113 |
|
| 114 |
# Sidebar Model Management
|
| 115 |
st.sidebar.subheader("Model Hub 🗂️")
|
| 116 |
model_type = st.sidebar.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"])
|
| 117 |
+
model_options = {"NLP (Causal LM)": "HuggingFaceTB/SmolLM-135M", "CV (Diffusion)": "CompVis/stable-diffusion-v1-4"}
|
| 118 |
+
selected_model = st.sidebar.selectbox("Select Model", ["None", model_options[model_type]])
|
| 119 |
if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
| 120 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
| 121 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=selected_model)
|
|
|
|
| 130 |
with tab1:
|
| 131 |
st.header("Build Titan 🌱 (Quick Start!)")
|
| 132 |
model_type = st.selectbox("Model Type", ["NLP (Causal LM)", "CV (Diffusion)"], key="build_type")
|
| 133 |
+
base_model = st.selectbox("Select Model", [model_options[model_type]], key="build_model")
|
| 134 |
if st.button("Download Model ⬇️"):
|
| 135 |
config = (ModelConfig if "NLP" in model_type else DiffusionConfig)(name=f"titan_{int(time.time())}", base_model=base_model)
|
| 136 |
builder = ModelBuilder() if "NLP" in model_type else DiffusionBuilder()
|
|
|
|
| 175 |
dataloader = DataLoader(dataset, batch_size=2)
|
| 176 |
optimizer = torch.optim.AdamW(st.session_state['builder'].model.parameters(), lr=2e-5)
|
| 177 |
st.session_state['builder'].model.train()
|
| 178 |
+
for _ in range(1):
|
| 179 |
for batch in dataloader:
|
| 180 |
optimizer.zero_grad()
|
| 181 |
outputs = st.session_state['builder'].model(**{k: v.to(st.session_state['builder'].model.device) for k, v in batch.items()})
|
|
|
|
| 194 |
texts = text_input.splitlines()[:len(images)]
|
| 195 |
optimizer = torch.optim.AdamW(st.session_state['builder'].pipeline.unet.parameters(), lr=1e-5)
|
| 196 |
st.session_state['builder'].pipeline.unet.train()
|
| 197 |
+
for _ in range(1):
|
| 198 |
for img, text in zip(images, texts):
|
| 199 |
optimizer.zero_grad()
|
| 200 |
latents = st.session_state['builder'].pipeline.vae.encode(torch.tensor(np.array(img)).permute(2, 0, 1).unsqueeze(0).float().to(st.session_state['builder'].pipeline.device)).latent_dist.sample()
|
|
|
|
| 233 |
with tab4:
|
| 234 |
st.header("Camera Snap 📷 (Instant Shots!)")
|
| 235 |
from streamlit_webrtc import webrtc_streamer
|
| 236 |
+
ctx = webrtc_streamer(
|
| 237 |
+
key="camera",
|
| 238 |
+
video_processor_factory=VideoSnapshot,
|
| 239 |
+
frontend_rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
| 240 |
+
)
|
| 241 |
if ctx.video_processor:
|
| 242 |
snapshot_text = st.text_input("Snapshot Text", "Live Snap")
|
| 243 |
if st.button("Snap It! 📸"):
|
|
|
|
| 245 |
if snapshot:
|
| 246 |
filename = generate_filename(snapshot_text)
|
| 247 |
snapshot.save(filename)
|
| 248 |
+
st.image(snapshot, caption=filename, use_container_width=True)
|
| 249 |
st.success("Snapped! 🎉")
|
| 250 |
|
| 251 |
# Demo Dataset
|