Update README.md
Browse files
README.md
CHANGED
|
@@ -137,9 +137,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
|
| 137 |
|
| 138 |
model_id = "google/gemma-3-12b-it"
|
| 139 |
model_to_quantize = "google/gemma-3-12b-it"
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
from torchao.quantization import Int4WeightOnlyConfig, quantize_
|
| 143 |
from torchao.prototype.awq import (
|
| 144 |
AWQConfig,
|
| 145 |
)
|
|
@@ -150,14 +148,24 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 150 |
torch_dtype=torch.bfloat16,
|
| 151 |
)
|
| 152 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
# AWQ only works for H100 INT4 so far
|
| 155 |
base_config = Int4WeightOnlyConfig(group_size=128)
|
| 156 |
-
|
|
|
|
| 157 |
quantize_(
|
| 158 |
model,
|
| 159 |
quant_config,
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
| 161 |
TransformerEvalWrapper(
|
| 162 |
model=model,
|
| 163 |
tokenizer=tokenizer,
|
|
@@ -166,22 +174,25 @@ TransformerEvalWrapper(
|
|
| 166 |
tasks=tasks,
|
| 167 |
limit=calibration_limit,
|
| 168 |
)
|
| 169 |
-
|
|
|
|
| 170 |
quantize_(model, quant_config)
|
| 171 |
-
|
| 172 |
quantized_model = model
|
| 173 |
-
|
|
|
|
| 174 |
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
# Push to hub
|
| 178 |
USER_ID = "YOUR_USER_ID"
|
| 179 |
MODEL_NAME = model_id.split("/")[-1]
|
| 180 |
save_to = f"{USER_ID}/{MODEL_NAME}-AWQ-INT4"
|
| 181 |
quantized_model.push_to_hub(save_to, safe_serialization=False)
|
| 182 |
tokenizer.push_to_hub(save_to)
|
| 183 |
-
|
| 184 |
# Manual Testing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
prompt = "Hey, are you conscious? Can you talk to me?"
|
| 186 |
messages = [
|
| 187 |
{
|
|
|
|
| 137 |
|
| 138 |
model_id = "google/gemma-3-12b-it"
|
| 139 |
model_to_quantize = "google/gemma-3-12b-it"
|
| 140 |
+
from torchao.quantization import Int4WeightOnlyConfig, quantize_, ModuleFqnToConfig
|
|
|
|
|
|
|
| 141 |
from torchao.prototype.awq import (
|
| 142 |
AWQConfig,
|
| 143 |
)
|
|
|
|
| 148 |
torch_dtype=torch.bfloat16,
|
| 149 |
)
|
| 150 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 151 |
+
def get_quant_config(linear_config):
|
| 152 |
+
return ModuleFqnToConfig({
|
| 153 |
+
r"re:language_model\.model\.layers\..+\.mlp\..+_proj": linear_config,
|
| 154 |
+
r"re:language_model\.model\.layers\..+\.self_attn\..+_proj": linear_config,
|
| 155 |
+
r"re:model\.language_model\.layers\..+\.mlp\..+_proj": linear_config,
|
| 156 |
+
r"re:model\.language_model\.layers\..+\.self_attn\..+_proj": linear_config,
|
| 157 |
+
})
|
| 158 |
# AWQ only works for H100 INT4 so far
|
| 159 |
base_config = Int4WeightOnlyConfig(group_size=128)
|
| 160 |
+
linear_config = AWQConfig(base_config, step="prepare")
|
| 161 |
+
quant_config = get_quant_config(linear_config)
|
| 162 |
quantize_(
|
| 163 |
model,
|
| 164 |
quant_config,
|
| 165 |
)
|
| 166 |
+
tasks = ["mmlu_philosophy"]
|
| 167 |
+
calibration_limit=30
|
| 168 |
+
max_seq_length=2048
|
| 169 |
TransformerEvalWrapper(
|
| 170 |
model=model,
|
| 171 |
tokenizer=tokenizer,
|
|
|
|
| 174 |
tasks=tasks,
|
| 175 |
limit=calibration_limit,
|
| 176 |
)
|
| 177 |
+
linear_config = AWQConfig(base_config, step="convert")
|
| 178 |
+
quant_config = get_quant_config(linear_config)
|
| 179 |
quantize_(model, quant_config)
|
|
|
|
| 180 |
quantized_model = model
|
| 181 |
+
linear_config = AWQConfig(base_config, step="prepare_for_loading")
|
| 182 |
+
quant_config = get_quant_config(linear_config)
|
| 183 |
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
|
|
|
|
|
|
|
| 184 |
# Push to hub
|
| 185 |
USER_ID = "YOUR_USER_ID"
|
| 186 |
MODEL_NAME = model_id.split("/")[-1]
|
| 187 |
save_to = f"{USER_ID}/{MODEL_NAME}-AWQ-INT4"
|
| 188 |
quantized_model.push_to_hub(save_to, safe_serialization=False)
|
| 189 |
tokenizer.push_to_hub(save_to)
|
|
|
|
| 190 |
# Manual Testing
|
| 191 |
+
quantized_model = AutoModelForCausalLM.from_pretrained(
|
| 192 |
+
save_to,
|
| 193 |
+
device_map="cuda:0",
|
| 194 |
+
torch_dtype=torch.bfloat16,
|
| 195 |
+
)
|
| 196 |
prompt = "Hey, are you conscious? Can you talk to me?"
|
| 197 |
messages = [
|
| 198 |
{
|