|
|
import os |
|
|
import dspy |
|
|
import mlflow |
|
|
import asyncio |
|
|
|
|
|
from mcp import ClientSession |
|
|
from mcp.client.streamable_http import streamablehttp_client |
|
|
|
|
|
lm = dspy.LM( |
|
|
model='openai/gpt-4o-mini', |
|
|
temperature=0, |
|
|
api_key=os.environ['OPENAI_API_KEY'], |
|
|
api_base=os.environ['OPENAI_BASE_URL'] |
|
|
) |
|
|
|
|
|
mcp_url = "https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY") |
|
|
if not HF_TOKEN: |
|
|
raise RuntimeError("Please set your Hugging Face user access token in the HF_TOKEN environment variable.") |
|
|
|
|
|
|
|
|
dspy.configure(lm=lm) |
|
|
|
|
|
mlflow.dspy.autolog() |
|
|
mlflow.set_experiment('sql-react-agent-http') |
|
|
|
|
|
|
|
|
class QueryResponse(dspy.Signature): |
|
|
""" |
|
|
You are an expert AI assistant specialized in generating and executing SQLite queries against a database. |
|
|
Your primary goal is to accurately answer user questions based *only* on the data retrieved. You must be methodical in exploring the database structure. |
|
|
|
|
|
<Schema Exploration and Join Path Strategy> |
|
|
1. **List All Tables:** Always start with `sql_db_list_tables`. |
|
|
2. **Identify Potential Tables:** List tables potentially holding the requested entities (e.g., cities, merchants) and metrics (e.g., spend). Also, identify tables that might *link* these entities (often containing ID columns like `cust_id`, `CARD_ID`, `M_ID`). |
|
|
3. **Get Schemas Systematically:** Use `sql_db_schema` to get schemas for *all* tables identified in step 2. This is crucial. Do not skip potential linking tables. |
|
|
4. **Map the Join Path:** |
|
|
* Explicitly identify the column containing the primary metric (e.g., `transaction.TX_AMOUNT`). |
|
|
* Explicitly identify the column containing the target entity (e.g., `customer.city`). |
|
|
* **CRITICAL:** Trace the connections between these tables using ID columns revealed in the schemas. Look for sequences like `tableA.ID -> tableB.tableA_ID`, `tableB.ID -> tableC.tableB_ID`. |
|
|
* **Example Path:** To link transaction spend to customer city, you MUST verify the path: `transaction.CARD_ID` links to `card.card_number`, AND `card.cust_id` links to `customer.cust_id`. You **MUST** request the schema for the `card` table to confirm this. |
|
|
* **State the Path:** Before writing the query, state the full join path you intend to use (e.g., "Found path: transaction JOIN card ON transaction.CARD_ID = card.card_number JOIN customer ON card.cust_id = customer.cust_id"). |
|
|
5. **Verify Columns:** Double-check that *every* column used in your intended SELECT, JOIN, WHERE, GROUP BY, or ORDER BY clauses exists in the schemas you retrieved. |
|
|
</Schema Exploration and Join Path Strategy> |
|
|
|
|
|
<Query Construction and Execution> |
|
|
6. **Construct Query:** Build the SQLite query using the verified tables, columns, and the full, correct join path. |
|
|
* Use explicit JOIN clauses (INNER JOIN is usually appropriate unless otherwise specified). |
|
|
* Quote identifiers (like `"transaction"`) if they are keywords or contain special characters. |
|
|
* Select only necessary columns. Alias columns for clarity if needed (e.g., `SUM(t.TX_AMOUNT) AS total_spend`). |
|
|
* Include calculations like percentage contribution if requested. The total sum for percentage calculation should be derived correctly (e.g., `(SELECT SUM(TX_AMOUNT) FROM "transaction")`). |
|
|
* Apply `GROUP BY` to the target entity column (e.g., `c.city`). |
|
|
* Apply `ORDER BY` and `LIMIT 5` (unless otherwise specified). |
|
|
7. **Validate Query:** Use `sql_db_query_checker`. Revise if syntax errors occur. |
|
|
8. **Execute Query:** Use `sql_db_query`. |
|
|
9. **Formulate Answer:** Base the final answer *strictly* on the query results. If the query returns no results *after* confirming a valid join path and correct syntax, state that no data matching the criteria was found. |
|
|
10. **Handle Missing Information:** If, after thorough schema exploration (including checking potential linking tables), you cannot find the requested column (e.g., 'country') or a valid join path, *then and only then* inform the user the data is unavailable. Do not substitute unrelated columns. |
|
|
11. **Final Answer Only:** Provide the answer directly without further tool calls once results are obtained. |
|
|
</Query Construction and Execution> |
|
|
|
|
|
<General Restrictions> |
|
|
1. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.). |
|
|
2. DO NOT MAKE UP ANSWERS. |
|
|
</General Restrictions> |
|
|
""" |
|
|
|
|
|
query: str = dspy.InputField() |
|
|
answer: str = dspy.OutputField(desc="The generated response to the customer query.") |
|
|
|
|
|
|
|
|
async def respond(query): |
|
|
async with streamablehttp_client( |
|
|
url=mcp_url, |
|
|
headers={"Authorization": f"Bearer {HF_TOKEN}"} |
|
|
) as (read, write, _): |
|
|
async with ClientSession(read, write) as session: |
|
|
|
|
|
await session.initialize() |
|
|
|
|
|
tools_output = await session.list_tools() |
|
|
|
|
|
|
|
|
dspy_tools = [] |
|
|
for tool in tools_output.tools: |
|
|
dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool)) |
|
|
|
|
|
|
|
|
react_agent = dspy.ReAct(QueryResponse, tools=dspy_tools, max_iters=10) |
|
|
|
|
|
output = await react_agent.acall(query=query) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
user_query = "Who are the top 5 merchants by total number of transactions?" |
|
|
pred = asyncio.run(respond(user_query)) |
|
|
|
|
|
print(pred.answer) |
|
|
|
|
|
|
|
|
|
|
|
user_query = "Which is the highest spend month and amount for each card type?" |
|
|
pred = asyncio.run(respond(user_query)) |
|
|
|
|
|
print(pred.answer) |
|
|
|
|
|
|
|
|
|
|
|
user_query = "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?" |
|
|
pred = asyncio.run(respond(user_query)) |
|
|
|
|
|
print(pred.answer) |
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
user_queries = [ |
|
|
"Who are the top 5 merchants by total transactions?", |
|
|
"Which is the highest spend month and amount for each card type?", |
|
|
"Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?" |
|
|
] |
|
|
|
|
|
tasks_to_run = [respond(query) for query in user_queries] |
|
|
results = await asyncio.gather(*tasks_to_run) |
|
|
|
|
|
return results |
|
|
|
|
|
results = asyncio.run(main()) |
|
|
|
|
|
for result in results: |
|
|
print(result.answer) |