#!/usr/bin/env python3 """ Test script for torchao quantization inference """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from torchao.quantization import Int4WeightOnlyConfig, Int8WeightOnlyConfig from torchao.dtypes import Int4CPULayout import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def test_torchao_quantization(): """Test torchao quantization with different configurations""" model_id = "Tonic/petite-elle-L-aime-3-sft" device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Testing torchao quantization on device: {device}") # Test different quantization configs configs_to_test = [] if device == "cuda": configs_to_test.append(("Int8WeightOnlyConfig", Int8WeightOnlyConfig(group_size=128))) else: configs_to_test.append(("Int4WeightOnlyConfig CPU", Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout()))) for config_name, quant_config in configs_to_test: logger.info(f"\nTesting {config_name}...") try: # Create quantization config quantization_config = TorchAoConfig(quant_type=quant_config) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id # Load model with quantization model_kwargs = { "device_map": "auto" if device == "cuda" else "cpu", "torch_dtype": torch.bfloat16 if device == "cuda" else torch.float32, "trust_remote_code": True, "low_cpu_mem_usage": True, "quantization_config": quantization_config, } logger.info(f"Loading model with {config_name}...") model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) # Test generation test_prompt = "Bonjour, comment allez-vous?" inputs = tokenizer(test_prompt, return_tensors="pt") if device == "cuda": inputs = {k: v.cuda() for k, v in inputs.items()} logger.info("Generating response...") with torch.no_grad(): output_ids = model.generate( inputs['input_ids'], max_new_tokens=50, temperature=0.7, top_p=0.95, do_sample=True, attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, cache_implementation="static" # Important for torchao ) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) assistant_response = response[len(test_prompt):].strip() logger.info(f"āœ… {config_name} test successful!") logger.info(f"Input: {test_prompt}") logger.info(f"Output: {assistant_response}") # Clean up del model torch.cuda.empty_cache() if device == "cuda" else None except Exception as e: logger.error(f"āŒ {config_name} test failed: {e}") continue logger.info("\nšŸŽ‰ All torchao quantization tests completed!") if __name__ == "__main__": test_torchao_quantization()