marstin commited on
Commit
cf24c8d
·
1 Parent(s): c274a6b

[martin-dev] fix gpu device

Browse files
Files changed (1) hide show
  1. demo/launch_gradio.py +15 -12
demo/launch_gradio.py CHANGED
@@ -22,16 +22,6 @@ models_cache: Dict[str, Any] = {}
22
  current_model_selection: Optional[ModelSelection] = None
23
 
24
 
25
- def example1():
26
- return "What is in this image? Describe in one word.", None, None
27
-
28
- def example2():
29
- return "Describe the main object in the picture in one word.", None, None
30
-
31
- def example3():
32
- return "What color is the dominant object? Describe in one word.", None, None
33
-
34
-
35
  def read_layer_spec(spec_file_path: str) -> List[str]:
36
  """Read available layers from the model spec file.
37
 
@@ -171,8 +161,12 @@ def get_single_image_probabilities(
171
  Tuple containing list of top tokens and their probabilities.
172
  """
173
  # Generate prompt and process inputs
 
174
  text = vlm._generate_prompt(instruction, has_images=True)
175
  inputs = vlm._generate_processor_output(text, image)
 
 
 
176
 
177
  with torch.no_grad():
178
  outputs = vlm.model.generate(
@@ -366,9 +360,13 @@ def get_module_similarity_pooled(
366
  raise ValueError(f"Module '{module_name}' not found in model")
367
 
368
  try:
 
369
  # Extract embedding for image1
370
  text = vlm._generate_prompt(instruction, has_images=True)
371
  inputs1 = vlm._generate_processor_output(text, image1)
 
 
 
372
 
373
  embeddings.clear()
374
  with torch.no_grad():
@@ -381,6 +379,9 @@ def get_module_similarity_pooled(
381
 
382
  # Extract embedding for image2
383
  inputs2 = vlm._generate_processor_output(text, image2)
 
 
 
384
 
385
  embeddings.clear()
386
  with torch.no_grad():
@@ -441,7 +442,7 @@ def get_module_similarity_pooled(
441
  hook_handle.remove()
442
 
443
 
444
- @GPU(duration=60)
445
  def process_dual_inputs(
446
  model_choice: str,
447
  selected_layer: str,
@@ -638,7 +639,9 @@ def create_demo() -> gr.Blocks:
638
  # Add examples
639
  gr.Examples(
640
  examples=[
641
- [example1()], [example2()], [example3()]
 
 
642
  ],
643
  inputs=[instruction_input, image1_input, image2_input]
644
  )
 
22
  current_model_selection: Optional[ModelSelection] = None
23
 
24
 
 
 
 
 
 
 
 
 
 
 
25
  def read_layer_spec(spec_file_path: str) -> List[str]:
26
  """Read available layers from the model spec file.
27
 
 
161
  Tuple containing list of top tokens and their probabilities.
162
  """
163
  # Generate prompt and process inputs
164
+ vlm.model.eval()
165
  text = vlm._generate_prompt(instruction, has_images=True)
166
  inputs = vlm._generate_processor_output(text, image)
167
+ for key in inputs:
168
+ if isinstance(inputs[key], torch.Tensor):
169
+ inputs[key] = inputs[key].to(vlm.config.device)
170
 
171
  with torch.no_grad():
172
  outputs = vlm.model.generate(
 
360
  raise ValueError(f"Module '{module_name}' not found in model")
361
 
362
  try:
363
+ vlm.model.eval()
364
  # Extract embedding for image1
365
  text = vlm._generate_prompt(instruction, has_images=True)
366
  inputs1 = vlm._generate_processor_output(text, image1)
367
+ for key in inputs1:
368
+ if isinstance(inputs1[key], torch.Tensor):
369
+ inputs1[key] = inputs1[key].to(vlm.config.device)
370
 
371
  embeddings.clear()
372
  with torch.no_grad():
 
379
 
380
  # Extract embedding for image2
381
  inputs2 = vlm._generate_processor_output(text, image2)
382
+ for key in inputs2:
383
+ if isinstance(inputs2[key], torch.Tensor):
384
+ inputs2[key] = inputs2[key].to(vlm.config.device)
385
 
386
  embeddings.clear()
387
  with torch.no_grad():
 
442
  hook_handle.remove()
443
 
444
 
445
+ @GPU(duration=120)
446
  def process_dual_inputs(
447
  model_choice: str,
448
  selected_layer: str,
 
639
  # Add examples
640
  gr.Examples(
641
  examples=[
642
+ ['What is in this image? Describe in one word.', None, None],
643
+ ['Describe the main object in the picture in one word.', None, None],
644
+ ['What color is the dominant object? Describe in one word.', None, None],
645
  ],
646
  inputs=[instruction_input, image1_input, image2_input]
647
  )