hperkins commited on
Commit
171cc73
·
verified ·
1 Parent(s): f508d32

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -26
handler.py CHANGED
@@ -1,20 +1,19 @@
1
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
2
- from qwen_vl_utils import process_vision_info
3
  import torch
4
  import json
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
  # Load the model and processor for Qwen2-VL-7B
9
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
10
  model_dir,
11
- torch_dtype=torch.float16, # Use float16 for reduced memory usage
12
- device_map="auto" # Automatically assign to available GPU(s)
13
  )
14
  self.processor = AutoProcessor.from_pretrained(model_dir)
15
  self.model.eval()
16
-
17
- # Enable gradient checkpointing for memory savings
18
  self.model.gradient_checkpointing_enable()
19
 
20
  def preprocess(self, request_data):
@@ -22,7 +21,7 @@ class EndpointHandler:
22
  messages = request_data.get('messages')
23
  if not messages:
24
  raise ValueError("Messages are required")
25
-
26
  # Process vision info (image or video) from the messages
27
  image_inputs, video_inputs = process_vision_info(messages)
28
 
@@ -39,18 +38,18 @@ class EndpointHandler:
39
  padding=True,
40
  return_tensors="pt",
41
  )
42
-
43
- return inputs.to(self.model.device)
44
 
45
  def inference(self, inputs):
46
  # Perform inference with the model
47
  with torch.no_grad():
 
48
  generated_ids = self.model.generate(
49
- **inputs,
50
- max_new_tokens=256, # Increased token length for richer output
51
- num_beams=5, # Increase beam size for better quality
52
- early_stopping=True, # Stop when all beams have finished
53
- max_batch_size=1 # Keep batch size small to manage memory usage
54
  )
55
 
56
  # Trim the output (remove input tokens from generated output)
@@ -72,22 +71,14 @@ class EndpointHandler:
72
 
73
  def __call__(self, request):
74
  try:
75
- # Ensure request is a string before attempting to load it as JSON
76
- if isinstance(request, dict):
77
- request_data = request
78
- else:
79
- request_data = json.loads(request) # Parse the JSON request data
80
-
81
  # Preprocess the input data (text, images, videos)
82
  inputs = self.preprocess(request_data)
83
-
84
  # Perform inference
85
  outputs = self.inference(inputs)
86
-
87
  # Postprocess the output
88
  result = self.postprocess(outputs)
89
-
90
- return json.dumps({"result": result}) # Return a JSON response
91
-
92
  except Exception as e:
93
- return json.dumps({"error": str(e)}) # Return error as JSON
 
 
 
1
  import torch
2
  import json
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
  # Load the model and processor for Qwen2-VL-7B
9
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
10
  model_dir,
11
+ torch_dtype=torch.float16, # FP16 precision to reduce memory
12
+ device_map="auto" # Automatically distribute model across devices
13
  )
14
  self.processor = AutoProcessor.from_pretrained(model_dir)
15
  self.model.eval()
16
+ # Enable gradient checkpointing to save memory
 
17
  self.model.gradient_checkpointing_enable()
18
 
19
  def preprocess(self, request_data):
 
21
  messages = request_data.get('messages')
22
  if not messages:
23
  raise ValueError("Messages are required")
24
+
25
  # Process vision info (image or video) from the messages
26
  image_inputs, video_inputs = process_vision_info(messages)
27
 
 
38
  padding=True,
39
  return_tensors="pt",
40
  )
41
+
42
+ return inputs.to("cuda")
43
 
44
  def inference(self, inputs):
45
  # Perform inference with the model
46
  with torch.no_grad():
47
+ # Generate the output
48
  generated_ids = self.model.generate(
49
+ **inputs,
50
+ max_new_tokens=128,
51
+ num_beams=1,
52
+ max_batch_size=1
 
53
  )
54
 
55
  # Trim the output (remove input tokens from generated output)
 
71
 
72
  def __call__(self, request):
73
  try:
74
+ # Parse the JSON request data
75
+ request_data = json.loads(request)
 
 
 
 
76
  # Preprocess the input data (text, images, videos)
77
  inputs = self.preprocess(request_data)
 
78
  # Perform inference
79
  outputs = self.inference(inputs)
 
80
  # Postprocess the output
81
  result = self.postprocess(outputs)
82
+ return json.dumps({"result": result})
 
 
83
  except Exception as e:
84
+ return json.dumps({"error": str(e)})