Update README.md
Browse files
README.md
CHANGED
|
@@ -99,7 +99,7 @@ from transformers import (
|
|
| 99 |
from torchao.quantization.quant_api import (
|
| 100 |
IntxWeightOnlyConfig,
|
| 101 |
Int8DynamicActivationIntxWeightConfig,
|
| 102 |
-
|
| 103 |
quantize_,
|
| 104 |
)
|
| 105 |
from torchao.quantization.granularity import PerGroup, PerAxis
|
|
@@ -121,7 +121,7 @@ linear_config = Int8DynamicActivationIntxWeightConfig(
|
|
| 121 |
weight_granularity=PerGroup(32),
|
| 122 |
weight_scale_dtype=torch.bfloat16,
|
| 123 |
)
|
| 124 |
-
quant_config =
|
| 125 |
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
|
| 126 |
|
| 127 |
# either use `untied_model_id` or `untied_model_local_path`
|
|
@@ -130,7 +130,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
| 130 |
|
| 131 |
# Push to hub
|
| 132 |
MODEL_NAME = model_id.split("/")[-1]
|
| 133 |
-
save_to = f"{USER_ID}/{MODEL_NAME}-
|
| 134 |
quantized_model.push_to_hub(save_to, safe_serialization=False)
|
| 135 |
tokenizer.push_to_hub(save_to)
|
| 136 |
|
|
|
|
| 99 |
from torchao.quantization.quant_api import (
|
| 100 |
IntxWeightOnlyConfig,
|
| 101 |
Int8DynamicActivationIntxWeightConfig,
|
| 102 |
+
ModuleFqnToConfig,
|
| 103 |
quantize_,
|
| 104 |
)
|
| 105 |
from torchao.quantization.granularity import PerGroup, PerAxis
|
|
|
|
| 121 |
weight_granularity=PerGroup(32),
|
| 122 |
weight_scale_dtype=torch.bfloat16,
|
| 123 |
)
|
| 124 |
+
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
|
| 125 |
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
|
| 126 |
|
| 127 |
# either use `untied_model_id` or `untied_model_local_path`
|
|
|
|
| 130 |
|
| 131 |
# Push to hub
|
| 132 |
MODEL_NAME = model_id.split("/")[-1]
|
| 133 |
+
save_to = f"{USER_ID}/{MODEL_NAME}-8da4w"
|
| 134 |
quantized_model.push_to_hub(save_to, safe_serialization=False)
|
| 135 |
tokenizer.push_to_hub(save_to)
|
| 136 |
|