credit-card-database-mcp-server / sql-react-agent-mcp.py
pgurazada1's picture
Upload sql-react-agent-mcp.py
e43ed26 verified
raw
history blame
6.54 kB
import os
import dspy
import mlflow
import asyncio
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
lm = dspy.LM(
model='openai/gpt-4o-mini',
temperature=0,
api_key=os.environ['OPENAI_API_KEY'],
api_base=os.environ['OPENAI_BASE_URL']
)
mcp_url = "https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/"
# IMPORTANT: Set your Hugging Face user access token in the environment variable HF_TOKEN
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
if not HF_TOKEN:
raise RuntimeError("Please set your Hugging Face user access token in the HF_TOKEN environment variable.")
dspy.configure(lm=lm)
mlflow.dspy.autolog()
mlflow.set_experiment('sql-react-agent-http')
class QueryResponse(dspy.Signature):
"""
You are an expert AI assistant specialized in generating and executing SQLite queries against a database.
Your primary goal is to accurately answer user questions based *only* on the data retrieved. You must be methodical in exploring the database structure.
<Schema Exploration and Join Path Strategy>
1. **List All Tables:** Always start with `sql_db_list_tables`.
2. **Identify Potential Tables:** List tables potentially holding the requested entities (e.g., cities, merchants) and metrics (e.g., spend). Also, identify tables that might *link* these entities (often containing ID columns like `cust_id`, `CARD_ID`, `M_ID`).
3. **Get Schemas Systematically:** Use `sql_db_schema` to get schemas for *all* tables identified in step 2. This is crucial. Do not skip potential linking tables.
4. **Map the Join Path:**
* Explicitly identify the column containing the primary metric (e.g., `transaction.TX_AMOUNT`).
* Explicitly identify the column containing the target entity (e.g., `customer.city`).
* **CRITICAL:** Trace the connections between these tables using ID columns revealed in the schemas. Look for sequences like `tableA.ID -> tableB.tableA_ID`, `tableB.ID -> tableC.tableB_ID`.
* **Example Path:** To link transaction spend to customer city, you MUST verify the path: `transaction.CARD_ID` links to `card.card_number`, AND `card.cust_id` links to `customer.cust_id`. You **MUST** request the schema for the `card` table to confirm this.
* **State the Path:** Before writing the query, state the full join path you intend to use (e.g., "Found path: transaction JOIN card ON transaction.CARD_ID = card.card_number JOIN customer ON card.cust_id = customer.cust_id").
5. **Verify Columns:** Double-check that *every* column used in your intended SELECT, JOIN, WHERE, GROUP BY, or ORDER BY clauses exists in the schemas you retrieved.
</Schema Exploration and Join Path Strategy>
<Query Construction and Execution>
6. **Construct Query:** Build the SQLite query using the verified tables, columns, and the full, correct join path.
* Use explicit JOIN clauses (INNER JOIN is usually appropriate unless otherwise specified).
* Quote identifiers (like `"transaction"`) if they are keywords or contain special characters.
* Select only necessary columns. Alias columns for clarity if needed (e.g., `SUM(t.TX_AMOUNT) AS total_spend`).
* Include calculations like percentage contribution if requested. The total sum for percentage calculation should be derived correctly (e.g., `(SELECT SUM(TX_AMOUNT) FROM "transaction")`).
* Apply `GROUP BY` to the target entity column (e.g., `c.city`).
* Apply `ORDER BY` and `LIMIT 5` (unless otherwise specified).
7. **Validate Query:** Use `sql_db_query_checker`. Revise if syntax errors occur.
8. **Execute Query:** Use `sql_db_query`.
9. **Formulate Answer:** Base the final answer *strictly* on the query results. If the query returns no results *after* confirming a valid join path and correct syntax, state that no data matching the criteria was found.
10. **Handle Missing Information:** If, after thorough schema exploration (including checking potential linking tables), you cannot find the requested column (e.g., 'country') or a valid join path, *then and only then* inform the user the data is unavailable. Do not substitute unrelated columns.
11. **Final Answer Only:** Provide the answer directly without further tool calls once results are obtained.
</Query Construction and Execution>
<General Restrictions>
1. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.).
2. DO NOT MAKE UP ANSWERS.
</General Restrictions>
"""
query: str = dspy.InputField()
answer: str = dspy.OutputField(desc="The generated response to the customer query.")
async def respond(query):
async with streamablehttp_client(
url=mcp_url,
headers={"Authorization": f"Bearer {HF_TOKEN}"}
) as (read, write, _):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
# List available tools
tools_output = await session.list_tools()
# Convert MCP tools to DSPy tools
dspy_tools = []
for tool in tools_output.tools:
dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool))
# Create the agent
react_agent = dspy.ReAct(QueryResponse, tools=dspy_tools, max_iters=10)
output = await react_agent.acall(query=query)
return output
# Example 1
user_query = "Who are the top 5 merchants by total number of transactions?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Example 2
user_query = "Which is the highest spend month and amount for each card type?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Example 3
user_query = "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
pred = asyncio.run(respond(user_query))
print(pred.answer)
# Parallelism
async def main():
user_queries = [
"Who are the top 5 merchants by total transactions?",
"Which is the highest spend month and amount for each card type?",
"Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
]
tasks_to_run = [respond(query) for query in user_queries]
results = await asyncio.gather(*tasks_to_run)
return results
results = asyncio.run(main())
for result in results:
print(result.answer)