jerryzh168 commited on
Commit
488fc91
·
verified ·
1 Parent(s): 638f535

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +22 -11
README.md CHANGED
@@ -137,9 +137,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
137
 
138
  model_id = "google/gemma-3-12b-it"
139
  model_to_quantize = "google/gemma-3-12b-it"
140
-
141
-
142
- from torchao.quantization import Int4WeightOnlyConfig, quantize_
143
  from torchao.prototype.awq import (
144
  AWQConfig,
145
  )
@@ -150,14 +148,24 @@ model = AutoModelForCausalLM.from_pretrained(
150
  torch_dtype=torch.bfloat16,
151
  )
152
  tokenizer = AutoTokenizer.from_pretrained(model_id)
153
-
 
 
 
 
 
 
154
  # AWQ only works for H100 INT4 so far
155
  base_config = Int4WeightOnlyConfig(group_size=128)
156
- quant_config = AWQConfig(base_config, step="prepare")
 
157
  quantize_(
158
  model,
159
  quant_config,
160
  )
 
 
 
161
  TransformerEvalWrapper(
162
  model=model,
163
  tokenizer=tokenizer,
@@ -166,22 +174,25 @@ TransformerEvalWrapper(
166
  tasks=tasks,
167
  limit=calibration_limit,
168
  )
169
- quant_config = AWQConfig(base_config, step="convert")
 
170
  quantize_(model, quant_config)
171
-
172
  quantized_model = model
173
- quant_config = AWQConfig(base_config, step="prepare_for_loading")
 
174
  quantized_model.config.quantization_config = TorchAoConfig(quant_config)
175
-
176
-
177
  # Push to hub
178
  USER_ID = "YOUR_USER_ID"
179
  MODEL_NAME = model_id.split("/")[-1]
180
  save_to = f"{USER_ID}/{MODEL_NAME}-AWQ-INT4"
181
  quantized_model.push_to_hub(save_to, safe_serialization=False)
182
  tokenizer.push_to_hub(save_to)
183
-
184
  # Manual Testing
 
 
 
 
 
185
  prompt = "Hey, are you conscious? Can you talk to me?"
186
  messages = [
187
  {
 
137
 
138
  model_id = "google/gemma-3-12b-it"
139
  model_to_quantize = "google/gemma-3-12b-it"
140
+ from torchao.quantization import Int4WeightOnlyConfig, quantize_, ModuleFqnToConfig
 
 
141
  from torchao.prototype.awq import (
142
  AWQConfig,
143
  )
 
148
  torch_dtype=torch.bfloat16,
149
  )
150
  tokenizer = AutoTokenizer.from_pretrained(model_id)
151
+ def get_quant_config(linear_config):
152
+ return ModuleFqnToConfig({
153
+ r"re:language_model\.model\.layers\..+\.mlp\..+_proj": linear_config,
154
+ r"re:language_model\.model\.layers\..+\.self_attn\..+_proj": linear_config,
155
+ r"re:model\.language_model\.layers\..+\.mlp\..+_proj": linear_config,
156
+ r"re:model\.language_model\.layers\..+\.self_attn\..+_proj": linear_config,
157
+ })
158
  # AWQ only works for H100 INT4 so far
159
  base_config = Int4WeightOnlyConfig(group_size=128)
160
+ linear_config = AWQConfig(base_config, step="prepare")
161
+ quant_config = get_quant_config(linear_config)
162
  quantize_(
163
  model,
164
  quant_config,
165
  )
166
+ tasks = ["mmlu_philosophy"]
167
+ calibration_limit=30
168
+ max_seq_length=2048
169
  TransformerEvalWrapper(
170
  model=model,
171
  tokenizer=tokenizer,
 
174
  tasks=tasks,
175
  limit=calibration_limit,
176
  )
177
+ linear_config = AWQConfig(base_config, step="convert")
178
+ quant_config = get_quant_config(linear_config)
179
  quantize_(model, quant_config)
 
180
  quantized_model = model
181
+ linear_config = AWQConfig(base_config, step="prepare_for_loading")
182
+ quant_config = get_quant_config(linear_config)
183
  quantized_model.config.quantization_config = TorchAoConfig(quant_config)
 
 
184
  # Push to hub
185
  USER_ID = "YOUR_USER_ID"
186
  MODEL_NAME = model_id.split("/")[-1]
187
  save_to = f"{USER_ID}/{MODEL_NAME}-AWQ-INT4"
188
  quantized_model.push_to_hub(save_to, safe_serialization=False)
189
  tokenizer.push_to_hub(save_to)
 
190
  # Manual Testing
191
+ quantized_model = AutoModelForCausalLM.from_pretrained(
192
+ save_to,
193
+ device_map="cuda:0",
194
+ torch_dtype=torch.bfloat16,
195
+ )
196
  prompt = "Hey, are you conscious? Can you talk to me?"
197
  messages = [
198
  {