Spaces:
Runtime error
Runtime error
Update model_worker.py
Browse files- model_worker.py +5 -4
model_worker.py
CHANGED
|
@@ -76,7 +76,8 @@ class ModelWorker:
|
|
| 76 |
@torch.inference_mode()
|
| 77 |
def generate_stream(self, params):
|
| 78 |
tokenizer, model = self.tokenizer, self.model
|
| 79 |
-
|
|
|
|
| 80 |
prompt = params["prompt"]
|
| 81 |
ori_prompt = prompt
|
| 82 |
images = params.get("images", None)
|
|
@@ -90,9 +91,9 @@ class ModelWorker:
|
|
| 90 |
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
|
| 91 |
|
| 92 |
images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
|
| 93 |
-
images = images.to(self.
|
| 94 |
-
# images = images.to(self.
|
| 95 |
-
patch_positions = patch_positions.to(self.
|
| 96 |
|
| 97 |
replace_token = DEFAULT_IMAGE_TOKEN
|
| 98 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
|
|
|
| 76 |
@torch.inference_mode()
|
| 77 |
def generate_stream(self, params):
|
| 78 |
tokenizer, model = self.tokenizer, self.model
|
| 79 |
+
# for adjust to zero environment of huggingface
|
| 80 |
+
model.to(self.device)
|
| 81 |
prompt = params["prompt"]
|
| 82 |
ori_prompt = prompt
|
| 83 |
images = params.get("images", None)
|
|
|
|
| 91 |
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
|
| 92 |
|
| 93 |
images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
|
| 94 |
+
images = images.to(self.device, dtype=torch.float16)
|
| 95 |
+
# images = images.to(self.device, dtype=torch.bfloat16)
|
| 96 |
+
patch_positions = patch_positions.to(self.device)
|
| 97 |
|
| 98 |
replace_token = DEFAULT_IMAGE_TOKEN
|
| 99 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|