File size: 32,881 Bytes
f647629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d190a0
f647629
 
 
 
 
 
 
7d190a0
f647629
 
 
 
7d190a0
 
 
 
 
 
 
 
 
 
f647629
 
7d190a0
f647629
7d190a0
f647629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
"""
Service layer for Weave API.

This module provides high-level services for querying and processing Weave traces.
It orchestrates the client, query builder, and processor components.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Set

from wandb_mcp_server.utils import get_rich_logger, get_server_args
from wandb_mcp_server.weave_api.client import WeaveApiClient
from wandb_mcp_server.weave_api.models import QueryResult
from wandb_mcp_server.weave_api.processors import TraceProcessor
from wandb_mcp_server.weave_api.query_builder import QueryBuilder

# Import CallSchema to validate column names
try:
    from weave.trace_server.trace_server_interface import CallSchema

    VALID_COLUMNS = set(CallSchema.__annotations__.keys())
    HAVE_CALL_SCHEMA = True
except ImportError:
    # Fallback if CallSchema isn't available
    VALID_COLUMNS = {
        "id",
        "project_id",
        "op_name",
        "display_name",
        "trace_id",
        "parent_id",
        "started_at",
        "attributes",
        "inputs",
        "ended_at",
        "exception",
        "output",
        "summary",
        "wb_user_id",
        "wb_run_id",
        "deleted_at",
        "storage_size_bytes",
        "total_storage_size_bytes",
    }
    HAVE_CALL_SCHEMA = False

logger = get_rich_logger(__name__)


class TraceService:
    """Service for querying and processing Weave traces."""

    # Define cost fields once as a class constant
    COST_FIELDS = {"total_cost", "completion_cost", "prompt_cost"}

    # Define synthetic columns that shouldn't be passed to the API but can be reconstructed
    SYNTHETIC_COLUMNS = {"costs"}

    # Define latency field mapping
    LATENCY_FIELD_MAPPING = {"latency_ms": "summary.weave.latency_ms"}

    def __init__(
        self,
        api_key: Optional[str] = None,
        server_url: Optional[str] = None,
        retries: int = 3,
        timeout: int = 10,
    ):
        """Initialize the TraceService.

        Args:
            api_key: W&B API key. If not provided, uses WANDB_API_KEY env var.
            server_url: Weave API server URL. Defaults to 'https://trace.wandb.ai'.
            retries: Number of retries for failed requests.
            timeout: Request timeout in seconds.
        """
        # If no API key provided, try to get from environment
        if api_key is None:
            import os
            # Try to get from environment (set by auth middleware for HTTP or user for STDIO)
            api_key = os.environ.get("WANDB_API_KEY")
            
            # If still no key, try get_server_args as fallback
            if not api_key:
                server_config = get_server_args()
                api_key = server_config.wandb_api_key

        # Pass the resolved API key to WeaveApiClient.
        # If api_key is None or "", WeaveApiClient will raise its ValueError.
        self.client = WeaveApiClient(
            api_key=api_key,
            server_url=server_url,
            retries=retries,
            timeout=timeout,
        )

        # Initialize collection for invalid columns (for warning messages)
        self.invalid_columns = set()

    def _validate_and_filter_columns(
        self, columns: Optional[List[str]]
    ) -> tuple[Optional[List[str]], List[str], Set[str]]:
        """Validate columns against CallSchema and filter out synthetic/invalid columns.

        Handles mapping of 'latency_ms' to 'summary.weave.latency_ms'.

        Args:
            columns: List of columns.

        Returns:
            Tuple of (filtered_columns_for_api, requested_synthetic_columns, invalid_columns_reported)
        """
        if not columns:
            return (
                None,
                [],
                set(),
            )  # Return None for filtered_columns_for_api if input is None

        filtered_columns_for_api: list[str] = []
        requested_synthetic_columns: list[str] = []
        invalid_columns_reported: set[str] = set()

        processed_columns = (
            set()
        )  # To avoid duplicate processing if a column is listed multiple times

        for col_name in columns:
            if col_name in processed_columns:
                continue
            processed_columns.add(col_name)

            if col_name == "latency_ms":
                # 'latency_ms' is synthetic, its data comes from 'summary.weave.latency_ms'
                requested_synthetic_columns.append("latency_ms")
                # Ensure the source field is requested from the API
                source_field = self.LATENCY_FIELD_MAPPING["latency_ms"]
                if source_field not in filtered_columns_for_api:
                    filtered_columns_for_api.append(source_field)
                # Also ensure 'summary' itself is added if not already, as 'summary.weave.latency_ms' implies 'summary'
                if (
                    "summary" not in filtered_columns_for_api
                    and source_field.startswith("summary.")
                ):
                    filtered_columns_for_api.append("summary")
                logger.info(
                    f"Column 'latency_ms' requested: will be synthesized from '{source_field}'. Added '{source_field}' to API columns."
                )

            elif col_name == "costs":
                # 'costs' is synthetic, its data comes from 'summary.weave.costs'
                requested_synthetic_columns.append("costs")
                # Ensure the source field ('summary') is requested
                if "summary" not in filtered_columns_for_api:
                    filtered_columns_for_api.append("summary")
                logger.info(
                    "Column 'costs' requested: will be synthesized from 'summary.weave.costs'. Added 'summary' to API columns."
                )

            elif col_name == "status":
                # 'status' can be top-level or from 'summary.weave.status'
                requested_synthetic_columns.append("status")
                # Add 'status' to API columns to try fetching top-level first.
                # If not present, it will be synthesized from summary.
                if "status" not in filtered_columns_for_api:
                    filtered_columns_for_api.append("status")
                if (
                    "summary" not in filtered_columns_for_api
                ):  # Also ensure summary for fallback
                    filtered_columns_for_api.append("summary")
                logger.info(
                    "Column 'status' requested: will attempt direct fetch or synthesize from 'summary.weave.status'."
                )

            elif col_name in VALID_COLUMNS:
                # Direct valid column
                if col_name not in filtered_columns_for_api:
                    filtered_columns_for_api.append(col_name)

            elif "." in col_name:  # Potentially a dot-separated path
                base_field = col_name.split(".")[0]
                if base_field in VALID_COLUMNS:
                    # Valid nested field (e.g., "summary.weave.latency_ms", "attributes.foo")
                    if col_name not in filtered_columns_for_api:
                        filtered_columns_for_api.append(col_name)
                    logger.info(
                        f"Nested column field '{col_name}' requested, added to API columns."
                    )
                else:
                    logger.warning(
                        f"Invalid base field '{base_field}' in nested column '{col_name}'. It will be ignored."
                    )
                    invalid_columns_reported.add(col_name)
            else:
                # Neither a direct valid column, nor a recognized synthetic, nor a valid-looking nested path
                logger.warning(
                    f"Invalid column '{col_name}' requested. It will be ignored."
                )
                invalid_columns_reported.add(col_name)

        # Ensure filtered_columns_for_api does not have duplicates and maintains order as much as possible
        # (though order to the API might not matter as much as presence)
        final_filtered_columns_for_api = []
        seen_in_final = set()
        for fc in filtered_columns_for_api:
            if fc not in seen_in_final:
                final_filtered_columns_for_api.append(fc)
                seen_in_final.add(fc)

        return (
            final_filtered_columns_for_api,
            requested_synthetic_columns,
            invalid_columns_reported,
        )

    def _ensure_required_columns_for_synthetic(
        self,
        filtered_columns: Optional[List[str]],
        requested_synthetic_columns: List[str],
    ) -> Optional[List[str]]:
        """Ensure required columns for synthetic fields are included.

        Args:
            filtered_columns: List of columns after filtering out synthetic ones.
            requested_synthetic_columns: List of requested synthetic columns.

        Returns:
            Updated filtered columns list with required columns added.
        """
        if not filtered_columns:
            filtered_columns = []

        required_columns = set(filtered_columns)

        # Add required columns for synthesizing costs
        if "costs" in requested_synthetic_columns:
            # Costs data comes from summary.weave.costs
            if "summary" not in required_columns:
                logger.info("Adding 'summary' column as it's required for costs data")
                required_columns.add("summary")

        # Add other required columns for other synthetic fields as needed

        return list(required_columns)

    def _add_synthetic_columns(
        self,
        traces: List[Dict[str, Any]],
        requested_synthetic_columns: List[str],
        invalid_columns: Set[str],
    ) -> List[Dict[str, Any]]:
        """Add synthetic columns back to the traces and add warnings for invalid columns.

        Args:
            traces: List of trace dictionaries.
            requested_synthetic_columns: List of requested synthetic columns.
            invalid_columns: Set of invalid column names that were requested.

        Returns:
            Updated traces with synthetic columns added and invalid column warnings.
        """
        if not requested_synthetic_columns and not invalid_columns:
            return traces

        updated_traces = []

        for trace in traces:
            updated_trace = trace.copy()

            # Add costs data if requested
            if "costs" in requested_synthetic_columns:
                costs_data = trace.get("summary", {}).get("weave", {}).get("costs", {})
                if costs_data:
                    logger.debug(
                        f"Adding synthetic 'costs' column with {len(costs_data)} providers"
                    )
                    updated_trace["costs"] = costs_data
                else:
                    logger.warning(f"No costs data found in trace {trace.get('id')}")
                    updated_trace["costs"] = {}

            # Add status from summary if requested
            if "status" in requested_synthetic_columns:
                status = trace.get("status")  # Check if it's already in the trace
                if not status:
                    # Extract from summary.weave.status
                    status = trace.get("summary", {}).get("weave", {}).get("status")
                    if status:
                        logger.debug(
                            f"Adding synthetic 'status' from summary: {status}"
                        )
                        updated_trace["status"] = status
                    else:
                        logger.warning(
                            f"No status data found in trace {trace.get('id')}"
                        )
                        updated_trace["status"] = None

            # Add latency_ms from summary if requested
            if "latency_ms" in requested_synthetic_columns:
                latency = trace.get("latency_ms")  # Check if it's already in the trace
                if latency is None:
                    # Extract from summary.weave.latency_ms
                    latency = (
                        trace.get("summary", {}).get("weave", {}).get("latency_ms")
                    )
                    if latency is not None:
                        logger.debug(
                            f"Adding synthetic 'latency_ms' from summary: {latency}"
                        )
                        updated_trace["latency_ms"] = latency
                    else:
                        logger.warning(
                            f"No latency_ms data found in trace {trace.get('id')}"
                        )
                        updated_trace["latency_ms"] = None

            # Add warnings for invalid columns
            for col in invalid_columns:
                warning_message = f"{col} is not a valid column name, no data returned"
                updated_trace[col] = warning_message

            updated_traces.append(updated_trace)

        return updated_traces

    def query_traces(
        self,
        entity_name: str,
        project_name: str,
        filters: Optional[Dict[str, Any]] = None,
        sort_by: str = "started_at",
        sort_direction: str = "desc",
        limit: Optional[int] = None,
        offset: int = 0,
        include_costs: bool = True,
        include_feedback: bool = True,
        columns: Optional[List[str]] = None,
        expand_columns: Optional[List[str]] = None,
        truncate_length: Optional[int] = 200,
        return_full_data: bool = False,
        metadata_only: bool = False,
    ) -> QueryResult:
        """Query traces from the Weave API.

        Args:
            entity_name: Weights & Biases entity name.
            project_name: Weights & Biands project name.
            filters: Dictionary of filter conditions.
            sort_by: Field to sort by.
            sort_direction: Sort direction ('asc' or 'desc').
            limit: Maximum number of results to return.
            offset: Number of results to skip (for pagination).
            include_costs: Include tracked API cost information in the results.
            include_feedback: Include Weave annotations in the results.
            columns: List of specific columns to include in the results.
            expand_columns: List of columns to expand in the results.
            truncate_length: Maximum length for string values.
            return_full_data: Whether to include full untruncated trace data.
            metadata_only: Whether to only include metadata without traces.

        Returns:
            QueryResult object with metadata and optionally traces.
        """
        # Clear invalid columns from previous requests
        self.invalid_columns = set()

        # Special handling for cost-based sorting
        client_side_cost_sort = sort_by in self.COST_FIELDS

        # Handle latency field mapping
        if sort_by in self.LATENCY_FIELD_MAPPING:
            logger.info(
                f"Mapping sort field '{sort_by}' to '{self.LATENCY_FIELD_MAPPING[sort_by]}'"
            )
            server_sort_by = self.LATENCY_FIELD_MAPPING[sort_by]
            server_sort_direction = sort_direction
        elif client_side_cost_sort:
            include_costs = True
            server_sort_by = "started_at"
            server_sort_direction = sort_direction
        elif sort_by == "latency_ms":  # Added specific handling for latency_ms sort
            logger.info(
                f"Sort by 'latency_ms' requested. Will sort by server field '{self.LATENCY_FIELD_MAPPING['latency_ms']}'."
            )
            server_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"]
            server_sort_direction = sort_direction
        elif "." in sort_by:  # Handles general dot-separated paths
            base_field = sort_by.split(".")[0]
            if base_field in VALID_COLUMNS:
                logger.info(f"Using nested sort field for server: {sort_by}")
                server_sort_by = sort_by
                server_sort_direction = sort_direction
            else:
                logger.warning(
                    f"Invalid base field '{base_field}' in sort_by '{sort_by}', falling back to 'started_at'."
                )
                server_sort_by = "started_at"
                server_sort_direction = sort_direction
        elif sort_by not in VALID_COLUMNS:
            logger.warning(
                f"Invalid sort field '{sort_by}', falling back to 'started_at'."
            )
            server_sort_by = "started_at"
            server_sort_direction = sort_direction
        else:  # sort_by is in VALID_COLUMNS and not a special case
            server_sort_by = sort_by
            server_sort_direction = sort_direction

        # Validate and filter columns using CallSchema
        filtered_api_columns, rs_columns, inv_columns = (
            self._validate_and_filter_columns(columns)
        )

        # Store invalid columns for later
        self.invalid_columns = inv_columns  # Corrected variable name

        # If costs was requested as a column (now checked via rs_columns), make sure to include it
        if "costs" in rs_columns:  # Corrected check
            include_costs = True

        # Manually add latency_ms to synthetic fields if requested - This is now handled in _validate_and_filter_columns
        # if columns and "latency_ms" in columns and "latency_ms" not in requested_synthetic_columns:
        #     requested_synthetic_columns.append("latency_ms")

        # Ensure required columns for synthetic fields are included - This is also largely handled by _validate_and_filter_columns logic
        # filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns)

        # Prepare query parameters
        query_params = {
            "entity_name": entity_name,
            "project_name": project_name,
            "filters": filters or {},
            "sort_by": server_sort_by,
            "sort_direction": server_sort_direction,
            "limit": None
            if client_side_cost_sort
            else limit,  # No limit if we're sorting by cost
            "offset": offset,
            "include_costs": include_costs,
            "include_feedback": include_feedback,
            "columns": filtered_api_columns,  # Use the columns intended for the API
            "expand_columns": expand_columns,
        }

        # Build request body
        request_body = QueryBuilder.prepare_query_params(query_params)

        # Extract synthetic fields if any were specified
        synthetic_fields = (
            request_body.pop("_synthetic_fields", [])
            if "_synthetic_fields" in request_body
            else []
        )

        # Make sure all requested synthetic columns are included in synthetic_fields
        for col in rs_columns:  # Use rs_columns
            if col not in synthetic_fields:
                synthetic_fields.append(col)

        # Execute query
        all_traces = list(self.client.query_traces(request_body))

        # Add synthetic columns and invalid column warnings back to the results
        if rs_columns or inv_columns:  # Use corrected variables
            all_traces = self._add_synthetic_columns(
                all_traces, rs_columns, inv_columns
            )

        # Client-side cost-based sorting if needed
        if client_side_cost_sort and all_traces:
            logger.info(f"Performing client-side sorting by {sort_by}")
            # Sort traces by cost
            all_traces.sort(
                key=lambda t: TraceProcessor.get_cost(t, sort_by),
                reverse=(sort_direction == "desc"),
            )
            # Apply limit if specified
            if limit is not None:
                all_traces = all_traces[:limit]

        # If we need to synthesize fields, do it
        if synthetic_fields:
            logger.info(f"Synthesizing fields: {synthetic_fields}")
            all_traces = [
                TraceProcessor.synthesize_fields(trace, synthetic_fields)
                for trace in all_traces
            ]

        # Process traces
        result = TraceProcessor.process_traces(
            traces=all_traces,
            truncate_length=truncate_length or 0,
            return_full_data=return_full_data,
            metadata_only=metadata_only,
        )

        return result

    def query_paginated_traces(
        self,
        entity_name: str,
        project_name: str,
        chunk_size: int = 20,
        filters: Optional[Dict[str, Any]] = None,
        sort_by: str = "started_at",
        sort_direction: str = "desc",
        target_limit: Optional[int] = None,
        include_costs: bool = True,
        include_feedback: bool = True,
        columns: Optional[List[str]] = None,
        expand_columns: Optional[List[str]] = None,
        truncate_length: Optional[int] = 200,
        return_full_data: bool = False,
        metadata_only: bool = False,
    ) -> QueryResult:
        """Query traces with pagination.

        Args:
            entity_name: Weights & Biases entity name.
            project_name: Weights & Biands project name.
            chunk_size: Number of traces to retrieve in each chunk.
            filters: Dictionary of filter conditions.
            sort_by: Field to sort by.
            sort_direction: Sort direction ('asc' or 'desc').
            target_limit: Maximum total number of results to return.
            include_costs: Include tracked API cost information in the results.
            include_feedback: Include Weave annotations in the results.
            columns: List of specific columns to include in the results.
            expand_columns: List of columns to expand in the results.
            truncate_length: Maximum length for string values.
            return_full_data: Whether to include full untruncated trace data.
            metadata_only: Whether to only include metadata without traces.

        Returns:
            QueryResult object with metadata and optionally traces.
        """
        # Special handling for cost-based sorting
        client_side_cost_sort = sort_by in self.COST_FIELDS

        # Determine effective_sort_by for the server
        effective_sort_by = "started_at"  # Default
        if sort_by == "latency_ms":
            effective_sort_by = self.LATENCY_FIELD_MAPPING["latency_ms"]
            logger.info(
                f"Paginated sort by 'latency_ms', server will use '{effective_sort_by}'."
            )
        elif "." in sort_by:
            base_field = sort_by.split(".")[0]
            if base_field in VALID_COLUMNS:
                effective_sort_by = sort_by
                logger.info(
                    f"Paginated sort by nested field '{sort_by}', server will use it directly."
                )
            else:
                logger.warning(
                    f"Paginated sort by invalid nested field '{sort_by}', defaulting to 'started_at'."
                )
        elif (
            sort_by in VALID_COLUMNS and sort_by not in self.COST_FIELDS
        ):  # Exclude COST_FIELDS as they are client-sorted
            effective_sort_by = sort_by
        elif (
            sort_by not in self.COST_FIELDS
        ):  # If not valid and not cost, warn and default
            logger.warning(
                f"Paginated sort by invalid field '{sort_by}', defaulting to 'started_at'."
            )

        # Validate and filter columns using CallSchema
        # Pass the original 'columns'
        filtered_api_columns, rs_columns, inv_columns = (
            self._validate_and_filter_columns(columns)
        )

        # Store invalid columns for later
        self.invalid_columns = inv_columns  # Corrected

        # If costs was requested as a column, make sure to include it
        if "costs" in rs_columns:  # Corrected
            include_costs = True

        # Ensure required columns for synthetic fields are included - Handled by _validate_and_filter_columns
        # filtered_api_columns = self._ensure_required_columns_for_synthetic(filtered_api_columns, rs_columns)

        if client_side_cost_sort:
            logger.info(f"Cost-based sorting detected: {sort_by}")
            all_traces = self._query_for_cost_sorting(
                entity_name=entity_name,
                project_name=project_name,
                filters=filters,
                sort_by=sort_by,
                sort_direction=sort_direction,
                target_limit=target_limit,
                columns=filtered_api_columns,  # Pass filtered columns for API
                expand_columns=expand_columns,
                include_costs=True,
                include_feedback=include_feedback,
                requested_synthetic_columns=rs_columns,  # Pass synthetic columns request
                invalid_columns=inv_columns,  # Pass invalid columns
            )
        else:
            # Normal paginated query logic
            all_traces = []
            current_offset = 0

            while True:
                logger.info(
                    f"Querying chunk with offset {current_offset}, size {chunk_size}"
                )
                remaining = (
                    target_limit - len(all_traces) if target_limit else chunk_size
                )
                current_chunk_size = (
                    min(chunk_size, remaining) if target_limit else chunk_size
                )

                chunk_result = self.query_traces(
                    entity_name=entity_name,
                    project_name=project_name,
                    filters=filters,
                    sort_by=effective_sort_by,
                    sort_direction=sort_direction,
                    limit=current_chunk_size,
                    offset=current_offset,
                    include_costs=include_costs,
                    include_feedback=include_feedback,
                    columns=columns,  # Pass original 'columns' here, query_traces will validate and filter.
                    # This ensures that if 'latency_ms' was requested, it's handled correctly
                    # by the nested call to _validate_and_filter_columns inside query_traces.
                    expand_columns=expand_columns,
                    return_full_data=True,  # We want raw data for now
                    metadata_only=False,
                )

                # Get the traces from the QueryResult and handle both None and empty list cases
                traces_from_chunk = (
                    chunk_result.traces if chunk_result and chunk_result.traces else []
                )
                if not traces_from_chunk:
                    break

                all_traces.extend(traces_from_chunk)

                if len(traces_from_chunk) < current_chunk_size or (
                    target_limit and len(all_traces) >= target_limit
                ):
                    break

                current_offset += chunk_size

        # Process all traces at once with appropriate parameters
        if target_limit and all_traces:
            all_traces = all_traces[:target_limit]

        result = TraceProcessor.process_traces(
            traces=all_traces,
            truncate_length=truncate_length or 0,
            return_full_data=return_full_data,
            metadata_only=metadata_only,
        )
        logger.debug(
            f"Final result from query_paginated_traces:\n\n{len(result.model_dump_json(indent=2))}\n"
        )
        assert isinstance(result, QueryResult), (
            f"Result type must be a QueryResult, found: {type(result)}"
        )
        return result

    def _query_for_cost_sorting(
        self,
        entity_name: str,
        project_name: str,
        filters: Optional[Dict[str, Any]] = None,
        sort_by: str = "total_cost",
        sort_direction: str = "desc",
        target_limit: Optional[int] = None,
        columns: Optional[List[str]] = None,
        expand_columns: Optional[List[str]] = None,
        include_costs: bool = True,
        include_feedback: bool = True,
        requested_synthetic_columns: Optional[List[str]] = None,
        invalid_columns: Optional[Set[str]] = None,
    ) -> List[Dict[str, Any]]:
        """Special two-stage query logic for cost-based sorting.

        Args:
            entity_name: Weights & Biases entity name.
            project_name: Weights & Biands project name.
            filters: Dictionary of filter conditions.
            sort_by: Cost field to sort by.
            sort_direction: Sort direction ('asc' or 'desc').
            target_limit: Maximum number of results to return.
            columns: List of specific columns to include in the results.
            expand_columns: List of columns to expand in the results.
            include_costs: Include tracked API cost information in the results.
            include_feedback: Include Weave annotations in the results.
            requested_synthetic_columns: List of synthetic columns requested by the user.
            invalid_columns: Set of invalid column names that were requested.

        Returns:
            List of trace dictionaries sorted by the specified cost field.
        """
        if invalid_columns is None:
            invalid_columns = set()

        # First pass: Fetch all trace IDs and costs
        first_pass_query = {
            "entity_name": entity_name,
            "project_name": project_name,
            "filters": filters or {},
            "sort_by": "started_at",  # Use a standard sort for the first pass
            "sort_direction": "desc",
            "limit": 1000000,  # Explicitly set a large limit to get all traces
            "include_costs": True,  # We need costs for sorting
            "include_feedback": False,  # Don't need feedback for the first pass
            "columns": ["id", "summary"],  # Need summary for costs data
        }

        first_pass_request = QueryBuilder.prepare_query_params(first_pass_query)
        first_pass_results = list(self.client.query_traces(first_pass_request))

        logger.info(
            f"First pass of cost sorting request retrieved {len(first_pass_results)} traces"
        )

        # Filter and sort by cost
        filtered_results = [
            t
            for t in first_pass_results
            if TraceProcessor.get_cost(t, sort_by) is not None
        ]

        filtered_results.sort(
            key=lambda t: TraceProcessor.get_cost(t, sort_by),
            reverse=(sort_direction == "desc"),
        )

        # Get the IDs of the top N traces
        top_ids = (
            [t["id"] for t in filtered_results[:target_limit] if "id" in t]
            if target_limit
            else [t["id"] for t in filtered_results if "id" in t]
        )

        logger.info(f"After sorting by {sort_by}, selected {len(top_ids)} trace IDs")

        if not top_ids:
            return []

        # Second pass: Fetch the full details for the selected traces
        second_pass_query = {
            "entity_name": entity_name,
            "project_name": project_name,
            "filters": {"call_ids": top_ids},
            "include_costs": include_costs,
            "include_feedback": include_feedback,
            "columns": columns,
            "expand_columns": expand_columns,
        }

        # Make sure we request summary if costs were requested
        if requested_synthetic_columns and "costs" in requested_synthetic_columns:
            if not columns or "summary" not in columns:
                if not second_pass_query["columns"]:
                    second_pass_query["columns"] = ["summary"]
                elif "summary" not in second_pass_query["columns"]:
                    second_pass_query["columns"].append("summary")
                logger.info("Added 'summary' to columns for cost data retrieval")

        second_pass_request = QueryBuilder.prepare_query_params(second_pass_query)
        second_pass_results = list(self.client.query_traces(second_pass_request))

        logger.info(f"Second pass retrieved {len(second_pass_results)} traces")

        # Add synthetic columns and invalid column warnings back to the results
        if requested_synthetic_columns or invalid_columns:
            second_pass_results = self._add_synthetic_columns(
                second_pass_results,
                requested_synthetic_columns or [],
                invalid_columns,
            )

        # Ensure the results are in the same order as the IDs
        id_to_index = {id: i for i, id in enumerate(top_ids)}
        second_pass_results.sort(
            key=lambda t: id_to_index.get(t.get("id"), float("inf"))
        )

        return second_pass_results