Spaces:
Running
Running
future-xy
commited on
Commit
·
794e78c
1
Parent(s):
3237d78
fix cuda mismatch bugs
Browse files
src/backend/moe_infinity.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
import os
|
| 3 |
from transformers import AutoTokenizer
|
| 4 |
import transformers
|
|
|
|
| 5 |
from moe_infinity import MoE
|
| 6 |
from typing import List, Tuple, Optional, Union
|
| 7 |
|
|
@@ -26,7 +27,9 @@ class MoEHFLM(HFLM):
|
|
| 26 |
self.offload_path = offload_path
|
| 27 |
self.device_memory_ratio = device_memory_ratio
|
| 28 |
self.use_chat_template = use_chat_template
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
# self._create_model()
|
| 31 |
|
| 32 |
def _create_model(self, *args, **kwargs):
|
|
|
|
| 2 |
import os
|
| 3 |
from transformers import AutoTokenizer
|
| 4 |
import transformers
|
| 5 |
+
from transformers import AutoModelForCausalLM
|
| 6 |
from moe_infinity import MoE
|
| 7 |
from typing import List, Tuple, Optional, Union
|
| 8 |
|
|
|
|
| 27 |
self.offload_path = offload_path
|
| 28 |
self.device_memory_ratio = device_memory_ratio
|
| 29 |
self.use_chat_template = use_chat_template
|
| 30 |
+
if "device" in kwargs:
|
| 31 |
+
kwargs.pop("device")
|
| 32 |
+
super().__init__(*args, **kwargs, pretrained=pretrained, device="cuda:0") # Assuming HFLM accepts a 'pretrained' arg and handles it
|
| 33 |
# self._create_model()
|
| 34 |
|
| 35 |
def _create_model(self, *args, **kwargs):
|