pgurazada1's picture
Update server.py
870bab9 verified
raw
history blame
8.42 kB
# import os
# import uvicorn
# import inspect
# from mcp.server.fastmcp import FastMCP
# from starlette.requests import Request
# from starlette.responses import PlainTextResponse, JSONResponse
# from langchain_community.utilities import SQLDatabase
# from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(
# api_key=os.environ.get('OPENAI_API_KEY', None),
# base_url=os.environ['OPENAI_BASE_URL'],
# model='gpt-4o-mini',
# temperature=0
# )
# # Create an MCP server and the tool registry
# mcp = FastMCP("Credit Card Database Server")
# tool_registry = []
# def register_tool(fn):
# """Decorator to register tool metadata and with MCP."""
# # Register with MCP
# mcp.tool()(fn)
# # Save metadata
# sig = inspect.signature(fn)
# params = [
# {
# "name": param.name,
# "type": str(param.annotation) if param.annotation is not inspect._empty else "Any",
# "default": param.default if param.default is not inspect._empty else None,
# }
# for param in sig.parameters.values()
# ]
# tool_registry.append({
# "name": fn.__name__,
# "description": fn.__doc__.strip() if fn.__doc__ else "",
# "parameters": params,
# })
# return fn
# credit_card_db = SQLDatabase.from_uri(r"sqlite:///data/ccms.db")
# query_checker_tool = QuerySQLCheckerTool(db=credit_card_db, llm=llm)
# @mcp.custom_route("/", methods=["GET"])
# async def home(request: Request) -> PlainTextResponse:
# return PlainTextResponse(
# """
# Credit Card Database MCP Server
# ----
# Use the following URL to connect with this server
# https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/
# Access the following URL for a list of tools and their documentation.
# https://pgurazada1-credit-card-database-mcp-server.hf.space/tools/
# """
# )
# @register_tool
# def sql_db_list_tables():
# """
# Returns a comma-separated list of table names in the database.
# """
# return credit_card_db.get_usable_table_names()
# @register_tool
# def sql_db_schema(table_names: list[str]) -> str:
# """
# Input 'table_names_str' is a comma-separated string of table names.
# Returns the DDL SQL schema for these tables.
# """
# return credit_card_db.get_table_info(table_names)
# @register_tool
# def sql_db_query_checker(query: str) -> str:
# """
# Input 'query' is a SQL query string.
# Checks if the query is valid.
# If the query is valid, it returns the original query.
# If the query is not valid, it returns the corrected query.
# This tool is used to ensure the query is valid before executing it.
# """
# return query_checker_tool.run(query)
# @register_tool
# def sql_db_query(query: str) -> str:
# """
# Input 'query' is a SQL query string.
# Executes the query (SELECT only) and returns the result.
# """
# return credit_card_db.run(query)
# @mcp.custom_route("/tools", methods=["GET"])
# async def list_tools(request: Request) -> JSONResponse:
# """Return all registered tool metadata as JSON."""
# return JSONResponse(tool_registry)
# if __name__ == "__main__":
# uvicorn.run(mcp.streamable_http_app, host="0.0.0.0", port=8000)
##--- version with Google OAuth---
import os
import uvicorn
import inspect
from mcp.server.fastmcp import FastMCP
from starlette.requests import Request
from starlette.responses import PlainTextResponse, JSONResponse, RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from fastapi import FastAPI
from authlib.integrations.starlette_client import OAuth, OAuthError
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
api_key=os.environ.get('OPENAI_API_KEY', None),
base_url=os.environ['OPENAI_BASE_URL'],
model='gpt-4o-mini',
temperature=0
)
credit_card_db = SQLDatabase.from_uri(r"sqlite:///data/ccms.db")
query_checker_tool = QuerySQLCheckerTool(db=credit_card_db, llm=llm)
# Google OAuth config - set these in your environment
GOOGLE_CLIENT_ID = os.environ["GOOGLE_CLIENT_ID"]
GOOGLE_CLIENT_SECRET = os.environ["GOOGLE_CLIENT_SECRET"]
SECRET_KEY = os.environ.get("SESSION_SECRET", "supersecret") # should be set securely
# FastAPI app & session middleware
app = FastAPI()
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
# Set up OAuth
oauth = OAuth()
CONF_URL = 'https://accounts.google.com/.well-known/openid-configuration'
oauth.register(
name='google',
client_id=GOOGLE_CLIENT_ID,
client_secret=GOOGLE_CLIENT_SECRET,
server_metadata_url=CONF_URL,
client_kwargs={
'scope': 'openid email profile'
}
)
# MCP
mcp = FastMCP("Credit Card Database Server")
tool_registry = []
def register_tool(fn):
mcp.tool()(fn)
sig = inspect.signature(fn)
params = [
{
"name": param.name,
"type": str(param.annotation) if param.annotation is not inspect._empty else "Any",
"default": param.default if param.default is not inspect._empty else None,
}
for param in sig.parameters.values()
]
tool_registry.append({
"name": fn.__name__,
"description": fn.__doc__.strip() if fn.__doc__ else "",
"parameters": params,
})
return fn
@register_tool
def sql_db_list_tables():
"""Returns a comma-separated list of table names in the database."""
return credit_card_db.get_usable_table_names()
@register_tool
def sql_db_schema(table_names: list[str]) -> str:
"""Input 'table_names_str' is a comma-separated string of table names. Returns the DDL SQL schema for these tables."""
return credit_card_db.get_table_info(table_names)
@register_tool
def sql_db_query_checker(query: str) -> str:
"""Checks if the query is valid. If valid, returns the original query; if not, returns the corrected query."""
return query_checker_tool.run(query)
@register_tool
def sql_db_query(query: str) -> str:
"""Executes the query (SELECT only) and returns the result."""
return credit_card_db.run(query)
@app.route("/")
async def home(request: Request):
user = request.session.get("user")
if user:
username = user["name"]
return PlainTextResponse(f"Hello, {username}! You are logged in with Google.\nAccess /mcp/, /tools/")
else:
return PlainTextResponse("Hello! Please go to https://pgurazada1-credit-card-database-mcp-server.hf.space/login to sign in with Google.")
@app.route("/login")
async def login(request: Request):
redirect_uri = str(request.url_for('auth')).replace("http://", "https://")
print("Redirect URI:", redirect_uri)
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.route("/auth")
async def auth(request: Request):
token = await oauth.google.authorize_access_token(request)
token_dict = dict(token)
print("TOKEN:", token_dict)
user_info = token_dict.get("userinfo")
if not user_info:
# fallback: fetch from userinfo endpoint if not present
user_info = await oauth.google.userinfo(request, token=token_dict)
request.session["user"] = dict(user_info)
return RedirectResponse(url="/")
@app.route("/logout")
async def logout(request: Request):
request.session.pop("user", None)
return RedirectResponse(url="/")
# Protect MCP endpoints with authentication
def require_google_auth(request: Request):
user = request.session.get("user")
if not user:
return RedirectResponse(url="/login")
return user
@app.route("/mcp/{path:path}", methods=["GET", "POST"])
async def mcp_proxy(request: Request):
user = require_google_auth(request)
if isinstance(user, RedirectResponse):
return user
# forward request to MCP server (adapt as needed)
return await mcp.streamable_http_app(request.scope, request.receive, request.send)
@app.route("/tools")
async def list_tools(request: Request):
user = require_google_auth(request)
if isinstance(user, RedirectResponse):
return user
return JSONResponse(tool_registry)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)