File size: 15,122 Bytes
1a4f599 2bd7a23 1a4f599 2bd7a23 1a4f599 2bd7a23 1a4f599 2bd7a23 1a4f599 2bd7a23 1a4f599 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 |
import torch
import time
import gc
import json
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Dict, Any, Optional
# Performance optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Global model and tokenizer variables
model = None
tokenizer = None
MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v7-sft-v1"
# Inference configurations
INFERENCE_CONFIGS = {
"Optimized for Speed": {
"max_new_tokens_base": 512,
"max_new_tokens_cap": 512,
"min_tokens": 50,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"use_cache": False,
"description": "Fast responses with limited output length"
},
"Middle-ground": {
"max_new_tokens_base": 4096,
"max_new_tokens_cap": 4096,
"min_tokens": 50,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"use_cache": False,
"description": "Balanced performance and output quality"
},
"Full Capacity": {
"max_new_tokens_base": 8192,
"max_new_tokens_cap": 8192,
"min_tokens": 1,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"use_cache": False,
"description": "Maximum output length with dynamic allocation"
}
}
def get_inference_configs():
"""Get available inference configurations"""
return INFERENCE_CONFIGS
def load_model():
"""Load model and tokenizer with optimizations"""
global model, tokenizer
if model is not None and tokenizer is not None:
return model, tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
## load 8 bit quants
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
# # Or 4-bit for even more memory savings
# quantization_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_compute_dtype=torch.float16,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# )
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.float16, # Use half precision for speed
attn_implementation="flash_attention_2" if hasattr(torch.nn, 'scaled_dot_product_attention') else None,
use_cache=True,
#quantization_config=quantization_config,
).eval()
# Enable gradient checkpointing if available
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
# Set pad_token_id
if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
model.config.pad_token_id = tokenizer.pad_token_id
# Set padding side to left for better batching
tokenizer.padding_side = "left"
memory = model.get_memory_footprint() / 1e6
print(f"Memory footprint: {memory:,.1f} MB")
return model, tokenizer
# ===== TOOL DEFINITIONS =====
def calculate_numbers(operation: str, num1: float, num2: float) -> Dict[str, Any]:
"""
Sample tool to perform basic mathematical operations on two numbers.
Args:
operation: The operation to perform ('add', 'subtract', 'multiply', 'divide')
num1: First number
num2: Second number
Returns:
Dictionary with result and operation details
"""
try:
num1, num2 = float(num1), float(num2)
if operation.lower() == 'add':
result = num1 + num2
elif operation.lower() == 'subtract':
result = num1 - num2
elif operation.lower() == 'multiply':
result = num1 * num2
elif operation.lower() == 'divide':
if num2 == 0:
return {"error": "Division by zero is not allowed"}
result = num1 / num2
else:
return {"error": f"Unknown operation: {operation}"}
return {
"result": result,
"operation": operation,
"operands": [num1, num2],
"formatted": f"{num1} {operation} {num2} = {result}"
}
except ValueError as e:
return {"error": f"Invalid number format: {str(e)}"}
except Exception as e:
return {"error": f"Calculation error: {str(e)}"}
# Tool registry
AVAILABLE_TOOLS = {
"calculate_numbers": {
"function": calculate_numbers,
"description": "Perform basic mathematical operations (add, subtract, multiply, divide) on two numbers",
"parameters": {
"operation": "The mathematical operation to perform",
"num1": "First number",
"num2": "Second number"
}
}
}
def execute_tool_call(tool_name: str, **kwargs) -> Dict[str, Any]:
"""Execute a tool call with given parameters"""
print(f"Executing tool: {tool_name} with parameters: {kwargs}")
if tool_name not in AVAILABLE_TOOLS:
return {"error": f"Unknown tool: {tool_name}"}
try:
tool_function = AVAILABLE_TOOLS[tool_name]["function"]
result = tool_function(**kwargs)
return {
"tool_name": tool_name,
"parameters": kwargs,
"result": result
}
except Exception as e:
print(f"Tool execution failed: {str(e)}")
return {
"tool_name": tool_name,
"parameters": kwargs,
"error": f"Tool execution error: {str(e)}"
}
# def parse_tool_calls(text: str) -> list:
# """
# Parse tool calls from model output.
# Expected format: [TOOL_CALL:tool_name(param1=value1, param2=value2)]
# """
# tool_calls = []
# #pattern = r'\[TOOL_CALL:(\w+)\((.*?)\)\]'
# pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|<tool_call>\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*</tool_call>)'
# matches = re.findall(pattern, text)
# print(matches)
#
# for tool_name, params_str in matches:
# try:
# params = {}
# if params_str.strip():
# param_pairs = params_str.split(',')
# for pair in param_pairs:
# if '=' in pair:
# key, value = pair.split('=', 1)
# key = key.strip()
# value = value.strip().strip('"\'') # Remove quotes
# params[key] = value
# tool_calls.append({
# "tool_name": tool_name,
# "parameters": params,
# "original_call": f"[TOOL_CALL:{tool_name}({params_str})]" # Store original call for replacement
# })
# except Exception as e:
# print(f"Error parsing tool call '{tool_name}({params_str})': {e}")
# continue
#
# return tool_calls
def parse_tool_calls(text: str) -> list:
"""
Parse tool calls from model output.
Supports both formats:
- [TOOL_CALL:tool_name(param1=value1, param2=value2)]
- <tool_call>{"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}}</tool_call>
"""
tool_calls = []
# Pattern for both formats
pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|<tool_call>\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*</tool_call>)'
matches = re.findall(pattern, text)
print("Raw matches:", matches)
for match in matches:
full_match, old_tool_name, old_params, json_tool_name, json_params = match
# Determine which format was matched
if old_tool_name: # Old format: [TOOL_CALL:tool_name(params)]
tool_name = old_tool_name
params_str = old_params
original_call = f"[TOOL_CALL:{tool_name}({params_str})]"
try:
params = {}
if params_str.strip():
param_pairs = params_str.split(',')
for pair in param_pairs:
if '=' in pair:
key, value = pair.split('=', 1)
key = key.strip()
value = value.strip().strip('"\'') # Remove quotes
params[key] = value
tool_calls.append({
"tool_name": tool_name,
"parameters": params,
"original_call": original_call
})
except Exception as e:
print(f"Error parsing old format tool call '{tool_name}({params_str})': {e}")
continue
elif json_tool_name: # JSON format: <tool_call>...</tool_call>
tool_name = json_tool_name
params_str = json_params
original_call = full_match
try:
params = {}
if params_str.strip():
# Parse JSON-like parameters
# Handle the format: "operation": "add", "num1": "125", "num2": "675"
param_pairs = params_str.split(',')
for pair in param_pairs:
if ':' in pair:
key, value = pair.split(':', 1)
key = key.strip().strip('"\'') # Remove quotes and whitespace
value = value.strip().strip('"\'') # Remove quotes and whitespace
params[key] = value
tool_calls.append({
"tool_name": tool_name,
"parameters": params,
"original_call": original_call
})
except Exception as e:
print(f"Error parsing JSON format tool call '{tool_name}': {e}")
continue
return tool_calls
def process_tool_calls(text: str) -> str:
"""Process tool calls in the generated text and replace with results"""
tool_calls = parse_tool_calls(text)
if not tool_calls:
return text
processed_text = text
for tool_call in tool_calls:
tool_name = tool_call["tool_name"]
parameters = tool_call["parameters"]
original_call = tool_call["original_call"]
try:
# Validate parameters before execution
if not isinstance(parameters, dict):
raise ValueError(f"Invalid parameters for tool {tool_name}: {parameters}")
# Execute tool
result = execute_tool_call(tool_name, **parameters)
# Create replacement text
if "error" in result:
replacement = f"[TOOL_ERROR: {result['error']}]"
else:
if "result" in result["result"]:
replacement = f"[TOOL_RESULT: {result['result']['formatted']}]"
else:
replacement = f"[TOOL_RESULT: {result['result']}]"
# Replace tool call with result
processed_text = processed_text.replace(original_call, replacement)
except Exception as e:
print(f"Error processing tool call '{tool_name}': {e}")
replacement = f"[TOOL_ERROR: Failed to process tool call: {str(e)}]"
processed_text = processed_text.replace(original_call, replacement)
return processed_text
def monitor_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
cached = torch.cuda.memory_reserved() / 1e9
print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")
def generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground") -> str:
"""
Run inference with the given task (system prompt) and user input using the specified config.
"""
load_model()
config = INFERENCE_CONFIGS[config_name]
input_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
prompt_text = tokenizer.apply_chat_template(
input_messages,
tokenize=False,
add_generation_prompt=True
)
input_length = len(tokenizer.encode(prompt_text))
context_length = min(input_length, 3584) # Leave room for generation
inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=context_length,
padding=False
).to(model.device)
actual_input_length = inputs['input_ids'].shape[1]
max_new_tokens = min(config["max_new_tokens_cap"], 4096 - actual_input_length - 10)
max_new_tokens = max(config["min_tokens"], max_new_tokens)
with torch.no_grad():
start_time = time.time()
outputs = model.generate(
**inputs,
do_sample=config["do_sample"],
temperature=config["temperature"],
top_p=config["top_p"],
use_cache=config["use_cache"],
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# Memory optimizations
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.2f} seconds")
memory = model.get_memory_footprint() / 1e6
monitor_memory()
print(f"Memory footprint: {memory:,.1f} MB")
# Clean up
gc.collect()
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if prompt_text in full_text:
response_start = full_text.find(prompt_text) + len(prompt_text)
generated_response = full_text[response_start:].strip()
else:
# More robust fallback: try to extract response after the last user message
generated_response = full_text.strip()
try:
# Look for common assistant/response indicators
response_indicators = ["Assistant:", "<|assistant|>", "[/INST]", "Response:"]
for indicator in response_indicators:
if indicator in full_text:
parts = full_text.split(indicator)
if len(parts) > 1:
generated_response = parts[-1].strip()
break
# If no indicator found, try to remove the input part
user_message = user_input
if user_message in full_text:
parts = full_text.split(user_message)
if len(parts) > 1:
generated_response = parts[-1].strip()
except Exception:
generated_response = full_text.strip()
# Process any tool calls in the generated response
generated_response = process_tool_calls(generated_response)
# print('Input tokens:', inputs.input_ids.numel())
#print('Output tokens:', outputs.input_ids.numel())
# print('Output tokens:', outputs['input_ids'].numel())
return generated_response |