WithAnyone commited on
Commit
693172b
·
verified ·
1 Parent(s): 559ce6b

Update withanyone/flux/model.py

Browse files
Files changed (1) hide show
  1. withanyone/flux/model.py +7 -6
withanyone/flux/model.py CHANGED
@@ -145,13 +145,14 @@ class SiglipEmbedding(nn.Module):
145
  self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
146
  self.processor = AutoProcessor.from_pretrained(siglip_path)
147
  # self.model.to(torch.cuda.current_device())
148
- self.model.to("cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16)
 
149
 
150
  # BiRefNet matting setup
151
  self.use_matting = use_matting
152
  if self.use_matting:
153
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
154
- 'briaai/RMBG-2.0', trust_remote_code=True).to(torch.cuda.current_device(), dtype=torch.bfloat16)
155
  # Apply half precision to the entire model after loading
156
  self.matting_transform = transforms.Compose([
157
  # transforms.Resize((512, 512)),
@@ -165,7 +166,7 @@ class SiglipEmbedding(nn.Module):
165
  return image
166
 
167
  # Convert to input format and move to GPU
168
- input_image = self.matting_transform(image).unsqueeze(0).to(torch.cuda.current_device(), dtype=torch.bfloat16)
169
 
170
  # Generate prediction
171
  with torch.no_grad(), autocast(dtype=torch.bfloat16):
@@ -205,7 +206,7 @@ class SiglipEmbedding(nn.Module):
205
 
206
  pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
207
  # device
208
- pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
209
  last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
210
  # pooled_output = self.model(pixel_values).pooler_output # 2, 768
211
  siglip_embedding.append(last_hidden_state)
@@ -217,14 +218,14 @@ class SiglipEmbedding(nn.Module):
217
  for _ in range(4 - batch_size):
218
  pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
219
  # device
220
- pixel_values = pixel_values.to(torch.cuda.current_device(), dtype=torch.bfloat16)
221
  last_hidden_state = self.model(pixel_values).last_hidden_state
222
 
223
  elif isinstance(refimage, torch.Tensor):
224
  # refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
225
  batch_size, num_of_person, C, H, W = refimage.shape
226
  refimage = refimage.view(batch_size * num_of_person, C, H, W)
227
- refimage = refimage.to(torch.cuda.current_device(), dtype=torch.bfloat16)
228
  last_hidden_state = self.model(refimage).last_hidden_state
229
  siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
230
 
 
145
  self.model = SiglipModel.from_pretrained(siglip_path).vision_model.to(torch.bfloat16)
146
  self.processor = AutoProcessor.from_pretrained(siglip_path)
147
  # self.model.to(torch.cuda.current_device())
148
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
149
+ self.model.to(self.device)
150
 
151
  # BiRefNet matting setup
152
  self.use_matting = use_matting
153
  if self.use_matting:
154
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
155
+ 'briaai/RMBG-2.0', trust_remote_code=True).to(self.device, dtype=torch.bfloat16)
156
  # Apply half precision to the entire model after loading
157
  self.matting_transform = transforms.Compose([
158
  # transforms.Resize((512, 512)),
 
166
  return image
167
 
168
  # Convert to input format and move to GPU
169
+ input_image = self.matting_transform(image).unsqueeze(0).to(self.device, dtype=torch.bfloat16)
170
 
171
  # Generate prediction
172
  with torch.no_grad(), autocast(dtype=torch.bfloat16):
 
206
 
207
  pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
208
  # device
209
+ pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16)
210
  last_hidden_state = self.model(pixel_values).last_hidden_state # 2, 256 768
211
  # pooled_output = self.model(pixel_values).pooler_output # 2, 768
212
  siglip_embedding.append(last_hidden_state)
 
218
  for _ in range(4 - batch_size):
219
  pixel_values = self.processor(images=processed_images, return_tensors="pt").pixel_values
220
  # device
221
+ pixel_values = pixel_values.to(self.device, dtype=torch.bfloat16)
222
  last_hidden_state = self.model(pixel_values).last_hidden_state
223
 
224
  elif isinstance(refimage, torch.Tensor):
225
  # refimage is a tensor of shape (batch_size, num_of_person, 3, H, W)
226
  batch_size, num_of_person, C, H, W = refimage.shape
227
  refimage = refimage.view(batch_size * num_of_person, C, H, W)
228
+ refimage = refimage.to(self.device, dtype=torch.bfloat16)
229
  last_hidden_state = self.model(refimage).last_hidden_state
230
  siglip_embedding = last_hidden_state.view(batch_size, num_of_person, 256, 768)
231