File size: 3,639 Bytes
2ee7774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
"""
Test script for torchao quantization inference
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import Int4WeightOnlyConfig, Int8WeightOnlyConfig
from torchao.dtypes import Int4CPULayout
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def test_torchao_quantization():
    """Test torchao quantization with different configurations"""
    
    model_id = "Tonic/petite-elle-L-aime-3-sft"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    logger.info(f"Testing torchao quantization on device: {device}")
    
    # Test different quantization configs
    configs_to_test = []
    
    if device == "cuda":
        configs_to_test.append(("Int8WeightOnlyConfig", Int8WeightOnlyConfig(group_size=128)))
    else:
        configs_to_test.append(("Int4WeightOnlyConfig CPU", Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())))
    
    for config_name, quant_config in configs_to_test:
        logger.info(f"\nTesting {config_name}...")
        
        try:
            # Create quantization config
            quantization_config = TorchAoConfig(quant_type=quant_config)
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            
            # Load model with quantization
            model_kwargs = {
                "device_map": "auto" if device == "cuda" else "cpu",
                "torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32,
                "trust_remote_code": True,
                "low_cpu_mem_usage": True,
                "quantization_config": quantization_config,
            }
            
            logger.info(f"Loading model with {config_name}...")
            model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
            
            # Test generation
            test_prompt = "Bonjour, comment allez-vous?"
            inputs = tokenizer(test_prompt, return_tensors="pt")
            
            if device == "cuda":
                inputs = {k: v.cuda() for k, v in inputs.items()}
            
            logger.info("Generating response...")
            with torch.no_grad():
                output_ids = model.generate(
                    inputs['input_ids'],
                    max_new_tokens=50,
                    temperature=0.7,
                    top_p=0.95,
                    do_sample=True,
                    attention_mask=inputs['attention_mask'],
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    cache_implementation="static"  # Important for torchao
                )
            
            response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            assistant_response = response[len(test_prompt):].strip()
            
            logger.info(f"βœ… {config_name} test successful!")
            logger.info(f"Input: {test_prompt}")
            logger.info(f"Output: {assistant_response}")
            
            # Clean up
            del model
            torch.cuda.empty_cache() if device == "cuda" else None
            
        except Exception as e:
            logger.error(f"❌ {config_name} test failed: {e}")
            continue
    
    logger.info("\nπŸŽ‰ All torchao quantization tests completed!")

if __name__ == "__main__":
    test_torchao_quantization()