File size: 15,122 Bytes
1a4f599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd7a23
1a4f599
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd7a23
 
1a4f599
 
 
 
 
 
 
 
2bd7a23
 
1a4f599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd7a23
1a4f599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bd7a23
1a4f599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import torch
import time
import gc
import json
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Dict, Any, Optional

# Performance optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Global model and tokenizer variables
model = None
tokenizer = None
MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v7-sft-v1"

# Inference configurations
INFERENCE_CONFIGS = {
    "Optimized for Speed": {
        "max_new_tokens_base": 512,
        "max_new_tokens_cap": 512,
        "min_tokens": 50,
        "temperature": 0.7,
        "top_p": 0.9,
        "do_sample": True,
        "use_cache": False,
        "description": "Fast responses with limited output length"
    },
    "Middle-ground": {
        "max_new_tokens_base": 4096,
        "max_new_tokens_cap": 4096,
        "min_tokens": 50,
        "temperature": 0.7,
        "top_p": 0.9,
        "do_sample": True,
        "use_cache": False,
        "description": "Balanced performance and output quality"
    },
    "Full Capacity": {
        "max_new_tokens_base": 8192,
        "max_new_tokens_cap": 8192,
        "min_tokens": 1,
        "temperature": 0.7,
        "top_p": 0.9,
        "do_sample": True,
        "use_cache": False,
        "description": "Maximum output length with dynamic allocation"
    }
}


def get_inference_configs():
    """Get available inference configurations"""
    return INFERENCE_CONFIGS



def load_model():
    """Load model and tokenizer with optimizations"""
    global model, tokenizer

    if model is not None and tokenizer is not None:
        return model, tokenizer

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    ## load 8 bit quants
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
    )
    # # Or 4-bit for even more memory savings
    # quantization_config = BitsAndBytesConfig(
    #     load_in_4bit=True,
    #     bnb_4bit_compute_dtype=torch.float16,
    #     bnb_4bit_quant_type="nf4",
    #     bnb_4bit_use_double_quant=True,
    # )

    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto",
        dtype=torch.float16,  # Use half precision for speed
        attn_implementation="flash_attention_2" if hasattr(torch.nn, 'scaled_dot_product_attention') else None,
        use_cache=True,
        #quantization_config=quantization_config,
    ).eval()

    # Enable gradient checkpointing if available
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()

    # Set pad_token_id
    if model.config.pad_token_id is None and tokenizer.pad_token_id is not None:
        model.config.pad_token_id = tokenizer.pad_token_id

    # Set padding side to left for better batching
    tokenizer.padding_side = "left"

    memory = model.get_memory_footprint() / 1e6
    print(f"Memory footprint: {memory:,.1f} MB")

    return model, tokenizer


# ===== TOOL DEFINITIONS =====

def calculate_numbers(operation: str, num1: float, num2: float) -> Dict[str, Any]:
    """
    Sample tool to perform basic mathematical operations on two numbers.

    Args:
        operation: The operation to perform ('add', 'subtract', 'multiply', 'divide')
        num1: First number
        num2: Second number

    Returns:
        Dictionary with result and operation details
    """
    try:
        num1, num2 = float(num1), float(num2)

        if operation.lower() == 'add':
            result = num1 + num2
        elif operation.lower() == 'subtract':
            result = num1 - num2
        elif operation.lower() == 'multiply':
            result = num1 * num2
        elif operation.lower() == 'divide':
            if num2 == 0:
                return {"error": "Division by zero is not allowed"}
            result = num1 / num2
        else:
            return {"error": f"Unknown operation: {operation}"}

        return {
            "result": result,
            "operation": operation,
            "operands": [num1, num2],
            "formatted": f"{num1} {operation} {num2} = {result}"
        }
    except ValueError as e:
        return {"error": f"Invalid number format: {str(e)}"}
    except Exception as e:
        return {"error": f"Calculation error: {str(e)}"}


# Tool registry
AVAILABLE_TOOLS = {
    "calculate_numbers": {
        "function": calculate_numbers,
        "description": "Perform basic mathematical operations (add, subtract, multiply, divide) on two numbers",
        "parameters": {
            "operation": "The mathematical operation to perform",
            "num1": "First number",
            "num2": "Second number"
        }
    }
}


def execute_tool_call(tool_name: str, **kwargs) -> Dict[str, Any]:
    """Execute a tool call with given parameters"""
    print(f"Executing tool: {tool_name} with parameters: {kwargs}")
    if tool_name not in AVAILABLE_TOOLS:
        return {"error": f"Unknown tool: {tool_name}"}

    try:
        tool_function = AVAILABLE_TOOLS[tool_name]["function"]
        result = tool_function(**kwargs)
        return {
            "tool_name": tool_name,
            "parameters": kwargs,
            "result": result
        }
    except Exception as e:
        print(f"Tool execution failed: {str(e)}")
        return {
            "tool_name": tool_name,
            "parameters": kwargs,
            "error": f"Tool execution error: {str(e)}"
        }


# def parse_tool_calls(text: str) -> list:
#     """
#     Parse tool calls from model output.
#     Expected format: [TOOL_CALL:tool_name(param1=value1, param2=value2)]
#     """
#     tool_calls = []
#     #pattern = r'\[TOOL_CALL:(\w+)\((.*?)\)\]'
#     pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|<tool_call>\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*</tool_call>)'
#     matches = re.findall(pattern, text)
#     print(matches)
#
#     for tool_name, params_str in matches:
#         try:
#             params = {}
#             if params_str.strip():
#                 param_pairs = params_str.split(',')
#                 for pair in param_pairs:
#                     if '=' in pair:
#                         key, value = pair.split('=', 1)
#                         key = key.strip()
#                         value = value.strip().strip('"\'')  # Remove quotes
#                         params[key] = value
#             tool_calls.append({
#                 "tool_name": tool_name,
#                 "parameters": params,
#                 "original_call": f"[TOOL_CALL:{tool_name}({params_str})]"  # Store original call for replacement
#             })
#         except Exception as e:
#             print(f"Error parsing tool call '{tool_name}({params_str})': {e}")
#             continue
#
#     return tool_calls

def parse_tool_calls(text: str) -> list:
    """
    Parse tool calls from model output.
    Supports both formats:
    - [TOOL_CALL:tool_name(param1=value1, param2=value2)]
    - <tool_call>{"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}}</tool_call>
    """
    tool_calls = []

    # Pattern for both formats
    pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|<tool_call>\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*</tool_call>)'
    matches = re.findall(pattern, text)
    print("Raw matches:", matches)

    for match in matches:
        full_match, old_tool_name, old_params, json_tool_name, json_params = match

        # Determine which format was matched
        if old_tool_name:  # Old format: [TOOL_CALL:tool_name(params)]
            tool_name = old_tool_name
            params_str = old_params
            original_call = f"[TOOL_CALL:{tool_name}({params_str})]"

            try:
                params = {}
                if params_str.strip():
                    param_pairs = params_str.split(',')
                    for pair in param_pairs:
                        if '=' in pair:
                            key, value = pair.split('=', 1)
                            key = key.strip()
                            value = value.strip().strip('"\'')  # Remove quotes
                            params[key] = value

                tool_calls.append({
                    "tool_name": tool_name,
                    "parameters": params,
                    "original_call": original_call
                })

            except Exception as e:
                print(f"Error parsing old format tool call '{tool_name}({params_str})': {e}")
                continue

        elif json_tool_name:  # JSON format: <tool_call>...</tool_call>
            tool_name = json_tool_name
            params_str = json_params
            original_call = full_match

            try:
                params = {}
                if params_str.strip():
                    # Parse JSON-like parameters
                    # Handle the format: "operation": "add", "num1": "125", "num2": "675"
                    param_pairs = params_str.split(',')
                    for pair in param_pairs:
                        if ':' in pair:
                            key, value = pair.split(':', 1)
                            key = key.strip().strip('"\'')  # Remove quotes and whitespace
                            value = value.strip().strip('"\'')  # Remove quotes and whitespace
                            params[key] = value

                tool_calls.append({
                    "tool_name": tool_name,
                    "parameters": params,
                    "original_call": original_call
                })

            except Exception as e:
                print(f"Error parsing JSON format tool call '{tool_name}': {e}")
                continue

    return tool_calls

def process_tool_calls(text: str) -> str:
    """Process tool calls in the generated text and replace with results"""
    tool_calls = parse_tool_calls(text)

    if not tool_calls:
        return text

    processed_text = text

    for tool_call in tool_calls:
        tool_name = tool_call["tool_name"]
        parameters = tool_call["parameters"]
        original_call = tool_call["original_call"]

        try:
            # Validate parameters before execution
            if not isinstance(parameters, dict):
                raise ValueError(f"Invalid parameters for tool {tool_name}: {parameters}")

            # Execute tool
            result = execute_tool_call(tool_name, **parameters)

            # Create replacement text
            if "error" in result:
                replacement = f"[TOOL_ERROR: {result['error']}]"
            else:
                if "result" in result["result"]:
                    replacement = f"[TOOL_RESULT: {result['result']['formatted']}]"
                else:
                    replacement = f"[TOOL_RESULT: {result['result']}]"

            # Replace tool call with result
            processed_text = processed_text.replace(original_call, replacement)

        except Exception as e:
            print(f"Error processing tool call '{tool_name}': {e}")
            replacement = f"[TOOL_ERROR: Failed to process tool call: {str(e)}]"
            processed_text = processed_text.replace(original_call, replacement)

    return processed_text

def monitor_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        cached = torch.cuda.memory_reserved() / 1e9
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")

def generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground") -> str:
    """
    Run inference with the given task (system prompt) and user input using the specified config.
    """
    load_model()

    config = INFERENCE_CONFIGS[config_name]

    input_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]

    prompt_text = tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    input_length = len(tokenizer.encode(prompt_text))
    context_length = min(input_length, 3584)  # Leave room for generation

    inputs = tokenizer(
        prompt_text,
        return_tensors="pt",
        truncation=True,
        max_length=context_length,
        padding=False
    ).to(model.device)

    actual_input_length = inputs['input_ids'].shape[1]
    max_new_tokens = min(config["max_new_tokens_cap"], 4096 - actual_input_length - 10)
    max_new_tokens = max(config["min_tokens"], max_new_tokens)

    with torch.no_grad():
        start_time = time.time()
        outputs = model.generate(
            **inputs,
            do_sample=config["do_sample"],
            temperature=config["temperature"],
            top_p=config["top_p"],
            use_cache=config["use_cache"],
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            # Memory optimizations
            output_attentions=False,
            output_hidden_states=False,
            return_dict_in_generate=False,
        )
        inference_time = time.time() - start_time
        print(f"Inference time: {inference_time:.2f} seconds")

        memory = model.get_memory_footprint() / 1e6
        monitor_memory()
        print(f"Memory footprint: {memory:,.1f} MB")

    # Clean up
    gc.collect()

    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if prompt_text in full_text:
        response_start = full_text.find(prompt_text) + len(prompt_text)
        generated_response = full_text[response_start:].strip()
    else:
        # More robust fallback: try to extract response after the last user message
        generated_response = full_text.strip()
        try:
            # Look for common assistant/response indicators
            response_indicators = ["Assistant:", "<|assistant|>", "[/INST]", "Response:"]
            for indicator in response_indicators:
                if indicator in full_text:
                    parts = full_text.split(indicator)
                    if len(parts) > 1:
                        generated_response = parts[-1].strip()
                        break

            # If no indicator found, try to remove the input part
            user_message = user_input
            if user_message in full_text:
                parts = full_text.split(user_message)
                if len(parts) > 1:
                    generated_response = parts[-1].strip()
        except Exception:
            generated_response = full_text.strip()

    # Process any tool calls in the generated response
    generated_response = process_tool_calls(generated_response)
   # print('Input tokens:', inputs.input_ids.numel())
    #print('Output tokens:', outputs.input_ids.numel())
   # print('Output tokens:', outputs['input_ids'].numel())
    return generated_response