Spaces:
Running
on
Zero
Running
on
Zero
Update withanyone/flux/model.py
Browse files- withanyone/flux/model.py +7 -6
withanyone/flux/model.py
CHANGED
|
@@ -145,13 +145,14 @@ class SiglipEmbedding(nn.Module):
|
|
| 145 |
self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
|
| 146 |
self.processor = AutoProcessor.from_pretrained(siglip_path)
|
| 147 |
# self.model.to(torch.cuda.current_device())
|
| 148 |
-
self.
|
|
|
|
| 149 |
|
| 150 |
# BiRefNet matting setup
|
| 151 |
self.use_matting = use_matting
|
| 152 |
if self.use_matting:
|
| 153 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 154 |
-
'briaai/RMBG-2.0', trust_remote_code=True).to(
|
| 155 |
# Apply half precision to the entire model after loading
|
| 156 |
self.matting_transform = transforms.Compose([
|
| 157 |
# transforms.Resize((512, 512)),
|
|
@@ -165,7 +166,7 @@ class SiglipEmbedding(nn.Module):
|
|
| 165 |
return image
|
| 166 |
|
| 167 |
# Convert to input format and move to GPU
|
| 168 |
-
input_image = self.matting_transform(image).unsqueeze(0).to(
|
| 169 |
|
| 170 |
# Generate prediction
|
| 171 |
with torch.no_grad(), autocast(dtype=torch.bfloat16):
|
|
@@ -205,7 +206,7 @@ class SiglipEmbedding(nn.Module):
|
|
| 205 |
|
| 206 |
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 207 |
# device
|
| 208 |
-
pixel_values = pixel_values.to(
|
| 209 |
last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
|
| 210 |
# pooled_output = self.model(pixel_values).pooler_output # 2, 768
|
| 211 |
siglip_embedding.append(last_hidden_state)
|
|
@@ -217,14 +218,14 @@ class SiglipEmbedding(nn.Module):
|
|
| 217 |
for _ in range(4 - batch_size):
|
| 218 |
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 219 |
# device
|
| 220 |
-
pixel_values = pixel_values.to(
|
| 221 |
last_hidden_state = self.model(pixel_values).last_hidden_state
|
| 222 |
|
| 223 |
elif isinstance(refimage, torch.Tensor):
|
| 224 |
# refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
|
| 225 |
batch_size, num_of_person, C, H, W = refimage.shape
|
| 226 |
refimage = refimage.view(batch_size * num_of_person, C, H, W)
|
| 227 |
-
refimage = refimage.to(
|
| 228 |
last_hidden_state = self.model(refimage).last_hidden_state
|
| 229 |
siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
|
| 230 |
|
|
|
|
| 145 |
self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
|
| 146 |
self.processor = AutoProcessor.from_pretrained(siglip_path)
|
| 147 |
# self.model.to(torch.cuda.current_device())
|
| 148 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 149 |
+
self.model.to(self.device)
|
| 150 |
|
| 151 |
# BiRefNet matting setup
|
| 152 |
self.use_matting = use_matting
|
| 153 |
if self.use_matting:
|
| 154 |
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 155 |
+
'briaai/RMBG-2.0', trust_remote_code=True).to(self.device, dtype=torch.bfloat16)
|
| 156 |
# Apply half precision to the entire model after loading
|
| 157 |
self.matting_transform = transforms.Compose([
|
| 158 |
# transforms.Resize((512, 512)),
|
|
|
|
| 166 |
return image
|
| 167 |
|
| 168 |
# Convert to input format and move to GPU
|
| 169 |
+
input_image = self.matting_transform(image).unsqueeze(0).to(self.device, dtype=torch.bfloat16)
|
| 170 |
|
| 171 |
# Generate prediction
|
| 172 |
with torch.no_grad(), autocast(dtype=torch.bfloat16):
|
|
|
|
| 206 |
|
| 207 |
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 208 |
# device
|
| 209 |
+
pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16)
|
| 210 |
last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
|
| 211 |
# pooled_output = self.model(pixel_values).pooler_output # 2, 768
|
| 212 |
siglip_embedding.append(last_hidden_state)
|
|
|
|
| 218 |
for _ in range(4 - batch_size):
|
| 219 |
pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
|
| 220 |
# device
|
| 221 |
+
pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16)
|
| 222 |
last_hidden_state = self.model(pixel_values).last_hidden_state
|
| 223 |
|
| 224 |
elif isinstance(refimage, torch.Tensor):
|
| 225 |
# refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
|
| 226 |
batch_size, num_of_person, C, H, W = refimage.shape
|
| 227 |
refimage = refimage.view(batch_size * num_of_person, C, H, W)
|
| 228 |
+
refimage = refimage.to(self.device, dtype=torch.bfloat16)
|
| 229 |
last_hidden_state = self.model(refimage).last_hidden_state
|
| 230 |
siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
|
| 231 |
|