modify ComputeDtype params
Browse files- src/display/utils.py +8 -3
src/display/utils.py
CHANGED
|
@@ -230,17 +230,17 @@ class QuantType(Enum):
|
|
| 230 |
return QuantType.Unknown
|
| 231 |
|
| 232 |
|
|
|
|
| 233 |
class WeightDtype(Enum):
|
|
|
|
| 234 |
int2 = ModelDetails("int2")
|
| 235 |
-
int3 = ModelDetails("int3")
|
| 236 |
int4 = ModelDetails("int4")
|
| 237 |
nf4 = ModelDetails("nf4")
|
| 238 |
fp4 = ModelDetails("fp4")
|
| 239 |
|
| 240 |
-
|
| 241 |
Unknown = ModelDetails("?")
|
| 242 |
|
| 243 |
-
all = ModelDetails("All")
|
| 244 |
|
| 245 |
|
| 246 |
def from_str(weight_dtype):
|
|
@@ -259,11 +259,13 @@ class WeightDtype(Enum):
|
|
| 259 |
return WeightDtype.Unknown
|
| 260 |
|
| 261 |
class ComputeDtype(Enum):
|
|
|
|
| 262 |
fp16 = ModelDetails("float16")
|
| 263 |
bf16 = ModelDetails("bfloat16")
|
| 264 |
int8 = ModelDetails("int8")
|
| 265 |
fp32 = ModelDetails("float32")
|
| 266 |
|
|
|
|
| 267 |
Unknown = ModelDetails("?")
|
| 268 |
|
| 269 |
def from_str(compute_dtype):
|
|
@@ -275,8 +277,11 @@ class ComputeDtype(Enum):
|
|
| 275 |
return ComputeDtype.int8
|
| 276 |
if compute_dtype in ["float32"]:
|
| 277 |
return ComputeDtype.fp32
|
|
|
|
|
|
|
| 278 |
return ComputeDtype.Unknown
|
| 279 |
|
|
|
|
| 280 |
class GroupDtype(Enum):
|
| 281 |
group_1 = ModelDetails("-1")
|
| 282 |
group_1024 = ModelDetails("1024")
|
|
|
|
| 230 |
return QuantType.Unknown
|
| 231 |
|
| 232 |
|
| 233 |
+
|
| 234 |
class WeightDtype(Enum):
|
| 235 |
+
all = ModelDetails("All")
|
| 236 |
int2 = ModelDetails("int2")
|
| 237 |
+
int3 = ModelDetails("int3")
|
| 238 |
int4 = ModelDetails("int4")
|
| 239 |
nf4 = ModelDetails("nf4")
|
| 240 |
fp4 = ModelDetails("fp4")
|
| 241 |
|
|
|
|
| 242 |
Unknown = ModelDetails("?")
|
| 243 |
|
|
|
|
| 244 |
|
| 245 |
|
| 246 |
def from_str(weight_dtype):
|
|
|
|
| 259 |
return WeightDtype.Unknown
|
| 260 |
|
| 261 |
class ComputeDtype(Enum):
|
| 262 |
+
all = ModelDetails("All")
|
| 263 |
fp16 = ModelDetails("float16")
|
| 264 |
bf16 = ModelDetails("bfloat16")
|
| 265 |
int8 = ModelDetails("int8")
|
| 266 |
fp32 = ModelDetails("float32")
|
| 267 |
|
| 268 |
+
|
| 269 |
Unknown = ModelDetails("?")
|
| 270 |
|
| 271 |
def from_str(compute_dtype):
|
|
|
|
| 277 |
return ComputeDtype.int8
|
| 278 |
if compute_dtype in ["float32"]:
|
| 279 |
return ComputeDtype.fp32
|
| 280 |
+
if compute_dtype in ["All"]:
|
| 281 |
+
return ComputeDtype.all
|
| 282 |
return ComputeDtype.Unknown
|
| 283 |
|
| 284 |
+
|
| 285 |
class GroupDtype(Enum):
|
| 286 |
group_1 = ModelDetails("-1")
|
| 287 |
group_1024 = ModelDetails("1024")
|