"""Module for querying the W&B GraphQL API."""
import copy
import logging
import traceback
from typing import Any, Dict, List, Optional
import wandb
from graphql import parse
from graphql.language import ast as gql_ast
from graphql.language import printer as gql_printer
from graphql.language import visitor as gql_visitor
from wandb_gql import gql # This must be imported after wandb
from wandb_mcp_server.utils import get_rich_logger
logger = get_rich_logger(__name__)
QUERY_WANDB_GQL_TOOL_DESCRIPTION = """Execute an arbitrary GraphQL query against the Weights & Biases (W&B) Models API.
Use this tool to query data from Weights & Biases Models features, including experiment tracking runs,
model registry, reports, artifacts, sweeps.
**IMPORTANT PRODUCT DISTINCTION:**
W&B offers two distinct products with different purposes:
1. W&B Models: A system for ML experiment tracking, hyperparameter optimization, and model
lifecycle management. Use `query_wandb_tool` for questions about:
- Experiment runs, metrics, and performance comparisons
- Artifact management and model registry
- Hyperparameter optimization and sweeps
- Project dashboards and reports
2. W&B Weave: A toolkit for LLM and GenAI application observability and evaluation. Use
`query_weave_traces_tool` for questions about:
- Execution traces and paths of LLM operations
- LLM inputs, outputs, and intermediate results
- Chain of thought visualization and debugging
- LLM evaluation results and feedback
FYI: The Weigths & Biases platform is owned by Coreweave. If there are queries related to W&B, wandb \
or weave and Coreweave, they might be related to W&B products or features that leverage Coreweave's \
GPU or compute infrastructure.
**USE CASE SELECTOR - READ FIRST:**
- For runs, metrics, experiments, artifacts, sweeps etc → use query_wandb_tool (this tool)
- For traces, LLM calls, chain-of-thought, LLM evaluations, AI agent traces, AI apps etc → use query_weave_traces_tool
=====================================================================
⚠️ TOOL SELECTION WARNING ⚠️
This tool is ONLY for WANDB MODELS DATA (MLOps), NOT for LLM TRACES or GENAI APPS!
=====================================================================
**KEYWORD GUIDE:**
If user question contains:
- "runs", "experiments", "metrics" → Use query_wandb_tool (this tool)
- "traces", "LLM calls" etc → Use query_weave_traces_tool
**COMMON MISUSE CASES:**
❌ "Looking at performance of my latest weave evals" - Use query_weave_traces_tool
❌ "what system prompt was used for my openai call" - Use query_weave_traces_tool
❌ "Show me the traces for my weave evals" - Use query_weave_traces_tool
**STEP 1: ANALYZE THE USER QUERY FIRST!**
Before constructing the GraphQL query, determine how the user is referring to W&B entities, especially runs:
- Is the user providing a short, 8-character **Run ID** (e.g., `gtng2y4l`, `h0fm5qp5`)?
- Or are they providing a longer, human-readable **Display Name** (e.g., `transformer_train_run_123`, `eval_on_benchmark_v2`)?
Your choice of query structure depends heavily on this analysis (see Key Concepts and Examples below).
**KEY CONCEPTS - READ CAREFULLY:**
* **Run ID vs. Display Name:**
* To fetch a **single, specific run** using its unique 8-character ID (e.g., `gtng2y4l`), \
use the `run(name: $runId)` field. The variable `$runId` MUST be the ID, not the display name.
* To **find runs based on their human-readable `displayName`** (e.g., `my-cool-experiment-1`), \
use the `runs` collection field with a `filters` argument like: `runs(filters: "{\\"displayName\\":\
{\\"$eq\\":\\"my-cool-experiment-1\\"}}")`. This might return multiple runs if display names are not unique.
* **Filters require JSON Strings:** When using the `filters` argument (e.g., for `runs`, `artifacts`), \
the value provided in the `variables` dictionary MUST be a JSON formatted *string*. Use `json.dumps()` in Python to create it.
* **Collections Require Pagination Structure:** Queries fetching lists/collections (like `project.runs`, \
`artifact.files`) MUST include the `edges { node { ... } } pageInfo { endCursor hasNextPage }` pattern.
* **Summary Metrics:** Use the `summaryMetrics` field (returns a JSON string) to access a run's summary \
dictionary, not the deprecated `summary` field.
This function allows interaction with W&B data (Projects, Runs, Artifacts, Sweeps, Reports, etc.)
using the GraphQL query language.
Parameters
----------
query : str
he GraphQL query string. This defines the operation (query/mutation),
the data to fetch (selection set), and any variables used.
variables : dict[str, Any] | None, optional
A dictionary of variables to pass to the query.
Keys should match variable names defined in the query
(e.g., $entity, $project). Values should match the
expected types (String, Int, Float, Boolean, ID, JSONString).
**Crucially, complex arguments like `filters` MUST be provided
as a JSON formatted *string*. Use `json.dumps()` in Python
to create this string.**
max_items : int, optional
Maximum number of items to fetch across all pages. Default is 100.
items_per_page : int, optional
Number of items to request per page. Default is 50.
Returns
-------
Dict[str, Any]
The aggregated GraphQL response dictionary.
**⚠️ CRITICAL WARNING: Run ID vs. Display Name ⚠️**
If the user query mentions a run using its **long, human-readable name** (Display Name), you **MUST** use the `runs(filters: ...)` approach shown in the examples.
**DO NOT** use `run(name: ...)` with a Display Name; it will fail because `name` expects the short Run ID. Use `run(name: ...)` **ONLY** when the user provides the 8-character Run ID.
Review the "Minimal Example: Run ID vs Display Name" and "Get Run by Display Name" examples carefully.
**⚠️ REQUIRED PAGINATION STRUCTURE ⚠️**
All collection queries MUST include the complete W&B connection pattern with these elements:
1. `edges` array containing nodes
2. `node` objects inside edges containing your data fields
3. `pageInfo` object with:
- `endCursor` field (to enable pagination)
- `hasNextPage` field (to determine if more data exists)
This is a strict requirement enforced by the pagination system. Queries without this
structure will fail with the error "Query doesn't follow the W&B connection pattern."
Example of required pagination structure for any collection:
```graphql
runs(first: 10) { # or artifacts, files, etc.
edges {
node {
id
name
# ... other fields you need
}
# cursor # Optional: include cursor if needed for specific pagination logic
}
pageInfo {
endCursor
hasNextPage
}
}
```
**LLM CONTEXT WINDOW MANAGEMENT**
The results of this tool are returned to a LLM. Be mindful of the context window of the LLM!
**WARNING: AVOID OPEN-ENDED QUERIES!**
Open-ended queries should be strictly avoided when:
- There are a lot of runs in the project (e.g., hundreds or thousands)
- There are runs with large amounts of data (e.g., many metrics, large configs, etc.)
Examples of problematic open-ended queries:
- Requesting all runs in a project without limits
- Requesting complete run histories without filtering specific metrics
- Requesting all files from artifacts without specifying names/types
Instead, always:
- Use the `first` parameter to limit the number of items returned (start small, e.g., 5-10)
- Apply specific filters to narrow down results (e.g., state, creation time, metrics)
- Request only the specific fields needed, avoid selecting everything
- Consider paginating results if necessary (don't request everything at once)
Bad:
```graphql
query AllRuns($entity: String!, $project: String!) {
project(name: $project, entityName: $entity) {
# Potentially huge response: requests all fields for all runs
runs { edges { node { id name state history summaryMetrics config files { edges { node { name size }}}}}}}
}
}
```
Good:
```graphql
query LimitedRuns($entity: String!, $project: String!) {
project(name: $project, entityName: $entity) {
# Limits runs, specifies filters, and selects only necessary fields
runs(first: 5, filters: "{\\"state\\":\\"finished\\"}") {
edges {
node {
id
name
createdAt
summaryMetrics # Get summary JSON, parse later if needed
}
}
pageInfo { endCursor hasNextPage } # Always include pageInfo for collections
}
}
}
```
Some tactics to consider to avoid exceeding the context window of the LLM when using this tool:
- First return just metadata about the wandb project or run you will be returning.
- Select only a subset of the data such as just particular columns or rows.
- If you need to return a large amount of data consider using the `query_wandb_tool` in a loop
- Break up the query into smaller chunks.
If you are returning just a sample subset of the data warn the user that this is a sample and that they should
use the tool again with additional filters or pagination to get a more complete view.
**Constructing GraphQL Queries:**
1. **Operation Type:** Start with `query` for fetching data or `mutation` for modifying data.
2. **Operation Name:** (Optional but recommended) A descriptive name (e.g., `ProjectInfo`).
3. **Variables Definition:** Define variables used in the query with their types (e.g., `($entity: String!, $project: String!)`). `!` means required.
4. **Selection Set:** Specify the fields you want to retrieve, nesting as needed based on the W&B schema.
**W&B Schema Overview:**
* **Core Types:** `Entity`, `Project`, `Run`, `Artifact`, `Sweep`, `Report`, `User`, `Team`.
* **Relationships:** Entities contain Projects. Projects contain Runs, Sweeps, Artifacts. Runs use/are used by Artifacts. Sweeps contain Runs.
* **Common Fields:** `id`, `name`, `description`, `createdAt`, `config` (JSONString), `summaryMetrics` (JSONString - **Note:** use this field,
not `summary`, to access the run's summary dictionary as a JSON string), `historyKeys` (List of String), etc.
* **Connections (Lists):** Many lists (like `project.runs`, `artifact.files`) use a connection pattern:
```graphql
runs(first: Int, after: String, filters: JSONString, order: String) {
edges { node { id name ... } cursor }
pageInfo { hasNextPage endCursor }
}
```
Use `first` for limit, `after` with `pageInfo.endCursor` for pagination, `filters` (as a JSON string) for complex filtering, and `order` for sorting.
* **Field Type Handling:**
- Some fields require subfield selection (e.g., `tags { name }`) while others are scalar (e.g., `historyKeys`).
- Check the schema if you get errors like "must have a selection of subfields" or "must not have a selection".
**Query Examples:**
* **Minimal Example: Run ID vs Display Name:**
* **A) User provides Run ID (e.g., "get info for run h0fm5qp5"):**
```graphql
query GetRunById($entity: String!, $project: String!, $runId: String!) {
project(name: $project, entityName: $entity) {
# Use run(name: ...) with the Run ID
run(name: $runId) {
id
name # This will be the Run ID
displayName # This is the human-readable name
}
}
}
```
```python
variables = {"entity": "...", "project": "...", "runId": "h0fm5qp5"}
```
* **B) User provides Display Name (e.g., "get info for run transformer_train_123"):**
```graphql
# Note: Querying *runs* collection and filtering
query GetRunByDisplayNameMinimal($project: String!, $entity: String!, $displayNameFilter: JSONString) {
project(name: $project, entityName: $entity) {
# Use runs(filters: ...) with the Display Name
runs(first: 1, filters: $displayNameFilter) {
edges {
node {
id
name # Run ID
displayName # Display Name provided by user
}
}
pageInfo { endCursor hasNextPage } # Required for collections
}
}
}
```
```python
import json
variables = {
"entity": "...",
"project": "...",
"displayNameFilter": json.dumps({"displayName": {"$eq": "transformer_train_123"}})
}
```
* **Get Project Info:** (Doesn't retrieve a collection, no pagination needed)
```graphql
query ProjectInfo($entity: String!, $project: String!) {
project(name: $project, entityName: $entity) {
id
name
entityName
description
runCount
}
}
```
```python
variables = {"entity": "my-entity", "project": "my-project"}
```
* **Get Sorted Runs:** (Retrieves a collection, requires pagination structure)
```graphql
query SortedRuns($project: String!, $entity: String!, $limit: Int, $order: String) {
project(name: $project, entityName: $entity) {
runs(first: $limit, order: $order) {
edges {
node { id name displayName state createdAt summaryMetrics }
cursor # Optional cursor
}
pageInfo { # Required for collections
hasNextPage
endCursor
}
}
}
}
```
```python
variables = {
"entity": "my-entity",
"project": "my-project",
"limit": 10,
"order": "+summary_metrics.accuracy" # Ascending order by accuracy
# Use "-createdAt" for newest first (default if order omitted)
# Use "+createdAt" for oldest first
}
```
* **Get Runs with Pagination and Filtering:** (Requires pagination structure)
```graphql
query FilteredRuns($project: String!, $entity: String!, $limit: Int, $cursor: String, $filters: JSONString, $order: String) {
project(name: $project, entityName: $entity) {
runs(first: $limit, after: $cursor, filters: $filters, order: $order) {
edges {
node { id name state createdAt summaryMetrics }
cursor # Optional cursor
}
pageInfo { endCursor hasNextPage } # Required
}
}
}
```
```python
# Corrected: Show filters as the required escaped JSON string
variables = {
"entity": "my-entity",
"project": "my-project",
"limit": 10,
"order": "-summary_metrics.accuracy", # Optional: sort
"filters": "{\"state\": \"finished\", \"summary_metrics.accuracy\": {\"$gt\": 0.9}}", # Escaped JSON string
# "cursor": previous_pageInfo_endCursor # Optional for next page
}
# Note: The *content* of the `filters` JSON string must adhere to the specific
# filtering syntax supported by the W&B API (e.g., using operators like `$gt`, `$eq`, `$in`).
# Refer to W&B documentation for the full filter specification.
```
* **Get Run History Keys:** (Run is not a collection, historyKeys is scalar)
```graphql
query RunHistoryKeys($entity: String!, $project: String!, $runName: String!) {
project(name: $project, entityName: $entity) {
run(name: $runName) {
id
name
historyKeys # Returns ["metric1", "metric2", ...]
}
}
}
```
```python
variables = {"entity": "my-entity", "project": "my-project", "runName": "run-abc"}
```
* **Get Specific Run History Data:** (Uses `sampledHistory` for specific keys)
```graphql
# Corrected: Use specs argument
query RunHistorySampled($entity: String!, $project: String!, $runName: String!, $specs: [JSONString!]!) {
project(name: $project, entityName: $entity) {
run(name: $runName) {
id
name
# Use sampledHistory with specs to get actual values for specific keys
sampledHistory(specs: $specs) {
step # The step number
timestamp # Timestamp of the log
item # JSON string containing {key: value} for requested keys at this step
}
}
}
}
```
```python
# Corrected: Define specs variable with escaped JSON string literal for keys
variables = {
"entity": "my-entity",
"project": "my-project",
"runName": "run-abc",
"specs": ["{\"keys\": [\"loss\", \"val_accuracy\"]}}"] # List containing escaped JSON string
}
# Note: sampledHistory returns rows where *at least one* of the specified keys was logged.
# The 'item' field is a JSON string, you'll need to parse it (e.g., json.loads(row['item']))
# to get the actual key-value pairs for that step. It might not contain all requested keys
# if they weren't logged together at that specific step.
```
* **Get Run by Display Name:** (Requires filtering and pagination structure)
```graphql
# Note: Querying *runs* collection and filtering, not the singular run(name:...) field
query GetRunByDisplayName($project: String!, $entity: String!, $displayNameFilter: JSONString) {
project(name: $project, entityName: $entity) {
# Filter the runs collection by displayName
runs(first: 1, filters: $displayNameFilter) {
edges {
# Select desired fields from the node (the run)
node { id name displayName state createdAt summaryMetrics }
}
# Required pageInfo for collections
pageInfo { endCursor hasNextPage }
}
}
}
```
```python
# Use json.dumps for the filters argument
import json
target_display_name = "my-experiment-run-123"
variables = {
"entity": "my-entity",
"project": "my-project",
# Filter for the specific display name
"displayNameFilter": json.dumps({"displayName": {"$eq": target_display_name}})
# W&B filter syntax might vary slightly, check docs if needed. Common is {"field": "value"} or {"field": {"$operator": "value"}}
}
# Note: This finds runs where displayName *exactly* matches.
# It might return multiple runs if display names are not unique.
# The `name` field (often the run ID like 'gtng2y4l') is guaranteed unique per project.
# Use `run(name: $runId)` if you know the unique run ID ('name').
```
* **Get Artifact Details:** (Artifact is not a collection, but `files` is)
```graphql
query ArtifactDetails($entity: String!, $project: String!, $artifactName: String!) {
project(name: $project, entityName: $entity) {
artifact(name: $artifactName) { # Name format often 'artifact-name:version' or 'artifact-name:alias'
id
digest
description
state
size
createdAt
metadata # JSON String
aliases { alias } # Corrected: Use 'alias' field instead of 'name'
files { # Files is a collection, requires pagination structure
edges {
node { name url digest } # Corrected: Removed 'size' from File fields
}
pageInfo { endCursor hasNextPage } # Required for files collection
}
}
}
}
```
```python
variables = {"entity": "my-entity", "project": "my-project", "artifactName": "my-dataset:v3"}
```
* **Get Current User Info (Viewer):** (No variables needed)
```graphql
query GetViewerInfo {
viewer {
id
username
email
entity
}
}
```
```python
# No variables needed for this query
variables = {}
```
**Troubleshooting Common Errors:**
* `"Cannot query field 'summary' on type 'Run'"`: Use the `summaryMetrics` field instead of `summary`. It returns a JSON string containing the summary dictionary.
* `"Argument 'filters' has invalid value ... Expected type 'JSONString'"`: Ensure the `filters` argument in your `variables` is a JSON formatted *string*, likely created using `json.dumps()`. Also check the *content* of the filter string for valid W&B filter syntax.
* `"400 Client Error: Bad Request"` (especially when using filters): Double-check the *syntax* inside your `filters` JSON string. Ensure operators (`$eq`, `$gt`, etc.) and structure are valid for the W&B API. Invalid field names or operators within the filter string can cause this.
* `"Unknown argument 'direction' on field 'runs'"`: Control sort direction using `+` (ascending) or `-` (descending) prefixes in the `order` argument string (e.g., `order: "-createdAt"`), not with a separate `direction` argument.
* Errors related to `history` (e.g., `"Unknown argument 'keys' on field 'history'"` or `"Field 'history' must not have a selection..."`): To get *available* metric keys, query the `historyKeys` field (returns `[String!]`). To get *time-series data* for specific keys, use the `sampledHistory(keys: [...])` field as shown in the examples; it returns structured data points. The simple `history` field might return raw data unsuitable for direct querying or is deprecated.
* `"Query doesn't follow the W&B connection pattern"`: Ensure any field returning a list/collection (like `runs`, `files`, `artifacts`, etc.) includes the full `edges { node { ... } } pageInfo { endCursor hasNextPage }` structure. This is mandatory for pagination.
* `"Field must not have a selection"` / `"Field must have a selection"`: Check if the field you are querying is a scalar type (like `String`, `Int`, `JSONString`, `[String!]`) which cannot have sub-fields selected, or an object type which requires you to select sub-fields.
* `"Cannot query field 'step' on type 'Run'"`: The `Run` type does not have a direct `step` field. To find the maximum step count or total steps logged, query the `summaryMetrics` field (look for a key like `_step` or similar in the returned JSON string) or use the `historyLineCount` field which indicates the total number of history rows logged (often corresponding to steps).
**Notes:**
* Refer to the official W&B GraphQL schema (via introspection or documentation) for the most up-to-date field names, types, and available filters/arguments.
* Structure your query to request only the necessary data fields to minimize response size and improve performance.
* **Sorting:** Use the `order` parameter string. Prefix with `+` for ascending, `-` for descending (default).
Common sortable fields: `createdAt`, `updatedAt`, `heartbeatAt`, `config.*`, `summary_metrics.*`.
* Handle potential errors in the returned dictionary (e.g., check for an 'errors' key in the response).
"""
def find_paginated_collections(
obj: Dict, current_path: Optional[List[str]] = None
) -> List[List[str]]:
"""Find collections in a response that follow the W&B connection pattern. Returns List[List[str]]."""
# Ensure this implementation correctly builds and returns List[List[str]]
if current_path is None:
current_path = []
collections = []
if isinstance(obj, dict):
if (
"edges" in obj
and "pageInfo" in obj
and isinstance(obj.get("edges"), list)
and isinstance(obj.get("pageInfo"), dict)
and "hasNextPage" in obj.get("pageInfo", {})
and "endCursor" in obj.get("pageInfo", {})
):
collections.append(list(current_path)) # Correct: append list path
# Recurse correctly
for key, value in obj.items():
current_path.append(key)
collections.extend(find_paginated_collections(value, current_path))
current_path.pop()
elif isinstance(obj, list):
for item in obj:
collections.extend(find_paginated_collections(item, current_path))
return collections
def get_nested_value(obj: Dict, path: list[str]) -> Optional[Any]:
"""Get a value from a nested dictionary using a list of keys (path)."""
current = obj
# Iterate directly over the list path
for key in path:
if not isinstance(current, dict) or key not in current:
return None
current = current[key]
return current
def query_paginated_wandb_gql(
query: str,
variables: Optional[Dict[str, Any]] = None,
max_items: int = 100,
items_per_page: int = 50,
) -> Dict[str, Any]:
"""
Execute a GraphQL query against the W&B API with pagination support using AST modification.
Handles a single paginated field detected via the connection pattern.
Modifies the result dictionary in-place.
Args:
query: The GraphQL query string. MUST include pageInfo{hasNextPage, endCursor} for paginated fields.
variables: Variables to pass to the GraphQL query.
max_items: Maximum number of items to fetch across all pages (default: 100).
items_per_page: Number of items to request per page (default: 20).
deduplicate: Whether to deduplicate nodes by ID across pages (default: True).
Returns:
The aggregated GraphQL response dictionary.
"""
result_dict = {}
api = None
limit_key = None
try:
# Use API key from environment (set by auth middleware for HTTP, or by user for STDIO)
# Get API instance with proper key handling
from wandb_mcp_server.api_client import get_wandb_api
api = get_wandb_api()
logger.info(
"--- Inside query_paginated_wandb_gql: Step 0: Execute Initial Query ---"
)
# Determine limit key and set initial page vars
page1_vars_func = variables.copy() if variables is not None else {}
limit_key = None
for k in page1_vars_func:
if k.lower() in ["limit", "first", "count"]:
limit_key = k
break
if limit_key:
# Ensure first page uses items_per_page if limit is too high or missing
page1_vars_func[limit_key] = min(
items_per_page, page1_vars_func.get(limit_key) or items_per_page
)
else:
limit_key = "limit"
page1_vars_func[limit_key] = items_per_page
logger.debug(
f"No limit variable found in input, adding '{limit_key}={items_per_page}'"
)
# Parse for execution
try:
parsed_initial_query = gql(query.strip())
except Exception as e:
logger.error(f"Failed to parse initial query with wandb_gql: {e}")
return {"errors": [{"message": f"Failed to parse initial query: {e}"}]}
# Execute initial query
try:
result1 = api.client.execute(
parsed_initial_query, variable_values=page1_vars_func
)
result_dict = copy.deepcopy(result1) # Work on a copy
if "errors" in result_dict:
logger.error(
f"GraphQL errors in initial response: {result_dict['errors']}"
)
return result_dict # Return errors if found
except Exception as e:
logger.error(f"Failed to execute initial GraphQL query: {e}", exc_info=True)
return {"errors": [{"message": f"Failed to execute initial query: {e}"}]}
# Find Collections
detected_paths = find_paginated_collections(result_dict)
if not detected_paths:
logger.info("No paginated paths detected. Returning initial result.")
return result_dict
# --- Use the first detected path ---
# TODO: Enhance to handle multiple paths if necessary
path_to_paginate = detected_paths[0]
logger.info(f"Using path for pagination: {'/'.join(path_to_paginate)}")
# Extract page 1 data
runs_data1 = get_nested_value(result_dict, path_to_paginate)
if runs_data1 is None:
logger.warning(
f"Could not extract data for pagination path {'/'.join(path_to_paginate)}. Returning initial result."
)
return result_dict
page_info1 = get_nested_value(runs_data1, ["pageInfo"])
if page_info1 is None:
logger.warning(
f"Could not extract pageInfo for pagination path {'/'.join(path_to_paginate)}. Returning initial result."
)
return result_dict
cursor = page_info1.get("endCursor")
has_next = page_info1.get("hasNextPage")
initial_edges = runs_data1.get("edges", [])
logging.info(f"Page 1 Results: {len(initial_edges)} runs.")
logging.info(f"Page 1 PageInfo: {page_info1}")
# Deduplicate initial edges and update result_dict
seen_ids = set()
current_edge_count = 0
temp_initial_edges = []
if initial_edges:
for edge in initial_edges:
try:
# Check max items even on page 1 relative to the limit
if current_edge_count >= max_items:
break
node_id = edge["node"]["id"]
if node_id not in seen_ids:
seen_ids.add(node_id)
temp_initial_edges.append(edge)
current_edge_count += 1
except (KeyError, TypeError):
if current_edge_count < max_items:
temp_initial_edges.append(edge)
current_edge_count += 1
# Update the edges in the result_dict
target_collection_dict = get_nested_value(result_dict, path_to_paginate)
if target_collection_dict:
target_collection_dict["edges"] = temp_initial_edges[
:max_items
] # Ensure initial list respects max_items
current_edge_count = len(target_collection_dict["edges"])
logging.info(
f"Stored {current_edge_count} unique edges after page 1 (max: {max_items})."
)
if not has_next or not cursor or current_edge_count >= max_items:
logger.info(
"No further pages needed based on page 1 info or max_items reached."
)
# Ensure final pageInfo reflects reality
target_pi_dict = get_nested_value(
result_dict, path_to_paginate + ["pageInfo"]
)
if target_pi_dict:
target_pi_dict["hasNextPage"] = False
return result_dict
# Generate Paginated Query String
logging.info("\n--- Generating Paginated Query String --- ")
generated_paginated_query_string = None
after_variable_name = "after" # Standard name
try:
initial_ast = parse(query.strip())
visitor = AddPaginationArgsVisitor(
field_paths=detected_paths,
first_variable_name=limit_key,
after_variable_name=after_variable_name,
)
modified_ast = gql_visitor.visit(copy.deepcopy(initial_ast), visitor)
generated_paginated_query_string = gql_printer.print_ast(modified_ast)
logger.info("AST modification and printing successful.")
except Exception as e:
logger.error(f"Failed to generate query string via AST: {e}", exc_info=True)
return result_dict # Return what we have if generation fails
if generated_paginated_query_string is None:
return result_dict
logging.info(
"\n--- Loop: Execute, Deduplicate, Aggregate In-Place, Check Limit ---"
)
page_num = 1
current_cursor = cursor
current_has_next = has_next
final_page_info = page_info1
while current_has_next:
if current_edge_count >= max_items:
logging.info(f"Reached max_items ({max_items}). Stopping loop.")
final_page_info = {**final_page_info, "hasNextPage": False}
break
page_num += 1
logging.info(f"\nFetching Page {page_num}...")
page_vars = (
variables.copy() if variables is not None else {}
) # Start with original vars
page_vars[limit_key] = items_per_page # Set correct page size
page_vars[after_variable_name] = current_cursor # Set cursor
try:
# Parse and execute for the current page
parsed_generated = gql(generated_paginated_query_string)
logging.info(
f"Executing generated query for page {page_num} with vars: {page_vars}"
)
result_page = api.client.execute(
parsed_generated, variable_values=page_vars
)
if "errors" in result_page:
logger.error(
f"GraphQL errors on page {page_num}: {result_page['errors']}. Stopping pagination."
)
current_has_next = False
final_page_info = {
**final_page_info,
"hasNextPage": False,
} # Update page info on error
continue # Go to end of loop
runs_data = get_nested_value(result_page, path_to_paginate)
if runs_data is None:
logging.warning(
f"Could not get data for path {'/'.join(path_to_paginate)} on page {page_num}. Stopping."
)
current_has_next = False
continue
else:
edges_this_page = get_nested_value(runs_data, ["edges"]) or []
page_info = get_nested_value(runs_data, ["pageInfo"]) or {}
final_page_info = page_info # Store latest page info
logging.info(
f"Result (Page {page_num}): {len(edges_this_page)} runs returned."
)
logging.info(f"Page Info (Page {page_num}): {page_info}")
# Deduplicate & Find edges to append
new_edges_for_aggregation = []
duplicates_skipped = 0
if edges_this_page:
for edge in edges_this_page:
if (
current_edge_count + len(new_edges_for_aggregation)
>= max_items
):
logging.info(
f"Max items ({max_items}) reached mid-page {page_num}."
)
final_page_info = {**final_page_info, "hasNextPage": False}
current_has_next = False
break
try:
node_id = edge["node"]["id"]
if node_id not in seen_ids:
seen_ids.add(node_id)
new_edges_for_aggregation.append(edge)
else:
duplicates_skipped += 1
except (KeyError, TypeError):
new_edges_for_aggregation.append(edge)
if duplicates_skipped > 0:
logging.info(
f"Skipped {duplicates_skipped} duplicate edges on page {page_num}."
)
# Append new unique edges IN-PLACE
if new_edges_for_aggregation:
target_collection_dict_inplace = get_nested_value(
result_dict, path_to_paginate
)
if target_collection_dict_inplace and isinstance(
target_collection_dict_inplace.get("edges"), list
):
target_collection_dict_inplace["edges"].extend(
new_edges_for_aggregation
)
current_edge_count = len(
target_collection_dict_inplace["edges"]
)
logging.info(
f"Appended {len(new_edges_for_aggregation)} new edges. Total unique edges: {current_edge_count}"
)
else:
logging.error(
"Could not find target edges list in result_dict to append in-place."
)
current_has_next = False
else:
if len(edges_this_page) > 0:
logging.info(
"No new unique edges found on page {page_num} after deduplication."
)
else:
logging.info(
"No edges returned on page {page_num} to aggregate."
)
else:
logging.info("No edges returned on page {page_num} to aggregate.")
# Update cursor and has_next for next loop iteration (or final state)
current_cursor = final_page_info.get("endCursor")
# Respect hasNextPage from API unless loop was broken early by max_items or errors
if current_has_next: # Only update if loop didn't break mid-page
current_has_next = final_page_info.get("hasNextPage", False)
# Safety checks
if current_has_next and not current_cursor:
logging.warning(
"hasNextPage is true but no endCursor received. Stopping loop."
)
current_has_next = False
if not edges_this_page:
logging.warning(
f"No edges received for page {page_num}. Stopping loop."
)
current_has_next = False
except Exception as e:
logging.error(
f"Execution failed for page {page_num}: {e}", exc_info=True
)
current_has_next = False # Stop loop on error
logging.info(f"\n--- Pagination Loop Finished after page {page_num} ---")
logging.info(f"Final aggregated edge count: {current_edge_count}")
# Update the final pageInfo in the result dictionary
target_collection_dict_final = get_nested_value(result_dict, path_to_paginate)
if target_collection_dict_final:
target_collection_dict_final["pageInfo"] = final_page_info
logging.info(f"Updated final pageInfo: {final_page_info}")
return result_dict # Return the modified dictionary
except Exception as e:
error_message = f"Critical error in paginated GraphQL query function: {str(e)}\n{traceback.format_exc()}"
logger.error(error_message)
# Return original dict if possible, else error structure
if result_dict:
if "errors" not in result_dict:
result_dict["errors"] = []
result_dict["errors"].append(
{"message": "Pagination failed", "details": str(e)}
)
return result_dict
else:
return {
"errors": [
{"message": "Pagination failed catastrophically", "details": str(e)}
]
}
class AddPaginationArgsVisitor(gql_visitor.Visitor):
"""Adds first/after args and variables"""
def __init__(
self, field_paths, first_variable_name="limit", after_variable_name="after"
):
super().__init__()
self.field_paths = set(tuple(p) for p in field_paths)
self.first_variable_name = first_variable_name
self.after_variable_name = after_variable_name
self.current_path = []
self.modified_operation = False
def enter_field(self, node, key, parent, path, ancestors):
field_name = node.alias.value if node.alias else node.name.value
self.current_path.append(field_name)
current_path_tuple = tuple(self.current_path)
if current_path_tuple in self.field_paths:
existing_args = list(node.arguments)
args_changed = False
has_first = any(arg.name.value == "first" for arg in existing_args)
if not has_first:
# Defaulting variable name to 'limit' if not found, might need refinement
limit_var_node = gql_ast.VariableNode(
name=gql_ast.NameNode(value=self.first_variable_name)
)
existing_args.append(
gql_ast.ArgumentNode(
name=gql_ast.NameNode(value="first"), value=limit_var_node
)
)
args_changed = True
has_after = any(arg.name.value == "after" for arg in existing_args)
if not has_after:
existing_args.append(
gql_ast.ArgumentNode(
name=gql_ast.NameNode(value="after"),
value=gql_ast.VariableNode(
name=gql_ast.NameNode(value=self.after_variable_name)
),
)
)
args_changed = True
if args_changed:
node.arguments = tuple(existing_args)
def leave_field(self, node, key, parent, path, ancestors):
if self.current_path:
self.current_path.pop()
def enter_operation_definition(self, node, key, parent, path, ancestors):
if self.modified_operation:
return
existing_vars = {var.variable.name.value for var in node.variable_definitions}
new_defs_list = list(node.variable_definitions)
defs_changed = False
# Determine limit variable name from existing vars if possible, else default
current_limit_var = self.first_variable_name # Default
for var_name in existing_vars:
if var_name.lower() in ["limit", "first", "count"]:
current_limit_var = var_name
break
if current_limit_var not in existing_vars:
new_defs_list.append(
gql_ast.VariableDefinitionNode(
variable=gql_ast.VariableNode(
name=gql_ast.NameNode(value=current_limit_var)
),
type=gql_ast.NamedTypeNode(name=gql_ast.NameNode(value="Int")),
)
)
defs_changed = True
if self.after_variable_name not in existing_vars:
new_defs_list.append(
gql_ast.VariableDefinitionNode(
variable=gql_ast.VariableNode(
name=gql_ast.NameNode(value=self.after_variable_name)
),
type=gql_ast.NamedTypeNode(name=gql_ast.NameNode(value="String")),
)
)
defs_changed = True
if defs_changed:
node.variable_definitions = tuple(new_defs_list)
self.modified_operation = True