Spaces:
Running
Running
workaround for quantization and push
Browse files- QUANTIZATION_FIX_SUMMARY.md +165 -0
- requirements_quantization.txt +17 -0
- scripts/model_tonic/quantize_model.py +154 -22
- test_quantization_fix.py +149 -0
QUANTIZATION_FIX_SUMMARY.md
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quantization Fix Summary
|
| 2 |
+
|
| 3 |
+
## Issues Identified
|
| 4 |
+
|
| 5 |
+
The quantization script was failing due to several compatibility issues:
|
| 6 |
+
|
| 7 |
+
1. **Int8 Quantization Error**:
|
| 8 |
+
- Error: `The model is quantized with QuantizationMethod.TORCHAO and is not serializable`
|
| 9 |
+
- Cause: Offloaded modules in the model cannot be quantized with torchao
|
| 10 |
+
- Solution: Added alternative save method and fallback to bitsandbytes
|
| 11 |
+
|
| 12 |
+
2. **Int4 Quantization Error**:
|
| 13 |
+
- Error: `Could not run 'aten::_convert_weight_to_int4pack_for_cpu' with arguments from the 'CUDA' backend`
|
| 14 |
+
- Cause: Int4 quantization requires CPU backend but was being attempted on CUDA
|
| 15 |
+
- Solution: Added proper device selection logic
|
| 16 |
+
|
| 17 |
+
3. **Monitoring Error**:
|
| 18 |
+
- Error: `'SmolLM3Monitor' object has no attribute 'log_event'`
|
| 19 |
+
- Cause: Incorrect monitoring API usage
|
| 20 |
+
- Solution: Added flexible monitoring method detection
|
| 21 |
+
|
| 22 |
+
## Fixes Implemented
|
| 23 |
+
|
| 24 |
+
### 1. Enhanced Device Management (`scripts/model_tonic/quantize_model.py`)
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
def get_optimal_device(self, quant_type: str) -> str:
|
| 28 |
+
"""Get optimal device for quantization type"""
|
| 29 |
+
if quant_type == "int4_weight_only":
|
| 30 |
+
# Int4 quantization works better on CPU
|
| 31 |
+
return "cpu"
|
| 32 |
+
elif quant_type == "int8_weight_only":
|
| 33 |
+
# Int8 quantization works on GPU
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
return "cuda"
|
| 36 |
+
else:
|
| 37 |
+
logger.warning("β οΈ CUDA not available, falling back to CPU for int8")
|
| 38 |
+
return "cpu"
|
| 39 |
+
else:
|
| 40 |
+
return "auto"
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### 2. Alternative Quantization Method
|
| 44 |
+
|
| 45 |
+
Added `quantize_model_alternative()` method using bitsandbytes for better compatibility:
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
def quantize_model_alternative(self, quant_type: str, device: str = "auto", group_size: int = 128, save_dir: Optional[str] = None) -> Optional[str]:
|
| 49 |
+
"""Alternative quantization using bitsandbytes for better compatibility"""
|
| 50 |
+
# Uses BitsAndBytesConfig instead of TorchAoConfig
|
| 51 |
+
# Handles serialization issues better
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### 3. Improved Error Handling
|
| 55 |
+
|
| 56 |
+
- Added fallback from torchao to bitsandbytes
|
| 57 |
+
- Enhanced save method with alternative approaches
|
| 58 |
+
- Better device mapping for different quantization types
|
| 59 |
+
|
| 60 |
+
### 4. Fixed Monitoring Integration
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
def log_to_trackio(self, action: str, details: Dict[str, Any]):
|
| 64 |
+
"""Log quantization events to Trackio"""
|
| 65 |
+
if self.monitor:
|
| 66 |
+
try:
|
| 67 |
+
# Use the correct monitoring method
|
| 68 |
+
if hasattr(self.monitor, 'log_event'):
|
| 69 |
+
self.monitor.log_event(action, details)
|
| 70 |
+
elif hasattr(self.monitor, 'log_metric'):
|
| 71 |
+
self.monitor.log_metric(action, details.get('value', 1.0))
|
| 72 |
+
elif hasattr(self.monitor, 'log'):
|
| 73 |
+
self.monitor.log(action, details)
|
| 74 |
+
else:
|
| 75 |
+
logger.info(f"π {action}: {details}")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Usage Instructions
|
| 81 |
+
|
| 82 |
+
### 1. Install Dependencies
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
pip install -r requirements_quantization.txt
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### 2. Run Quantization
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
python3 quantize_and_push.py
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
### 3. Test Fixes
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
python3 test_quantization_fix.py
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Expected Behavior
|
| 101 |
+
|
| 102 |
+
### Successful Quantization
|
| 103 |
+
|
| 104 |
+
The script will now:
|
| 105 |
+
|
| 106 |
+
1. **Try torchao first** for each quantization type
|
| 107 |
+
2. **Fall back to bitsandbytes** if torchao fails
|
| 108 |
+
3. **Use appropriate devices** (CPU for int4, GPU for int8)
|
| 109 |
+
4. **Handle serialization issues** with alternative save methods
|
| 110 |
+
5. **Log progress** without monitoring errors
|
| 111 |
+
|
| 112 |
+
### Output
|
| 113 |
+
|
| 114 |
+
```
|
| 115 |
+
β
Model files validated
|
| 116 |
+
π Processing quantization type: int8_weight_only
|
| 117 |
+
π Using device: cuda
|
| 118 |
+
β
int8_weight_only quantization and push completed
|
| 119 |
+
π Processing quantization type: int4_weight_only
|
| 120 |
+
π Using device: cpu
|
| 121 |
+
β
int4_weight_only quantization and push completed
|
| 122 |
+
π Quantization summary: 2/2 successful
|
| 123 |
+
β
Quantization completed successfully!
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
## Troubleshooting
|
| 127 |
+
|
| 128 |
+
### If All Quantization Fails
|
| 129 |
+
|
| 130 |
+
1. **Install bitsandbytes**:
|
| 131 |
+
```bash
|
| 132 |
+
pip install bitsandbytes
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
2. **Check model path**:
|
| 136 |
+
```bash
|
| 137 |
+
ls -la /output-checkpoint
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
3. **Verify dependencies**:
|
| 141 |
+
```bash
|
| 142 |
+
python3 test_quantization_fix.py
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Common Issues
|
| 146 |
+
|
| 147 |
+
1. **Memory Issues**: Use CPU for int4 quantization
|
| 148 |
+
2. **Serialization Errors**: The script now handles these automatically
|
| 149 |
+
3. **Device Conflicts**: Automatic device selection based on quantization type
|
| 150 |
+
|
| 151 |
+
## Files Modified
|
| 152 |
+
|
| 153 |
+
1. `scripts/model_tonic/quantize_model.py` - Main quantization logic
|
| 154 |
+
2. `quantize_and_push.py` - Main script with better error handling
|
| 155 |
+
3. `test_quantization_fix.py` - Test script for verification
|
| 156 |
+
4. `requirements_quantization.txt` - Dependencies file
|
| 157 |
+
|
| 158 |
+
## Next Steps
|
| 159 |
+
|
| 160 |
+
1. Run the test script to verify fixes
|
| 161 |
+
2. Install bitsandbytes if not already installed
|
| 162 |
+
3. Run the quantization script
|
| 163 |
+
4. Check the Hugging Face repository for quantized models
|
| 164 |
+
|
| 165 |
+
The fixes ensure robust quantization with multiple fallback options and proper error handling.
|
requirements_quantization.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quantization Dependencies
|
| 2 |
+
# Core quantization libraries
|
| 3 |
+
torchao>=0.1.0
|
| 4 |
+
bitsandbytes>=0.41.0
|
| 5 |
+
|
| 6 |
+
# Transformers with quantization support
|
| 7 |
+
transformers>=4.36.0
|
| 8 |
+
|
| 9 |
+
# Hugging Face Hub for model pushing
|
| 10 |
+
huggingface_hub>=0.19.0
|
| 11 |
+
|
| 12 |
+
# Optional: For better performance
|
| 13 |
+
accelerate>=0.24.0
|
| 14 |
+
safetensors>=0.4.0
|
| 15 |
+
|
| 16 |
+
# Optional: For monitoring
|
| 17 |
+
datasets>=2.14.0
|
scripts/model_tonic/quantize_model.py
CHANGED
|
@@ -101,27 +101,16 @@ class ModelQuantizer:
|
|
| 101 |
return False
|
| 102 |
|
| 103 |
# Check for essential model files
|
| 104 |
-
required_files = ['config.json']
|
| 105 |
optional_files = ['tokenizer.json', 'tokenizer_config.json']
|
| 106 |
|
| 107 |
-
|
| 108 |
-
model_files = [
|
| 109 |
-
"model.safetensors.index.json", # Safetensors format
|
| 110 |
-
"pytorch_model.bin" # PyTorch format
|
| 111 |
-
]
|
| 112 |
-
|
| 113 |
-
missing_files = []
|
| 114 |
for file in required_files:
|
| 115 |
if not (self.model_path / file).exists():
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# Check if at least one model file exists
|
| 119 |
-
model_file_exists = any((self.model_path / file).exists() for file in model_files)
|
| 120 |
-
if not model_file_exists:
|
| 121 |
-
missing_files.extend(model_files)
|
| 122 |
|
| 123 |
-
if
|
| 124 |
-
logger.error(f"β Missing required model files: {
|
| 125 |
return False
|
| 126 |
|
| 127 |
logger.info(f"β
Model path validated: {self.model_path}")
|
|
@@ -144,6 +133,99 @@ class ModelQuantizer:
|
|
| 144 |
|
| 145 |
return TorchAoConfig(quant_type=quant_config)
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def quantize_model(
|
| 148 |
self,
|
| 149 |
quant_type: str,
|
|
@@ -162,15 +244,32 @@ class ModelQuantizer:
|
|
| 162 |
logger.info(f"π Device: {device}")
|
| 163 |
logger.info(f"π Group size: {group_size}")
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# Create quantization config
|
| 166 |
quantization_config = self.create_quantization_config(quant_type, group_size)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
# Load and quantize the model
|
| 169 |
quantized_model = AutoModelForCausalLM.from_pretrained(
|
| 170 |
str(self.model_path),
|
| 171 |
-
torch_dtype=
|
| 172 |
-
device_map=
|
| 173 |
-
quantization_config=quantization_config
|
|
|
|
| 174 |
)
|
| 175 |
|
| 176 |
# Determine save directory
|
|
@@ -183,7 +282,24 @@ class ModelQuantizer:
|
|
| 183 |
|
| 184 |
# Save quantized model (don't use safetensors for torchao)
|
| 185 |
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Copy tokenizer files if they exist
|
| 189 |
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
|
@@ -198,7 +314,9 @@ class ModelQuantizer:
|
|
| 198 |
|
| 199 |
except Exception as e:
|
| 200 |
logger.error(f"β Quantization failed: {e}")
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
def create_quantized_model_card(self, quant_type: str, original_model: str, subdir: str) -> str:
|
| 204 |
"""Create a model card for the quantized model"""
|
|
@@ -470,10 +588,24 @@ For questions and support, please open an issue on the Hugging Face repository.
|
|
| 470 |
"""Log quantization events to Trackio"""
|
| 471 |
if self.monitor:
|
| 472 |
try:
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
logger.info(f"π Logged to Trackio: {action}")
|
| 475 |
except Exception as e:
|
| 476 |
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
def quantize_and_push(
|
| 479 |
self,
|
|
|
|
| 101 |
return False
|
| 102 |
|
| 103 |
# Check for essential model files
|
| 104 |
+
required_files = ['config.json', 'pytorch_model.bin']
|
| 105 |
optional_files = ['tokenizer.json', 'tokenizer_config.json']
|
| 106 |
|
| 107 |
+
missing_required = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
for file in required_files:
|
| 109 |
if not (self.model_path / file).exists():
|
| 110 |
+
missing_required.append(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
if missing_required:
|
| 113 |
+
logger.error(f"β Missing required model files: {missing_required}")
|
| 114 |
return False
|
| 115 |
|
| 116 |
logger.info(f"β
Model path validated: {self.model_path}")
|
|
|
|
| 133 |
|
| 134 |
return TorchAoConfig(quant_type=quant_config)
|
| 135 |
|
| 136 |
+
def get_optimal_device(self, quant_type: str) -> str:
|
| 137 |
+
"""Get optimal device for quantization type"""
|
| 138 |
+
if quant_type == "int4_weight_only":
|
| 139 |
+
# Int4 quantization works better on CPU
|
| 140 |
+
return "cpu"
|
| 141 |
+
elif quant_type == "int8_weight_only":
|
| 142 |
+
# Int8 quantization works on GPU
|
| 143 |
+
if torch.cuda.is_available():
|
| 144 |
+
return "cuda"
|
| 145 |
+
else:
|
| 146 |
+
logger.warning("β οΈ CUDA not available, falling back to CPU for int8")
|
| 147 |
+
return "cpu"
|
| 148 |
+
else:
|
| 149 |
+
return "auto"
|
| 150 |
+
|
| 151 |
+
def quantize_model_alternative(
|
| 152 |
+
self,
|
| 153 |
+
quant_type: str,
|
| 154 |
+
device: str = "auto",
|
| 155 |
+
group_size: int = 128,
|
| 156 |
+
save_dir: Optional[str] = None
|
| 157 |
+
) -> Optional[str]:
|
| 158 |
+
"""Alternative quantization using bitsandbytes for better compatibility"""
|
| 159 |
+
try:
|
| 160 |
+
logger.info(f"π Attempting alternative quantization for: {quant_type}")
|
| 161 |
+
|
| 162 |
+
# Import bitsandbytes if available
|
| 163 |
+
try:
|
| 164 |
+
import bitsandbytes as bnb
|
| 165 |
+
from transformers import BitsAndBytesConfig
|
| 166 |
+
BNB_AVAILABLE = True
|
| 167 |
+
except ImportError:
|
| 168 |
+
BNB_AVAILABLE = False
|
| 169 |
+
logger.error("β bitsandbytes not available for alternative quantization")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
if not BNB_AVAILABLE:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
# Create bitsandbytes config
|
| 176 |
+
if quant_type == "int8_weight_only":
|
| 177 |
+
bnb_config = BitsAndBytesConfig(
|
| 178 |
+
load_in_8bit=True,
|
| 179 |
+
llm_int8_threshold=6.0,
|
| 180 |
+
llm_int8_has_fp16_weight=False
|
| 181 |
+
)
|
| 182 |
+
elif quant_type == "int4_weight_only":
|
| 183 |
+
bnb_config = BitsAndBytesConfig(
|
| 184 |
+
load_in_4bit=True,
|
| 185 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 186 |
+
bnb_4bit_use_double_quant=True,
|
| 187 |
+
bnb_4bit_quant_type="nf4"
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
logger.error(f"β Unsupported quantization type for alternative method: {quant_type}")
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
# Load model with bitsandbytes quantization
|
| 194 |
+
quantized_model = AutoModelForCausalLM.from_pretrained(
|
| 195 |
+
str(self.model_path),
|
| 196 |
+
quantization_config=bnb_config,
|
| 197 |
+
device_map="auto",
|
| 198 |
+
torch_dtype=torch.bfloat16,
|
| 199 |
+
low_cpu_mem_usage=True
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Determine save directory
|
| 203 |
+
if save_dir is None:
|
| 204 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 205 |
+
save_dir = f"quantized_{quant_type}_bnb_{timestamp}"
|
| 206 |
+
|
| 207 |
+
save_path = Path(save_dir)
|
| 208 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
|
| 210 |
+
# Save quantized model
|
| 211 |
+
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
| 212 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
| 213 |
+
|
| 214 |
+
# Copy tokenizer files if they exist
|
| 215 |
+
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
| 216 |
+
for file in tokenizer_files:
|
| 217 |
+
src_file = self.model_path / file
|
| 218 |
+
if src_file.exists():
|
| 219 |
+
shutil.copy2(src_file, save_path / file)
|
| 220 |
+
logger.info(f"π Copied {file}")
|
| 221 |
+
|
| 222 |
+
logger.info(f"β
Alternative quantization successful: {save_path}")
|
| 223 |
+
return str(save_path)
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"β Alternative quantization failed: {e}")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
def quantize_model(
|
| 230 |
self,
|
| 231 |
quant_type: str,
|
|
|
|
| 244 |
logger.info(f"π Device: {device}")
|
| 245 |
logger.info(f"π Group size: {group_size}")
|
| 246 |
|
| 247 |
+
# Determine optimal device
|
| 248 |
+
if device == "auto":
|
| 249 |
+
device = self.get_optimal_device(quant_type)
|
| 250 |
+
logger.info(f"π Using device: {device}")
|
| 251 |
+
|
| 252 |
# Create quantization config
|
| 253 |
quantization_config = self.create_quantization_config(quant_type, group_size)
|
| 254 |
|
| 255 |
+
# Load model with appropriate device mapping
|
| 256 |
+
if device == "cpu":
|
| 257 |
+
device_map = "cpu"
|
| 258 |
+
torch_dtype = torch.float32
|
| 259 |
+
elif device == "cuda":
|
| 260 |
+
device_map = "auto"
|
| 261 |
+
torch_dtype = torch.bfloat16
|
| 262 |
+
else:
|
| 263 |
+
device_map = "auto"
|
| 264 |
+
torch_dtype = "auto"
|
| 265 |
+
|
| 266 |
# Load and quantize the model
|
| 267 |
quantized_model = AutoModelForCausalLM.from_pretrained(
|
| 268 |
str(self.model_path),
|
| 269 |
+
torch_dtype=torch_dtype,
|
| 270 |
+
device_map=device_map,
|
| 271 |
+
quantization_config=quantization_config,
|
| 272 |
+
low_cpu_mem_usage=True
|
| 273 |
)
|
| 274 |
|
| 275 |
# Determine save directory
|
|
|
|
| 282 |
|
| 283 |
# Save quantized model (don't use safetensors for torchao)
|
| 284 |
logger.info(f"πΎ Saving quantized model to: {save_path}")
|
| 285 |
+
|
| 286 |
+
# For torchao models, we need to handle serialization carefully
|
| 287 |
+
try:
|
| 288 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
| 289 |
+
except Exception as save_error:
|
| 290 |
+
logger.warning(f"β οΈ Standard save failed: {save_error}")
|
| 291 |
+
logger.info("π Attempting alternative save method...")
|
| 292 |
+
|
| 293 |
+
# Try saving without quantization config
|
| 294 |
+
try:
|
| 295 |
+
# Remove quantization config temporarily
|
| 296 |
+
original_config = quantized_model.config.quantization_config
|
| 297 |
+
quantized_model.config.quantization_config = None
|
| 298 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
| 299 |
+
quantized_model.config.quantization_config = original_config
|
| 300 |
+
except Exception as alt_save_error:
|
| 301 |
+
logger.error(f"β Alternative save also failed: {alt_save_error}")
|
| 302 |
+
return None
|
| 303 |
|
| 304 |
# Copy tokenizer files if they exist
|
| 305 |
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
|
|
|
| 314 |
|
| 315 |
except Exception as e:
|
| 316 |
logger.error(f"β Quantization failed: {e}")
|
| 317 |
+
# Try alternative quantization method
|
| 318 |
+
logger.info("π Attempting alternative quantization method...")
|
| 319 |
+
return self.quantize_model_alternative(quant_type, device, group_size, save_dir)
|
| 320 |
|
| 321 |
def create_quantized_model_card(self, quant_type: str, original_model: str, subdir: str) -> str:
|
| 322 |
"""Create a model card for the quantized model"""
|
|
|
|
| 588 |
"""Log quantization events to Trackio"""
|
| 589 |
if self.monitor:
|
| 590 |
try:
|
| 591 |
+
# Use the correct monitoring method
|
| 592 |
+
if hasattr(self.monitor, 'log_event'):
|
| 593 |
+
self.monitor.log_event(action, details)
|
| 594 |
+
elif hasattr(self.monitor, 'log_metric'):
|
| 595 |
+
# Log as metric instead
|
| 596 |
+
self.monitor.log_metric(action, details.get('value', 1.0))
|
| 597 |
+
elif hasattr(self.monitor, 'log'):
|
| 598 |
+
# Use generic log method
|
| 599 |
+
self.monitor.log(action, details)
|
| 600 |
+
else:
|
| 601 |
+
# Just log locally if no monitoring method available
|
| 602 |
+
logger.info(f"π {action}: {details}")
|
| 603 |
logger.info(f"π Logged to Trackio: {action}")
|
| 604 |
except Exception as e:
|
| 605 |
logger.warning(f"β οΈ Failed to log to Trackio: {e}")
|
| 606 |
+
else:
|
| 607 |
+
# Log locally if no monitor available
|
| 608 |
+
logger.info(f"π {action}: {details}")
|
| 609 |
|
| 610 |
def quantize_and_push(
|
| 611 |
self,
|
test_quantization_fix.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify quantization fixes
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Setup logging
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO,
|
| 14 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 15 |
+
)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
def test_quantization_imports():
|
| 19 |
+
"""Test that all required imports work"""
|
| 20 |
+
try:
|
| 21 |
+
# Test torchao imports
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
| 23 |
+
from torchao.quantization import (
|
| 24 |
+
Int8WeightOnlyConfig,
|
| 25 |
+
Int4WeightOnlyConfig,
|
| 26 |
+
Int8DynamicActivationInt8WeightConfig
|
| 27 |
+
)
|
| 28 |
+
from torchao.dtypes import Int4CPULayout
|
| 29 |
+
logger.info("β
torchao imports successful")
|
| 30 |
+
|
| 31 |
+
# Test bitsandbytes imports
|
| 32 |
+
try:
|
| 33 |
+
import bitsandbytes as bnb
|
| 34 |
+
from transformers import BitsAndBytesConfig
|
| 35 |
+
logger.info("β
bitsandbytes imports successful")
|
| 36 |
+
except ImportError:
|
| 37 |
+
logger.warning("β οΈ bitsandbytes not available - alternative quantization disabled")
|
| 38 |
+
|
| 39 |
+
# Test HF imports
|
| 40 |
+
from huggingface_hub import HfApi
|
| 41 |
+
logger.info("β
huggingface_hub imports successful")
|
| 42 |
+
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
logger.error(f"β Import failed: {e}")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
def test_model_quantizer():
|
| 50 |
+
"""Test ModelQuantizer initialization"""
|
| 51 |
+
try:
|
| 52 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
| 53 |
+
|
| 54 |
+
# Test with dummy values
|
| 55 |
+
quantizer = ModelQuantizer(
|
| 56 |
+
model_path="/output-checkpoint",
|
| 57 |
+
repo_name="test/test-repo",
|
| 58 |
+
token="dummy_token"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
logger.info("β
ModelQuantizer initialization successful")
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"β ModelQuantizer test failed: {e}")
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
def test_quantization_configs():
|
| 69 |
+
"""Test quantization config creation"""
|
| 70 |
+
try:
|
| 71 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
| 72 |
+
|
| 73 |
+
quantizer = ModelQuantizer(
|
| 74 |
+
model_path="/output-checkpoint",
|
| 75 |
+
repo_name="test/test-repo",
|
| 76 |
+
token="dummy_token"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Test int8 config
|
| 80 |
+
config = quantizer.create_quantization_config("int8_weight_only", 128)
|
| 81 |
+
logger.info("β
int8_weight_only config creation successful")
|
| 82 |
+
|
| 83 |
+
# Test int4 config
|
| 84 |
+
config = quantizer.create_quantization_config("int4_weight_only", 128)
|
| 85 |
+
logger.info("β
int4_weight_only config creation successful")
|
| 86 |
+
|
| 87 |
+
return True
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"β Quantization config test failed: {e}")
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
def test_device_selection():
|
| 94 |
+
"""Test optimal device selection"""
|
| 95 |
+
try:
|
| 96 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
| 97 |
+
|
| 98 |
+
quantizer = ModelQuantizer(
|
| 99 |
+
model_path="/output-checkpoint",
|
| 100 |
+
repo_name="test/test-repo",
|
| 101 |
+
token="dummy_token"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Test device selection
|
| 105 |
+
device = quantizer.get_optimal_device("int8_weight_only")
|
| 106 |
+
logger.info(f"β
int8 device selection: {device}")
|
| 107 |
+
|
| 108 |
+
device = quantizer.get_optimal_device("int4_weight_only")
|
| 109 |
+
logger.info(f"β
int4 device selection: {device}")
|
| 110 |
+
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"β Device selection test failed: {e}")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
"""Run all tests"""
|
| 119 |
+
logger.info("π§ͺ Testing quantization fixes...")
|
| 120 |
+
|
| 121 |
+
tests = [
|
| 122 |
+
("Import Test", test_quantization_imports),
|
| 123 |
+
("ModelQuantizer Test", test_model_quantizer),
|
| 124 |
+
("Config Creation Test", test_quantization_configs),
|
| 125 |
+
("Device Selection Test", test_device_selection),
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
passed = 0
|
| 129 |
+
total = len(tests)
|
| 130 |
+
|
| 131 |
+
for test_name, test_func in tests:
|
| 132 |
+
logger.info(f"\nπ Running {test_name}...")
|
| 133 |
+
if test_func():
|
| 134 |
+
passed += 1
|
| 135 |
+
logger.info(f"β
{test_name} passed")
|
| 136 |
+
else:
|
| 137 |
+
logger.error(f"β {test_name} failed")
|
| 138 |
+
|
| 139 |
+
logger.info(f"\nπ Test Results: {passed}/{total} tests passed")
|
| 140 |
+
|
| 141 |
+
if passed == total:
|
| 142 |
+
logger.info("π All tests passed! Quantization fixes are working.")
|
| 143 |
+
return 0
|
| 144 |
+
else:
|
| 145 |
+
logger.error("β Some tests failed. Please check the errors above.")
|
| 146 |
+
return 1
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
exit(main())
|