update WeightDtype params
Browse files- src/display/utils.py +7 -1
src/display/utils.py
CHANGED
|
@@ -237,19 +237,25 @@ class WeightDtype(Enum):
|
|
| 237 |
nf4 = ModelDetails("nf4")
|
| 238 |
fp4 = ModelDetails("fp4")
|
| 239 |
|
|
|
|
| 240 |
Unknown = ModelDetails("?")
|
| 241 |
|
|
|
|
|
|
|
|
|
|
| 242 |
def from_str(weight_dtype):
|
| 243 |
if weight_dtype in ["int2"]:
|
| 244 |
return WeightDtype.int2
|
| 245 |
if weight_dtype in ["int3"]:
|
| 246 |
-
return WeightDtype.int3
|
| 247 |
if weight_dtype in ["int4"]:
|
| 248 |
return WeightDtype.int4
|
| 249 |
if weight_dtype in ["nf4"]:
|
| 250 |
return WeightDtype.nf4
|
| 251 |
if weight_dtype in ["fp4"]:
|
| 252 |
return WeightDtype.fp4
|
|
|
|
|
|
|
| 253 |
return WeightDtype.Unknown
|
| 254 |
|
| 255 |
class ComputeDtype(Enum):
|
|
|
|
| 237 |
nf4 = ModelDetails("nf4")
|
| 238 |
fp4 = ModelDetails("fp4")
|
| 239 |
|
| 240 |
+
|
| 241 |
Unknown = ModelDetails("?")
|
| 242 |
|
| 243 |
+
all = ModelDetails("All")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
def from_str(weight_dtype):
|
| 247 |
if weight_dtype in ["int2"]:
|
| 248 |
return WeightDtype.int2
|
| 249 |
if weight_dtype in ["int3"]:
|
| 250 |
+
return WeightDtype.int3
|
| 251 |
if weight_dtype in ["int4"]:
|
| 252 |
return WeightDtype.int4
|
| 253 |
if weight_dtype in ["nf4"]:
|
| 254 |
return WeightDtype.nf4
|
| 255 |
if weight_dtype in ["fp4"]:
|
| 256 |
return WeightDtype.fp4
|
| 257 |
+
if weight_dtype in ["All"]:
|
| 258 |
+
return WeightDtype.all
|
| 259 |
return WeightDtype.Unknown
|
| 260 |
|
| 261 |
class ComputeDtype(Enum):
|