ai generated python code inference with quantization,unable to check is it correct i do not understand python code , runned and seems working .
#1
by
						
21world
	
							
						- opened
							
					
import psutil
import torch
import threading
import queue
import time
import os
import curses
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
try:
    import GPUtil
    gputil_available = True
    print("GPUtil found - using it for GPU monitoring.")
except ImportError:
    gputil_available = False
    print("GPUtil not found - using torch for basic GPU monitoring.")
    print("Consider installing it: pip install gputil")
def get_gpu_info():
    """Gets GPU memory and utilization info."""
    if torch.cuda.is_available():
        if gputil_available:
            gpus = GPUtil.getGPUs()
            if gpus:
                gpu = gpus[0]
                return {
                    'memory_used_mb': gpu.memoryUsed,
                    'memory_total_mb': gpu.memoryTotal,
                    'memory_free_mb': gpu.memoryFree,
                    'memory_percent_used': gpu.memoryUtil * 100,
                    'util_percent': gpu.load * 100,
                    'temperature': gpu.temperature
                }
        else:
            memory_stats = torch.cuda.memory_stats()
            allocated = torch.cuda.memory_allocated()
            total = torch.cuda.get_device_properties(0).total_memory
            return {
                'memory_used_mb': allocated / 1024**2,
                'memory_total_mb': total / 1024**2,
                'memory_percent_used': (allocated / total) * 100,
                'util_percent': 'N/A',
                'temperature': 'N/A'
            }
    return None
def print_gpu_status(): # Keep for potential non-curses use
    gpu_info = get_gpu_info()
    if gpu_info:
        print(f"GPU Memory: {gpu_info['memory_used_mb']:.1f}MB / {gpu_info['memory_total_mb']:.1f}MB "
              f"({gpu_info['memory_percent_used']:.1f}% used)")
    else:
        print("CUDA not available or GPU info unavailable.")
class CursesChatInterface:
    def __init__(self, stdscr, model_path, offload_to_cpu=True):
        self.stdscr = stdscr
        self.model_path = model_path
        self.offload_to_cpu = offload_to_cpu
        self.model = None
        self.tokenizer = None
        self.conversation_history = []
        self.input_buffer = ""
        self.scroll_pos = 0 # For scrolling chat history
        self.gpu_info = get_gpu_info() # Initial GPU info
        self.gpu_info_lock = threading.Lock() # Lock to protect shared gpu_info
        self.gpu_monitor_thread = None
        self.stop_gpu_monitor = threading.Event() # Event to signal the thread to stop
        self.is_generating = False # Flag to indicate generation status
        self.inference_time = 0.0 # Time for last inference
        self.quantization_type = "4bit" # Default to 4-bit
        self.load_model()
        self.start_gpu_monitoring() # Start the background thread
        self.setup_curses()
    def setup_curses(self):
        # Initialize colors
        curses.start_color()
        curses.init_pair(1, curses.COLOR_CYAN, curses.COLOR_BLACK)  # User text
        curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK) # Assistant text
        curses.init_pair(3, curses.COLOR_YELLOW, curses.COLOR_BLACK) # System/GPU monitor
        curses.init_pair(4, curses.COLOR_WHITE, curses.COLOR_BLUE)   # Input prompt
        curses.init_pair(5, curses.COLOR_RED, curses.COLOR_BLACK)    # Errors
        curses.init_pair(6, curses.COLOR_MAGENTA, curses.COLOR_BLACK) # Computing indicator
        # Disable cursor blinking
        curses.curs_set(0)
    def load_model(self):
        # This runs in the main thread context of curses
        self.stdscr.clear()
        self.stdscr.addstr(0, 0, "Select quantization:", curses.color_pair(3))
        self.stdscr.addstr(2, 0, "Press '2' for 2-bit, '4' for 4-bit, '8' for 8-bit, 'n' for no quantization")
        self.stdscr.refresh()
        
        # Get quantization selection
        while True:
            key = self.stdscr.getch()
            if key != -1:
                if key == ord('2'):
                    self.quantization_type = "2bit"
                    break
                elif key == ord('4'):
                    self.quantization_type = "4bit"
                    break
                elif key == ord('8'):
                    self.quantization_type = "8bit"
                    break
                elif key == ord('n'):
                    self.quantization_type = "none"
                    break
        
        self.stdscr.clear()
        self.stdscr.addstr(0, 0, f"Loading model with {self.quantization_type} quantization... Please wait...", curses.color_pair(3))
        self.stdscr.refresh()
        
        try:
            # Suppress quantization warnings
            warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from.*to float16 during quantization")
            warnings.filterwarnings("ignore", message=".*output of backward is a view.*")
            
            if self.quantization_type == "2bit":
                try:
                    from transformers import BitsAndBytesConfig
                    bnb_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_use_double_quant=True,
                        bnb_4bit_quant_type="nf4",
                        bnb_4bit_compute_dtype=torch.bfloat16,
                        # For 2-bit, we'll use a special configuration
                        load_in_8bit=False,
                    )
                    # Note: True 2-bit quantization requires specific models/configs
                    # We'll use 4-bit config which is closest available
                    model_kwargs = {
                        "torch_dtype": torch.bfloat16,
                        "device_map": "auto",
                        "trust_remote_code": True,
                        "quantization_config": bnb_config,
                    }
                    print("Attempting to load model with 4-bit config (closest to 2-bit)...")
                    self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs)
                    print("Model loaded successfully with 4-bit config.")
                except ImportError:
                    print("2-bit quantization failed. Loading with 4-bit...")
                    self.quantization_type = "4bit"
                    self.load_model()  # Recursive call with new type
                    return
            elif self.quantization_type == "4bit":
                try:
                    from transformers import BitsAndBytesConfig
                    bnb_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_use_double_quant=True,
                        bnb_4bit_quant_type="nf4",
                        bnb_4bit_compute_dtype=torch.bfloat16,
                    )
                    model_kwargs = {
                        "torch_dtype": torch.bfloat16,
                        "device_map": "auto",
                        "trust_remote_code": True,
                        "quantization_config": bnb_config,
                    }
                    print("Attempting to load model with 4-bit quantization...")
                    self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs)
                    print("Model loaded successfully with 4-bit quantization.")
                except ImportError:
                    print("4-bit quantization failed. Loading with 8-bit...")
                    self.quantization_type = "8bit"
                    self.load_model()  # Recursive call with new type
                    return
            elif self.quantization_type == "8bit":
                try:
                    model_kwargs = {
                        "torch_dtype": torch.bfloat16,
                        "device_map": "auto",
                        "trust_remote_code": True,
                        "load_in_8bit": True,
                    }
                    print("Attempting to load model with 8-bit quantization...")
                    self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs)
                    print("Model loaded successfully with 8-bit quantization.")
                except ImportError:
                    print("8-bit quantization failed. Loading without quantization...")
                    self.quantization_type = "none"
                    model_kwargs = {
                        "torch_dtype": torch.bfloat16,
                        "device_map": "auto",
                        "trust_remote_code": True,
                    }
                    self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs)
                    print("Model loaded successfully without quantization.")
            elif self.quantization_type == "none":
                model_kwargs = {
                    "torch_dtype": torch.bfloat16,
                    "device_map": "auto",
                    "trust_remote_code": True,
                }
                self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **model_kwargs)
                print("Model loaded successfully without quantization.")
            
            # Compile model for faster inference (if supported)
            if hasattr(torch, 'compile'):
                print("Compiling model for faster inference...")
                self.model = torch.compile(self.model, mode='reduce-overhead', fullgraph=True)
                print("Model compiled successfully.")
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                trust_remote_code=True,
                padding_side="left",  # Better for generation
                use_fast=True  # Use fast tokenizer if available
            )
            if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        except Exception as e:
            self.stdscr.addstr(1, 0, f"Error loading model: {e}", curses.color_pair(5))
            self.stdscr.refresh()
            time.sleep(3) # Pause to show error before exiting
            raise e # Re-raise to be caught by curses.wrapper
    def start_gpu_monitoring(self):
        """Starts the background thread for GPU monitoring."""
        self.gpu_monitor_thread = threading.Thread(target=self._gpu_monitor_loop, daemon=True)
        self.gpu_monitor_thread.start()
    def _gpu_monitor_loop(self):
        """Background loop to update GPU info every 0.5 seconds."""
        while not self.stop_gpu_monitor.is_set():
            new_gpu_info = get_gpu_info()
            # Use lock to safely update the shared variable
            with self.gpu_info_lock:
                self.gpu_info = new_gpu_info
            # Sleep for 0.5 seconds (2 updates per second)
            self.stop_gpu_monitor.wait(timeout=0.5) # Use event.wait for potential early exit
    def generate_streaming(self, prompt, max_new_tokens=512, **kwargs):
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
        text = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        stream_queue = queue.Queue()
        def generate_tokens():
            start_time = time.time()
            try:
                # Clear cache before generation
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                generated_ids = self.model.generate(
                    inputs["input_ids"],
                    max_new_tokens=max_new_tokens,
                    do_sample=kwargs.get('do_sample', True),
                    temperature=kwargs.get('temperature', 0.7),
                    top_p=kwargs.get('top_p', 0.9),
                    top_k=kwargs.get('top_k', 50),
                    repetition_penalty=kwargs.get('repetition_penalty', 1.1),
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    use_cache=True,
                )
                
                full_response = self.tokenizer.decode(
                    generated_ids[0][inputs["input_ids"].shape[1]:], 
                    skip_special_tokens=True
                )
                end_time = time.time()
                self.inference_time = end_time - start_time
                stream_queue.put(('complete', full_response))
            except torch.cuda.OutOfMemoryError as e:
                stream_queue.put(('error', f'CUDA OOM: {str(e)}'))
                torch.cuda.empty_cache()
            except Exception as e:
                stream_queue.put(('error', str(e)))
            finally:
                # Clear cache after generation
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        generation_thread = threading.Thread(target=generate_tokens)
        generation_thread.start()
        while True:
            try:
                item = stream_queue.get(timeout=0.05) # Very short timeout for frequent checks
                if item[0] == 'complete':
                    yield item[1]
                    break
                elif item[0] == 'error':
                    yield item[1] # Yield error message
                    break
            except queue.Empty:
                # No need to update GPU info here anymore, background thread handles it
                # Signal that we are still generating
                self.is_generating = True 
                # Yield nothing, just continue loop to check for new items or update UI
                continue 
        generation_thread.join()
        self.is_generating = False # Reset flag when generation finishes
    def draw_ui(self):
        self.stdscr.clear()
        height, width = self.stdscr.getmaxyx()
        # Calculate sections
        gpu_monitor_height = 6  # Fixed height for GPU monitor (increased for inference time)
        input_height = 3        # Height for input prompt
        chat_height = height - gpu_monitor_height - input_height
        # --- 1. Draw Chat History (Main Content Area) ---
        chat_win = self.stdscr.subwin(chat_height, width, 0, 0)
        chat_win.scrollok(True) # Allow scrolling within the window
        chat_win.idlok(True)    # Enable hardware scrolling if available
        chat_win.setscrreg(0, chat_height - 1) # Define scrolling region
        # Display chat history with colors
        y_pos = 0
        for msg in self.conversation_history:
            role = msg["role"]
            content = msg["content"]
            color_pair = curses.color_pair(1) if role == "user" else curses.color_pair(2) if role == "assistant" else curses.color_pair(3)
            
            # Add "Computing..." indicator if this is the assistant's placeholder during generation
            # Specifically check if the last message is the placeholder and generation is active
            if (role == "assistant" and 
                content == "Thinking..." and 
                self.is_generating and 
                self.conversation_history.index(msg) == len(self.conversation_history) - 1): # Check if it's the last message
                 content = "[Computing...]" # Replace placeholder with indicator
            
            # Split content into lines that fit the window width
            lines = []
            for paragraph in content.split('\n'):
                if len(paragraph) == 0:
                    lines.append("")
                    continue
                for i in range(0, len(paragraph), width - 1):
                    lines.append(paragraph[i:i + width - 1])
            
            for line in lines:
                if y_pos >= self.scroll_pos and y_pos < self.scroll_pos + chat_height:
                    chat_win.addstr(y_pos - self.scroll_pos, 0, line, color_pair)
                y_pos += 1
        # --- 2. Draw GPU Monitor (Top, Fixed Size) ---
        gpu_win = self.stdscr.subwin(gpu_monitor_height, width, chat_height, 0)
        gpu_win.border()
        gpu_win.addstr(0, 1, " GPU/VRAM Monitor ", curses.color_pair(3) | curses.A_BOLD)
        # Read the latest GPU info safely using the lock
        with self.gpu_info_lock:
            gpu_info = self.gpu_info
        
        if gpu_info:
            gpu_win.addstr(1, 2, f"Mem: {gpu_info['memory_used_mb']:.1f}MB / {gpu_info['memory_total_mb']:.1f}MB ({gpu_info['memory_percent_used']:.1f}%)", curses.color_pair(3))
            gpu_win.addstr(2, 2, f"Util: {gpu_info['util_percent']}% | Temp: {gpu_info['temperature']}°C", curses.color_pair(3))
        else:
            gpu_win.addstr(1, 2, "CUDA unavailable", curses.color_pair(5))
        
        cpu_percent = psutil.cpu_percent()
        ram_percent = psutil.virtual_memory().percent
        gpu_win.addstr(3, 2, f"CPU: {cpu_percent}% | RAM: {ram_percent}%", curses.color_pair(3))
        
        # Display quantization type and inference time
        gpu_win.addstr(4, 2, f"Quant: {self.quantization_type.upper()} | Time: {self.inference_time:.2f}s", curses.color_pair(3))
        # --- 3. Draw Input Prompt (Bottom) ---
        input_win = self.stdscr.subwin(input_height, width, height - input_height, 0)
        input_win.border()
        
        # Show "Computing..." indicator in input area if generation is active
        if self.is_generating:
            input_win.addstr(1, 1, f"Computing... (Press any key to continue)", curses.color_pair(6))
        else:
            input_win.addstr(1, 1, f"You: {self.input_buffer}", curses.color_pair(4))
            input_win.addstr(1, 3 + len(self.input_buffer), "_", curses.color_pair(4) | curses.A_BLINK) # Cursor
        self.stdscr.refresh()
    def run(self):
        self.stdscr.nodelay(True) # Make getch() non-blocking, so UI updates happen constantly
        ui_update_interval = 0.1 # Update UI every 0.1 seconds (10 times per second) for smooth updates
        last_ui_update = time.time()
        
        try:
            while True:
                current_time = time.time()
                # Always check if it's time to update the UI
                if current_time - last_ui_update >= ui_update_interval:
                    self.draw_ui()
                    last_ui_update = current_time
                key = self.stdscr.getch()
                if key != -1: # If a key was pressed
                    # Process the key immediately
                    if key == curses.KEY_RESIZE:
                        # Handle terminal resize if needed
                        pass
                    elif key == curses.KEY_UP:
                        # Scroll chat history up (older messages)
                        if len(self.conversation_history) > 0 and self.scroll_pos > 0:
                            self.scroll_pos -= 1
                    elif key == curses.KEY_DOWN:
                        # Scroll chat history down (newer messages)
                        # Calculate total lines in chat history
                        total_lines = 0
                        for msg in self.conversation_history:
                            content = msg["content"]
                            for paragraph in content.split('\n'):
                                if len(paragraph) == 0:
                                    total_lines += 1
                                    continue
                                total_lines += (len(paragraph) + self.stdscr.getmaxyx()[1] - 2) // (self.stdscr.getmaxyx()[1] - 1)
                        # Allow scrolling down if there are more lines than visible window
                        if self.scroll_pos < max(0, total_lines - self.stdscr.getmaxyx()[0] + 8): # Approximate offset for GPU and input
                            self.scroll_pos += 1
                    elif key == ord('\n'): # Enter key
                        user_input = self.input_buffer.strip()
                        if user_input.lower() == 'quit':
                            break
                        elif user_input.lower() == 'clear':
                            self.conversation_history = []
                            self.scroll_pos = 0
                            self.input_buffer = ""
                            continue
                        elif user_input and not self.is_generating: # Process non-empty input only if not already generating
                            self.conversation_history.append({"role": "user", "content": user_input})
                            self.input_buffer = "" # Clear input buffer after sending
                            # Add placeholder for assistant response
                            self.conversation_history.append({"role": "assistant", "content": "Thinking..."})
                            self.is_generating = True # Set flag before starting generation
                            
                            # Generate and update assistant response
                            full_response = ""
                            for response_chunk in self.generate_streaming(
                                user_input, 
                                max_new_tokens=2048, 
                                temperature=0.28,
                                top_p=0.9,
                                top_k=4,
                                repetition_penalty=1.1
                            ):
                                if response_chunk.startswith('CUDA OOM') or response_chunk.startswith('Error'):
                                    # Handle error during generation
                                    self.conversation_history[-1]["content"] = response_chunk # Replace placeholder
                                    self.is_generating = False # Ensure flag is reset on error
                                    break
                                else:
                                    full_response = response_chunk # Update with latest chunk
                                    self.conversation_history[-1]["content"] = full_response # Update placeholder
                                    # UI will update automatically due to frequent draw_ui calls
                                    
                            # Update final response if it was an error (already handled above)
                            if not full_response.startswith('CUDA OOM') and not full_response.startswith('Error'):
                                 self.conversation_history[-1]["content"] = full_response # Ensure final text is set
                            # self.is_generating flag is reset by generate_streaming when it finishes
                    elif key == curses.KEY_BACKSPACE or key == 127 or key == 8: # Backspace
                        if len(self.input_buffer) > 0:
                            self.input_buffer = self.input_buffer[:-1]
                    elif 0 < key < 256: # Printable ASCII character
                        char = chr(key)
                        self.input_buffer += char
                
                # Small sleep to prevent excessive CPU usage, but keep it low for responsiveness
                # The UI updates frequently due to the timer logic above
                time.sleep(0.005) 
        finally:
            # Ensure the GPU monitoring thread stops when the main loop exits
            self.stop_gpu_monitor.set()
            if self.gpu_monitor_thread:
                self.gpu_monitor_thread.join(timeout=2) # Wait for thread to finish, max 2 seconds
def main(stdscr):
    model_name = "your/model/path"
    chat_interface = CursesChatInterface(
        stdscr, 
        model_name, 
        offload_to_cpu=True
    )
    chat_interface.run()
if __name__ == "__main__":
    # Set environment variable to potentially help with memory fragmentation
    # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    curses.wrapper(main)
edit
model_name = "your/model/path"
with your model path location dir
It is ai generated seems working tested with 2b quantization it shows that load with 4b and i think inference is faster,gpu memory usage is lower . unable to check correctness of python code cause i do not commonly use/like/understand python code.

