Update README.md
Browse files
README.md
CHANGED
|
@@ -60,25 +60,24 @@ from hqq.utils.patching import *
|
|
| 60 |
from hqq.core.quantize import *
|
| 61 |
from hqq.utils.generation_hf import HFGenerator
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
#Load the model
|
| 64 |
###################################################
|
| 65 |
#model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq' #no calib version
|
| 66 |
model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib' #calibrated version
|
| 67 |
|
| 68 |
-
|
| 69 |
-
cache_dir = '.'
|
| 70 |
-
model = AutoHQQHFModel.from_quantized(model_id, cache_dir=cache_dir, compute_dtype=compute_dtype)
|
| 71 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
|
| 72 |
|
| 73 |
-
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
|
| 74 |
-
patch_linearlayers(model, patch_add_quant_config, quant_config)
|
| 75 |
-
|
| 76 |
#Use optimized inference kernels
|
| 77 |
###################################################
|
| 78 |
-
|
| 79 |
-
#prepare_for_inference(model) #default backend
|
| 80 |
-
prepare_for_inference(model, backend="torchao_int4")
|
| 81 |
-
#prepare_for_inference(model, backend="bitblas") #takes a while to init...
|
| 82 |
|
| 83 |
#Generate
|
| 84 |
###################################################
|
|
|
|
| 60 |
from hqq.core.quantize import *
|
| 61 |
from hqq.utils.generation_hf import HFGenerator
|
| 62 |
|
| 63 |
+
#Settings
|
| 64 |
+
###################################################
|
| 65 |
+
backend = "torchao_int4" #'torchao_int4' #"torchao_int4" (4-bit only) or "bitblas" (4-bit + 2-bit) or "gemlite" (8-bit, 4-bit, 2-bit, 1-bit)
|
| 66 |
+
compute_dtype = torch.bfloat16 if backend=="torchao_int4" else torch.float16
|
| 67 |
+
device = 'cuda:0'
|
| 68 |
+
cache_dir = '.'
|
| 69 |
+
|
| 70 |
#Load the model
|
| 71 |
###################################################
|
| 72 |
#model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq' #no calib version
|
| 73 |
model_id = 'mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib' #calibrated version
|
| 74 |
|
| 75 |
+
model = AutoHQQHFModel.from_quantized(model_id, cache_dir=cache_dir, compute_dtype=compute_dtype, device=device).eval()
|
|
|
|
|
|
|
| 76 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
#Use optimized inference kernels
|
| 79 |
###################################################
|
| 80 |
+
prepare_for_inference(model, backend=backend)
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
#Generate
|
| 83 |
###################################################
|