Spaces:
Paused
Paused
Replit Deployment
commited on
Commit
·
89ae94f
1
Parent(s):
94b87a8
Deployment from Replit
Browse files- hf_database.py +6 -1
- security_hf.py +16 -0
- src/__init__.py +3 -0
- src/api/__init__.py +3 -0
- src/api/auth.py +105 -0
- src/api/database.py +73 -0
- src/api/main.py +73 -0
- src/api/routers/auth_router.py +74 -0
- src/api/routers/scraping_router.py +161 -0
- src/api/routers/threats_router.py +217 -0
- src/api/schemas.py +310 -0
- src/api/security.py +382 -0
- src/api/services/__init__.py +3 -0
- src/api/services/alert_service.py +316 -0
- src/api/services/dark_web_content_service.py +357 -0
- src/api/services/report_service.py +436 -0
- src/api/services/search_history_service.py +609 -0
- src/api/services/subscription_service.py +681 -0
- src/api/services/threat_service.py +411 -0
- src/api/services/user_service.py +166 -0
hf_database.py
CHANGED
|
@@ -10,7 +10,12 @@ from sqlalchemy.orm import sessionmaker
|
|
| 10 |
from sqlalchemy.pool import StaticPool
|
| 11 |
from src.models.base import Base
|
| 12 |
from src.models.user import User
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Configure logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 10 |
from sqlalchemy.pool import StaticPool
|
| 11 |
from src.models.base import Base
|
| 12 |
from src.models.user import User
|
| 13 |
+
try:
|
| 14 |
+
# Try to import from src.api.security first (for local development)
|
| 15 |
+
from src.api.security import get_password_hash
|
| 16 |
+
except ImportError:
|
| 17 |
+
# Fall back to simplified security module for HF (copied during deployment)
|
| 18 |
+
from security_hf import get_password_hash
|
| 19 |
|
| 20 |
# Configure logging
|
| 21 |
logging.basicConfig(level=logging.INFO)
|
security_hf.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified security module for Hugging Face deployment.
|
| 3 |
+
Contains only the essential functions needed for HF deployment.
|
| 4 |
+
"""
|
| 5 |
+
from passlib.context import CryptContext
|
| 6 |
+
|
| 7 |
+
# Set up password hashing
|
| 8 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 9 |
+
|
| 10 |
+
def get_password_hash(password: str) -> str:
|
| 11 |
+
"""Hash a password for storage"""
|
| 12 |
+
return pwd_context.hash(password)
|
| 13 |
+
|
| 14 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 15 |
+
"""Verify a password against a hash"""
|
| 16 |
+
return pwd_context.verify(plain_password, hashed_password)
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Package initialization for src.
|
| 3 |
+
"""
|
src/api/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Package initialization for API.
|
| 3 |
+
"""
|
src/api/auth.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends, HTTPException, status
|
| 2 |
+
from fastapi.security import OAuth2PasswordBearer
|
| 3 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
+
from jose import JWTError, jwt
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from src.api.database import get_db
|
| 11 |
+
from src.api.schemas import TokenData, UserInDB
|
| 12 |
+
from src.api.services.user_service import get_user_by_username
|
| 13 |
+
|
| 14 |
+
# Configure logger
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Constants for JWT
|
| 18 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
|
| 19 |
+
ALGORITHM = "HS256"
|
| 20 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
| 21 |
+
|
| 22 |
+
# OAuth2PasswordBearer for token extraction
|
| 23 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/token")
|
| 24 |
+
|
| 25 |
+
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Create a JWT access token.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data: Dictionary of data to encode in the token
|
| 31 |
+
expires_delta: Optional expiration time delta
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
str: JWT token
|
| 35 |
+
"""
|
| 36 |
+
to_encode = data.copy()
|
| 37 |
+
|
| 38 |
+
if expires_delta:
|
| 39 |
+
expire = datetime.utcnow() + expires_delta
|
| 40 |
+
else:
|
| 41 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 42 |
+
|
| 43 |
+
to_encode.update({"exp": expire})
|
| 44 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 45 |
+
|
| 46 |
+
return encoded_jwt
|
| 47 |
+
|
| 48 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)) -> UserInDB:
|
| 49 |
+
"""
|
| 50 |
+
Get the current authenticated user based on the JWT token.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
token: JWT token
|
| 54 |
+
db: Database session
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
UserInDB: User data
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
HTTPException: If authentication fails
|
| 61 |
+
"""
|
| 62 |
+
credentials_exception = HTTPException(
|
| 63 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 64 |
+
detail="Could not validate credentials",
|
| 65 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Decode JWT
|
| 70 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 71 |
+
username: str = payload.get("sub")
|
| 72 |
+
|
| 73 |
+
if username is None:
|
| 74 |
+
raise credentials_exception
|
| 75 |
+
|
| 76 |
+
token_data = TokenData(username=username)
|
| 77 |
+
except JWTError as e:
|
| 78 |
+
logger.error(f"JWT error: {e}")
|
| 79 |
+
raise credentials_exception
|
| 80 |
+
|
| 81 |
+
# Get user from database
|
| 82 |
+
user = await get_user_by_username(db, username=token_data.username)
|
| 83 |
+
|
| 84 |
+
if user is None:
|
| 85 |
+
raise credentials_exception
|
| 86 |
+
|
| 87 |
+
return user
|
| 88 |
+
|
| 89 |
+
async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
|
| 90 |
+
"""
|
| 91 |
+
Get the current active user.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
current_user: Current authenticated user
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
UserInDB: User data
|
| 98 |
+
|
| 99 |
+
Raises:
|
| 100 |
+
HTTPException: If user is inactive
|
| 101 |
+
"""
|
| 102 |
+
if not current_user.is_active:
|
| 103 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
| 104 |
+
|
| 105 |
+
return current_user
|
src/api/database.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database configuration and setup for API.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 5 |
+
from sqlalchemy.orm import sessionmaker
|
| 6 |
+
import os
|
| 7 |
+
from typing import AsyncGenerator
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Configure logger
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Get database URL from environment (convert synchronous to async URL)
|
| 14 |
+
db_url = os.getenv("DATABASE_URL", "")
|
| 15 |
+
if db_url.startswith("postgresql://"):
|
| 16 |
+
# Remove sslmode parameter if present which causes issues with asyncpg
|
| 17 |
+
if "?" in db_url:
|
| 18 |
+
base_url, params = db_url.split("?", 1)
|
| 19 |
+
param_list = params.split("&")
|
| 20 |
+
filtered_params = [p for p in param_list if not p.startswith("sslmode=")]
|
| 21 |
+
if filtered_params:
|
| 22 |
+
db_url = f"{base_url}?{'&'.join(filtered_params)}"
|
| 23 |
+
else:
|
| 24 |
+
db_url = base_url
|
| 25 |
+
|
| 26 |
+
ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
| 27 |
+
else:
|
| 28 |
+
ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
|
| 29 |
+
|
| 30 |
+
# Create async engine
|
| 31 |
+
engine = create_async_engine(
|
| 32 |
+
ASYNC_DATABASE_URL,
|
| 33 |
+
echo=False, # Set to True for debugging
|
| 34 |
+
future=True,
|
| 35 |
+
pool_size=5,
|
| 36 |
+
max_overflow=10
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Create async session factory
|
| 40 |
+
async_session = sessionmaker(
|
| 41 |
+
engine,
|
| 42 |
+
class_=AsyncSession,
|
| 43 |
+
expire_on_commit=False
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
| 47 |
+
"""
|
| 48 |
+
Get database session generator.
|
| 49 |
+
|
| 50 |
+
Yields:
|
| 51 |
+
AsyncSession: Database session
|
| 52 |
+
"""
|
| 53 |
+
session = async_session()
|
| 54 |
+
try:
|
| 55 |
+
yield session
|
| 56 |
+
await session.commit()
|
| 57 |
+
except Exception as e:
|
| 58 |
+
await session.rollback()
|
| 59 |
+
logger.error(f"Database error: {e}")
|
| 60 |
+
raise
|
| 61 |
+
finally:
|
| 62 |
+
await session.close()
|
| 63 |
+
|
| 64 |
+
# Dependency for getting DB session
|
| 65 |
+
async def get_db_session() -> AsyncSession:
|
| 66 |
+
"""
|
| 67 |
+
Get database session.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
AsyncSession: Database session
|
| 71 |
+
"""
|
| 72 |
+
async with async_session() as session:
|
| 73 |
+
return session
|
src/api/main.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Depends, HTTPException, status
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from src.api.database import get_db
|
| 7 |
+
from src.api.auth import get_current_user
|
| 8 |
+
from src.api.routers import threats_router, indicators_router, auth_router, admin_router
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logging.basicConfig(
|
| 12 |
+
level=logging.INFO,
|
| 13 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 14 |
+
handlers=[
|
| 15 |
+
logging.StreamHandler(),
|
| 16 |
+
logging.FileHandler("app.log")
|
| 17 |
+
]
|
| 18 |
+
)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Create FastAPI app
|
| 22 |
+
app = FastAPI(
|
| 23 |
+
title="CyberForge OSINT API",
|
| 24 |
+
description="API for Dark Web OSINT platform",
|
| 25 |
+
version="1.0.0"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Add CORS middleware
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=["*"], # Update for production
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Include routers for different endpoints
|
| 38 |
+
app.include_router(auth_router.router, prefix="/api/v1")
|
| 39 |
+
app.include_router(
|
| 40 |
+
threats_router.router,
|
| 41 |
+
prefix="/api/v1/threats",
|
| 42 |
+
tags=["threats"],
|
| 43 |
+
dependencies=[Depends(get_current_user)]
|
| 44 |
+
)
|
| 45 |
+
app.include_router(
|
| 46 |
+
indicators_router.router,
|
| 47 |
+
prefix="/api/v1/indicators",
|
| 48 |
+
tags=["indicators"],
|
| 49 |
+
dependencies=[Depends(get_current_user)]
|
| 50 |
+
)
|
| 51 |
+
app.include_router(
|
| 52 |
+
admin_router.router,
|
| 53 |
+
prefix="/api/v1/admin",
|
| 54 |
+
tags=["admin"],
|
| 55 |
+
dependencies=[Depends(get_current_user)]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@app.get("/api/health")
|
| 59 |
+
async def health_check():
|
| 60 |
+
"""Health check endpoint for monitoring."""
|
| 61 |
+
return {"status": "healthy", "version": "1.0.0"}
|
| 62 |
+
|
| 63 |
+
@app.on_event("startup")
|
| 64 |
+
async def startup_event():
|
| 65 |
+
"""Event handler for application startup."""
|
| 66 |
+
logger.info("Starting the CyberForge OSINT API")
|
| 67 |
+
# Add any startup tasks here (database connection, cache warming, etc.)
|
| 68 |
+
|
| 69 |
+
@app.on_event("shutdown")
|
| 70 |
+
async def shutdown_event():
|
| 71 |
+
"""Event handler for application shutdown."""
|
| 72 |
+
logger.info("Shutting down the CyberForge OSINT API")
|
| 73 |
+
# Add any cleanup tasks here (close connections, save state, etc.)
|
src/api/routers/auth_router.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication router.
|
| 3 |
+
|
| 4 |
+
This module provides authentication endpoints for the API.
|
| 5 |
+
"""
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 10 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
| 11 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 12 |
+
|
| 13 |
+
from src.api.database import get_db
|
| 14 |
+
from src.api.security import (
|
| 15 |
+
ACCESS_TOKEN_EXPIRE_MINUTES,
|
| 16 |
+
Token,
|
| 17 |
+
UserInDB,
|
| 18 |
+
authenticate_user,
|
| 19 |
+
create_access_token,
|
| 20 |
+
get_current_active_user,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
router = APIRouter(tags=["authentication"])
|
| 24 |
+
|
| 25 |
+
@router.post("/token", response_model=Token)
|
| 26 |
+
async def login_for_access_token(
|
| 27 |
+
form_data: OAuth2PasswordRequestForm = Depends(),
|
| 28 |
+
db: AsyncSession = Depends(get_db)
|
| 29 |
+
) -> Dict[str, Any]:
|
| 30 |
+
"""
|
| 31 |
+
OAuth2 compatible token login, get an access token for future requests.
|
| 32 |
+
"""
|
| 33 |
+
user = await authenticate_user(db, form_data.username, form_data.password)
|
| 34 |
+
if not user:
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 37 |
+
detail="Incorrect username or password",
|
| 38 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 42 |
+
access_token = create_access_token(
|
| 43 |
+
data={"sub": user.username, "scopes": user.scopes},
|
| 44 |
+
expires_delta=access_token_expires
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
expires_at = datetime.utcnow() + access_token_expires
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"access_token": access_token,
|
| 51 |
+
"token_type": "bearer",
|
| 52 |
+
"expires_at": expires_at
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
@router.get("/users/me", response_model=UserInDB)
|
| 56 |
+
async def read_users_me(
|
| 57 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
| 58 |
+
) -> UserInDB:
|
| 59 |
+
"""
|
| 60 |
+
Get current user.
|
| 61 |
+
"""
|
| 62 |
+
return current_user
|
| 63 |
+
|
| 64 |
+
@router.get("/users/me/scopes")
|
| 65 |
+
async def read_own_scopes(
|
| 66 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
| 67 |
+
) -> Dict[str, Any]:
|
| 68 |
+
"""
|
| 69 |
+
Get current user's scopes.
|
| 70 |
+
"""
|
| 71 |
+
return {
|
| 72 |
+
"username": current_user.username,
|
| 73 |
+
"scopes": current_user.scopes
|
| 74 |
+
}
|
src/api/routers/scraping_router.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Body, BackgroundTasks
|
| 2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 3 |
+
from typing import List, Optional, Dict, Any
|
| 4 |
+
import logging
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
from src.api.database import get_db
|
| 8 |
+
from src.api.auth import get_current_user
|
| 9 |
+
from src.api.schemas import User, CrawlRequest, CrawlResult
|
| 10 |
+
from src.services.scraper import WebScraper, ScraperError
|
| 11 |
+
from src.services.tor_proxy import TorProxyService, TorProxyError
|
| 12 |
+
|
| 13 |
+
# Configure logger
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
router = APIRouter(
|
| 17 |
+
prefix="/scraping",
|
| 18 |
+
tags=["scraping"],
|
| 19 |
+
responses={404: {"description": "Not found"}}
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Initialize services
|
| 23 |
+
scraper = WebScraper()
|
| 24 |
+
|
| 25 |
+
@router.post("/test-tor", response_model=Dict[str, Any])
|
| 26 |
+
async def test_tor_connection(
|
| 27 |
+
current_user: User = Depends(get_current_user)
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Test Tor connection.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
current_user: Current authenticated user
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Dict[str, Any]: Connection status
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
tor_proxy = TorProxyService()
|
| 40 |
+
is_connected = await tor_proxy.check_connection()
|
| 41 |
+
|
| 42 |
+
return {
|
| 43 |
+
"status": "success",
|
| 44 |
+
"is_connected": is_connected,
|
| 45 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 46 |
+
}
|
| 47 |
+
except TorProxyError as e:
|
| 48 |
+
logger.error(f"Tor proxy error: {e}")
|
| 49 |
+
raise HTTPException(
|
| 50 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 51 |
+
detail=f"Tor proxy error: {str(e)}"
|
| 52 |
+
)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Error testing Tor connection: {e}")
|
| 55 |
+
raise HTTPException(
|
| 56 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 57 |
+
detail=f"An error occurred: {str(e)}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
@router.post("/scrape", response_model=Dict[str, Any])
|
| 61 |
+
async def scrape_page(
|
| 62 |
+
url: str,
|
| 63 |
+
use_tor: bool = Body(False),
|
| 64 |
+
current_user: User = Depends(get_current_user)
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Scrape a single page.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
url: URL to scrape
|
| 71 |
+
use_tor: Whether to use Tor proxy
|
| 72 |
+
current_user: Current authenticated user
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, Any]: Scraped content
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
result = await scraper.extract_content(url, use_tor=use_tor)
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"status": "success",
|
| 82 |
+
"data": result,
|
| 83 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 84 |
+
}
|
| 85 |
+
except ScraperError as e:
|
| 86 |
+
logger.error(f"Scraper error: {e}")
|
| 87 |
+
raise HTTPException(
|
| 88 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 89 |
+
detail=f"Scraper error: {str(e)}"
|
| 90 |
+
)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Error scraping page: {e}")
|
| 93 |
+
raise HTTPException(
|
| 94 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 95 |
+
detail=f"An error occurred: {str(e)}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
@router.post("/crawl", response_model=Dict[str, Any])
|
| 99 |
+
async def crawl_site(
|
| 100 |
+
crawl_request: CrawlRequest,
|
| 101 |
+
background_tasks: BackgroundTasks,
|
| 102 |
+
current_user: User = Depends(get_current_user)
|
| 103 |
+
):
|
| 104 |
+
"""
|
| 105 |
+
Crawl a site.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
crawl_request: Crawl request data
|
| 109 |
+
background_tasks: Background tasks
|
| 110 |
+
current_user: Current authenticated user
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Dict[str, Any]: Crawl status
|
| 114 |
+
"""
|
| 115 |
+
# For longer crawls, we add them as background tasks
|
| 116 |
+
# This prevents timeouts on the API request
|
| 117 |
+
|
| 118 |
+
# Start crawl in background
|
| 119 |
+
if crawl_request.max_depth > 1 or '.onion' in crawl_request.url:
|
| 120 |
+
# Add to background tasks
|
| 121 |
+
background_tasks.add_task(
|
| 122 |
+
scraper.crawl,
|
| 123 |
+
crawl_request.url,
|
| 124 |
+
max_depth=crawl_request.max_depth,
|
| 125 |
+
max_pages=50,
|
| 126 |
+
keyword_filter=crawl_request.keywords
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"status": "started",
|
| 131 |
+
"message": "Crawl started in background",
|
| 132 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 133 |
+
}
|
| 134 |
+
else:
|
| 135 |
+
# For simple crawls, we perform them synchronously
|
| 136 |
+
try:
|
| 137 |
+
results = await scraper.crawl(
|
| 138 |
+
crawl_request.url,
|
| 139 |
+
max_depth=crawl_request.max_depth,
|
| 140 |
+
max_pages=10,
|
| 141 |
+
keyword_filter=crawl_request.keywords
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
"status": "completed",
|
| 146 |
+
"results": results,
|
| 147 |
+
"count": len(results),
|
| 148 |
+
"timestamp": datetime.utcnow().isoformat()
|
| 149 |
+
}
|
| 150 |
+
except ScraperError as e:
|
| 151 |
+
logger.error(f"Scraper error: {e}")
|
| 152 |
+
raise HTTPException(
|
| 153 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 154 |
+
detail=f"Scraper error: {str(e)}"
|
| 155 |
+
)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Error crawling site: {e}")
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 160 |
+
detail=f"An error occurred: {str(e)}"
|
| 161 |
+
)
|
src/api/routers/threats_router.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query
|
| 2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from src.api.database import get_db
|
| 7 |
+
from src.api.auth import get_current_user
|
| 8 |
+
from src.api.schemas import (
|
| 9 |
+
Threat, ThreatCreate, ThreatUpdate, ThreatFilter,
|
| 10 |
+
PaginationParams, User
|
| 11 |
+
)
|
| 12 |
+
from src.api.services.threat_service import (
|
| 13 |
+
create_threat, get_threat_by_id, update_threat,
|
| 14 |
+
delete_threat, get_threats, get_threat_statistics
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Configure logger
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
router = APIRouter(
|
| 21 |
+
tags=["threats"],
|
| 22 |
+
responses={404: {"description": "Not found"}}
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
@router.post("/", response_model=Threat, status_code=status.HTTP_201_CREATED)
|
| 26 |
+
async def create_threat_endpoint(
|
| 27 |
+
threat_data: ThreatCreate,
|
| 28 |
+
db: AsyncSession = Depends(get_db),
|
| 29 |
+
current_user: User = Depends(get_current_user)
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Create a new threat.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
threat_data: Threat data
|
| 36 |
+
db: Database session
|
| 37 |
+
current_user: Current authenticated user
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Threat: Created threat
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
threat = await create_threat(db, threat_data)
|
| 44 |
+
|
| 45 |
+
if not threat:
|
| 46 |
+
raise HTTPException(
|
| 47 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 48 |
+
detail="Failed to create threat"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return threat
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Error creating threat: {e}")
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 56 |
+
detail=f"An error occurred: {str(e)}"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@router.get("/{threat_id}", response_model=Threat)
|
| 60 |
+
async def get_threat_endpoint(
|
| 61 |
+
threat_id: int,
|
| 62 |
+
db: AsyncSession = Depends(get_db),
|
| 63 |
+
current_user: User = Depends(get_current_user)
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Get threat by ID.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
threat_id: Threat ID
|
| 70 |
+
db: Database session
|
| 71 |
+
current_user: Current authenticated user
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Threat: Threat data
|
| 75 |
+
"""
|
| 76 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 77 |
+
|
| 78 |
+
if not threat:
|
| 79 |
+
raise HTTPException(
|
| 80 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 81 |
+
detail=f"Threat with ID {threat_id} not found"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return threat
|
| 85 |
+
|
| 86 |
+
@router.put("/{threat_id}", response_model=Threat)
|
| 87 |
+
async def update_threat_endpoint(
|
| 88 |
+
threat_id: int,
|
| 89 |
+
threat_data: ThreatUpdate,
|
| 90 |
+
db: AsyncSession = Depends(get_db),
|
| 91 |
+
current_user: User = Depends(get_current_user)
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Update threat.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
threat_id: Threat ID
|
| 98 |
+
threat_data: Threat data
|
| 99 |
+
db: Database session
|
| 100 |
+
current_user: Current authenticated user
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Threat: Updated threat
|
| 104 |
+
"""
|
| 105 |
+
# Check if threat exists
|
| 106 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 107 |
+
|
| 108 |
+
if not threat:
|
| 109 |
+
raise HTTPException(
|
| 110 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 111 |
+
detail=f"Threat with ID {threat_id} not found"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Update threat
|
| 115 |
+
updated_threat = await update_threat(db, threat_id, threat_data)
|
| 116 |
+
|
| 117 |
+
if not updated_threat:
|
| 118 |
+
raise HTTPException(
|
| 119 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 120 |
+
detail="Failed to update threat"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return updated_threat
|
| 124 |
+
|
| 125 |
+
@router.delete("/{threat_id}", status_code=status.HTTP_204_NO_CONTENT)
|
| 126 |
+
async def delete_threat_endpoint(
|
| 127 |
+
threat_id: int,
|
| 128 |
+
db: AsyncSession = Depends(get_db),
|
| 129 |
+
current_user: User = Depends(get_current_user)
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Delete threat.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
threat_id: Threat ID
|
| 136 |
+
db: Database session
|
| 137 |
+
current_user: Current authenticated user
|
| 138 |
+
"""
|
| 139 |
+
# Check if threat exists
|
| 140 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 141 |
+
|
| 142 |
+
if not threat:
|
| 143 |
+
raise HTTPException(
|
| 144 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 145 |
+
detail=f"Threat with ID {threat_id} not found"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Delete threat
|
| 149 |
+
deleted = await delete_threat(db, threat_id)
|
| 150 |
+
|
| 151 |
+
if not deleted:
|
| 152 |
+
raise HTTPException(
|
| 153 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 154 |
+
detail="Failed to delete threat"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
@router.get("/", response_model=List[Threat])
|
| 158 |
+
async def get_threats_endpoint(
|
| 159 |
+
pagination: PaginationParams = Depends(),
|
| 160 |
+
severity: Optional[List[str]] = Query(None),
|
| 161 |
+
status: Optional[List[str]] = Query(None),
|
| 162 |
+
category: Optional[List[str]] = Query(None),
|
| 163 |
+
search: Optional[str] = Query(None),
|
| 164 |
+
from_date: Optional[str] = Query(None),
|
| 165 |
+
to_date: Optional[str] = Query(None),
|
| 166 |
+
db: AsyncSession = Depends(get_db),
|
| 167 |
+
current_user: User = Depends(get_current_user)
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Get threats with filtering and pagination.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
pagination: Pagination parameters
|
| 174 |
+
severity: Filter by severity
|
| 175 |
+
status: Filter by status
|
| 176 |
+
category: Filter by category
|
| 177 |
+
search: Search in title and description
|
| 178 |
+
from_date: Filter from date
|
| 179 |
+
to_date: Filter to date
|
| 180 |
+
db: Database session
|
| 181 |
+
current_user: Current authenticated user
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
List[Threat]: List of threats
|
| 185 |
+
"""
|
| 186 |
+
# Create filter params
|
| 187 |
+
filter_params = ThreatFilter(
|
| 188 |
+
severity=severity,
|
| 189 |
+
status=status,
|
| 190 |
+
category=category,
|
| 191 |
+
search=search,
|
| 192 |
+
from_date=from_date,
|
| 193 |
+
to_date=to_date
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Get threats
|
| 197 |
+
threats, total = await get_threats(db, filter_params, pagination)
|
| 198 |
+
|
| 199 |
+
return threats
|
| 200 |
+
|
| 201 |
+
@router.get("/statistics", response_model=dict)
|
| 202 |
+
async def get_threat_statistics_endpoint(
|
| 203 |
+
db: AsyncSession = Depends(get_db),
|
| 204 |
+
current_user: User = Depends(get_current_user)
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
Get threat statistics.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
db: Database session
|
| 211 |
+
current_user: Current authenticated user
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
dict: Threat statistics
|
| 215 |
+
"""
|
| 216 |
+
statistics = await get_threat_statistics(db)
|
| 217 |
+
return statistics
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API schemas for data validation and serialization.
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel, Field, validator, EmailStr
|
| 5 |
+
from typing import Optional, List, Dict, Any, Union
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
# Pagination
|
| 10 |
+
class PaginationParams(BaseModel):
|
| 11 |
+
"""Pagination parameters."""
|
| 12 |
+
page: int = Field(1, ge=1, description="Page number")
|
| 13 |
+
size: int = Field(10, ge=1, le=100, description="Items per page")
|
| 14 |
+
|
| 15 |
+
# User schemas
|
| 16 |
+
class UserBase(BaseModel):
|
| 17 |
+
"""Base user schema."""
|
| 18 |
+
username: str
|
| 19 |
+
email: EmailStr
|
| 20 |
+
full_name: Optional[str] = None
|
| 21 |
+
is_active: bool = True
|
| 22 |
+
|
| 23 |
+
class UserCreate(UserBase):
|
| 24 |
+
"""User creation schema."""
|
| 25 |
+
password: str
|
| 26 |
+
|
| 27 |
+
class UserUpdate(BaseModel):
|
| 28 |
+
"""User update schema."""
|
| 29 |
+
username: Optional[str] = None
|
| 30 |
+
email: Optional[EmailStr] = None
|
| 31 |
+
full_name: Optional[str] = None
|
| 32 |
+
is_active: Optional[bool] = None
|
| 33 |
+
password: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
class UserResponse(UserBase):
|
| 36 |
+
"""User response schema."""
|
| 37 |
+
id: int
|
| 38 |
+
is_superuser: bool = False
|
| 39 |
+
|
| 40 |
+
class Config:
|
| 41 |
+
orm_mode = True
|
| 42 |
+
|
| 43 |
+
# Token schemas
|
| 44 |
+
class Token(BaseModel):
|
| 45 |
+
"""Token schema."""
|
| 46 |
+
access_token: str
|
| 47 |
+
token_type: str = "bearer"
|
| 48 |
+
|
| 49 |
+
class TokenPayload(BaseModel):
|
| 50 |
+
"""Token payload schema."""
|
| 51 |
+
sub: Optional[int] = None
|
| 52 |
+
|
| 53 |
+
# Dark Web Content schemas
|
| 54 |
+
class DarkWebContentBase(BaseModel):
|
| 55 |
+
"""Base schema for dark web content."""
|
| 56 |
+
url: str
|
| 57 |
+
title: Optional[str] = None
|
| 58 |
+
content: str
|
| 59 |
+
content_type: str
|
| 60 |
+
source_name: Optional[str] = None
|
| 61 |
+
source_type: Optional[str] = None
|
| 62 |
+
language: Optional[str] = None
|
| 63 |
+
|
| 64 |
+
class DarkWebContentCreate(DarkWebContentBase):
|
| 65 |
+
"""Schema for creating dark web content."""
|
| 66 |
+
relevance_score: float = 0.0
|
| 67 |
+
sentiment_score: float = 0.0
|
| 68 |
+
entity_data: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
class DarkWebContentUpdate(BaseModel):
|
| 71 |
+
"""Schema for updating dark web content."""
|
| 72 |
+
title: Optional[str] = None
|
| 73 |
+
content_status: Optional[str] = None
|
| 74 |
+
relevance_score: Optional[float] = None
|
| 75 |
+
sentiment_score: Optional[float] = None
|
| 76 |
+
entity_data: Optional[str] = None
|
| 77 |
+
|
| 78 |
+
class DarkWebContentResponse(DarkWebContentBase):
|
| 79 |
+
"""Schema for dark web content response."""
|
| 80 |
+
id: int
|
| 81 |
+
domain: Optional[str] = None
|
| 82 |
+
content_status: str
|
| 83 |
+
scraped_at: datetime
|
| 84 |
+
relevance_score: float
|
| 85 |
+
sentiment_score: float
|
| 86 |
+
entity_data: Optional[str] = None
|
| 87 |
+
|
| 88 |
+
class Config:
|
| 89 |
+
orm_mode = True
|
| 90 |
+
|
| 91 |
+
# Dark Web Mention schemas
|
| 92 |
+
class DarkWebMentionBase(BaseModel):
|
| 93 |
+
"""Base schema for dark web mention."""
|
| 94 |
+
content_id: int
|
| 95 |
+
keyword: str
|
| 96 |
+
keyword_category: Optional[str] = None
|
| 97 |
+
context: Optional[str] = None
|
| 98 |
+
snippet: Optional[str] = None
|
| 99 |
+
mention_type: Optional[str] = None
|
| 100 |
+
|
| 101 |
+
class DarkWebMentionCreate(DarkWebMentionBase):
|
| 102 |
+
"""Schema for creating dark web mention."""
|
| 103 |
+
confidence: float = 0.0
|
| 104 |
+
is_verified: bool = False
|
| 105 |
+
|
| 106 |
+
class DarkWebMentionUpdate(BaseModel):
|
| 107 |
+
"""Schema for updating dark web mention."""
|
| 108 |
+
keyword_category: Optional[str] = None
|
| 109 |
+
mention_type: Optional[str] = None
|
| 110 |
+
confidence: Optional[float] = None
|
| 111 |
+
is_verified: Optional[bool] = None
|
| 112 |
+
|
| 113 |
+
class DarkWebMentionResponse(DarkWebMentionBase):
|
| 114 |
+
"""Schema for dark web mention response."""
|
| 115 |
+
id: int
|
| 116 |
+
confidence: float
|
| 117 |
+
is_verified: bool
|
| 118 |
+
created_at: datetime
|
| 119 |
+
|
| 120 |
+
class Config:
|
| 121 |
+
orm_mode = True
|
| 122 |
+
|
| 123 |
+
# Threat schemas
|
| 124 |
+
class ThreatBase(BaseModel):
|
| 125 |
+
"""Base schema for threat."""
|
| 126 |
+
title: str
|
| 127 |
+
description: str
|
| 128 |
+
severity: str
|
| 129 |
+
category: str
|
| 130 |
+
|
| 131 |
+
class ThreatCreate(ThreatBase):
|
| 132 |
+
"""Schema for creating threat."""
|
| 133 |
+
status: str = "New"
|
| 134 |
+
source_url: Optional[str] = None
|
| 135 |
+
source_name: Optional[str] = None
|
| 136 |
+
source_type: Optional[str] = None
|
| 137 |
+
affected_entity: Optional[str] = None
|
| 138 |
+
affected_entity_type: Optional[str] = None
|
| 139 |
+
confidence_score: float = 0.0
|
| 140 |
+
risk_score: float = 0.0
|
| 141 |
+
|
| 142 |
+
class ThreatUpdate(BaseModel):
|
| 143 |
+
"""Schema for updating threat."""
|
| 144 |
+
title: Optional[str] = None
|
| 145 |
+
description: Optional[str] = None
|
| 146 |
+
severity: Optional[str] = None
|
| 147 |
+
status: Optional[str] = None
|
| 148 |
+
category: Optional[str] = None
|
| 149 |
+
affected_entity: Optional[str] = None
|
| 150 |
+
affected_entity_type: Optional[str] = None
|
| 151 |
+
confidence_score: Optional[float] = None
|
| 152 |
+
risk_score: Optional[float] = None
|
| 153 |
+
|
| 154 |
+
class ThreatResponse(ThreatBase):
|
| 155 |
+
"""Schema for threat response."""
|
| 156 |
+
id: int
|
| 157 |
+
status: str
|
| 158 |
+
source_url: Optional[str] = None
|
| 159 |
+
source_name: Optional[str] = None
|
| 160 |
+
source_type: Optional[str] = None
|
| 161 |
+
discovered_at: datetime
|
| 162 |
+
affected_entity: Optional[str] = None
|
| 163 |
+
affected_entity_type: Optional[str] = None
|
| 164 |
+
confidence_score: float
|
| 165 |
+
risk_score: float
|
| 166 |
+
|
| 167 |
+
class Config:
|
| 168 |
+
orm_mode = True
|
| 169 |
+
|
| 170 |
+
# Indicator schemas
|
| 171 |
+
class IndicatorBase(BaseModel):
|
| 172 |
+
"""Base schema for indicator."""
|
| 173 |
+
threat_id: int
|
| 174 |
+
value: str
|
| 175 |
+
indicator_type: str
|
| 176 |
+
description: Optional[str] = None
|
| 177 |
+
|
| 178 |
+
class IndicatorCreate(IndicatorBase):
|
| 179 |
+
"""Schema for creating indicator."""
|
| 180 |
+
is_verified: bool = False
|
| 181 |
+
context: Optional[str] = None
|
| 182 |
+
source: Optional[str] = None
|
| 183 |
+
confidence_score: float = 0.0
|
| 184 |
+
|
| 185 |
+
class IndicatorUpdate(BaseModel):
|
| 186 |
+
"""Schema for updating indicator."""
|
| 187 |
+
description: Optional[str] = None
|
| 188 |
+
is_verified: Optional[bool] = None
|
| 189 |
+
context: Optional[str] = None
|
| 190 |
+
source: Optional[str] = None
|
| 191 |
+
confidence_score: Optional[float] = None
|
| 192 |
+
|
| 193 |
+
class IndicatorResponse(IndicatorBase):
|
| 194 |
+
"""Schema for indicator response."""
|
| 195 |
+
id: int
|
| 196 |
+
is_verified: bool
|
| 197 |
+
context: Optional[str] = None
|
| 198 |
+
source: Optional[str] = None
|
| 199 |
+
confidence_score: float
|
| 200 |
+
first_seen: datetime
|
| 201 |
+
last_seen: datetime
|
| 202 |
+
|
| 203 |
+
class Config:
|
| 204 |
+
orm_mode = True
|
| 205 |
+
|
| 206 |
+
# Alert schemas
|
| 207 |
+
class AlertBase(BaseModel):
|
| 208 |
+
"""Base schema for alert."""
|
| 209 |
+
title: str
|
| 210 |
+
description: str
|
| 211 |
+
severity: str
|
| 212 |
+
category: str
|
| 213 |
+
|
| 214 |
+
class AlertCreate(AlertBase):
|
| 215 |
+
"""Schema for creating alert."""
|
| 216 |
+
source_url: Optional[str] = None
|
| 217 |
+
threat_id: Optional[int] = None
|
| 218 |
+
mention_id: Optional[int] = None
|
| 219 |
+
|
| 220 |
+
class AlertUpdate(BaseModel):
|
| 221 |
+
"""Schema for updating alert."""
|
| 222 |
+
status: str
|
| 223 |
+
action_taken: Optional[str] = None
|
| 224 |
+
assigned_to_id: Optional[int] = None
|
| 225 |
+
is_read: Optional[bool] = None
|
| 226 |
+
|
| 227 |
+
class AlertResponse(AlertBase):
|
| 228 |
+
"""Schema for alert response."""
|
| 229 |
+
id: int
|
| 230 |
+
status: str
|
| 231 |
+
generated_at: datetime
|
| 232 |
+
source_url: Optional[str] = None
|
| 233 |
+
is_read: bool
|
| 234 |
+
threat_id: Optional[int] = None
|
| 235 |
+
mention_id: Optional[int] = None
|
| 236 |
+
assigned_to_id: Optional[int] = None
|
| 237 |
+
action_taken: Optional[str] = None
|
| 238 |
+
resolved_at: Optional[datetime] = None
|
| 239 |
+
|
| 240 |
+
class Config:
|
| 241 |
+
orm_mode = True
|
| 242 |
+
|
| 243 |
+
# Report schemas
|
| 244 |
+
class ReportBase(BaseModel):
|
| 245 |
+
"""Base schema for report."""
|
| 246 |
+
report_id: str
|
| 247 |
+
title: str
|
| 248 |
+
summary: str
|
| 249 |
+
content: str
|
| 250 |
+
report_type: str
|
| 251 |
+
|
| 252 |
+
class ReportCreate(ReportBase):
|
| 253 |
+
"""Schema for creating report."""
|
| 254 |
+
status: str = "Draft"
|
| 255 |
+
severity: Optional[str] = None
|
| 256 |
+
publish_date: Optional[datetime] = None
|
| 257 |
+
time_period_start: Optional[datetime] = None
|
| 258 |
+
time_period_end: Optional[datetime] = None
|
| 259 |
+
keywords: Optional[str] = None
|
| 260 |
+
author_id: int
|
| 261 |
+
threat_ids: List[int] = []
|
| 262 |
+
|
| 263 |
+
class ReportUpdate(BaseModel):
|
| 264 |
+
"""Schema for updating report."""
|
| 265 |
+
title: Optional[str] = None
|
| 266 |
+
summary: Optional[str] = None
|
| 267 |
+
content: Optional[str] = None
|
| 268 |
+
report_type: Optional[str] = None
|
| 269 |
+
status: Optional[str] = None
|
| 270 |
+
severity: Optional[str] = None
|
| 271 |
+
publish_date: Optional[datetime] = None
|
| 272 |
+
time_period_start: Optional[datetime] = None
|
| 273 |
+
time_period_end: Optional[datetime] = None
|
| 274 |
+
keywords: Optional[str] = None
|
| 275 |
+
threat_ids: Optional[List[int]] = None
|
| 276 |
+
|
| 277 |
+
class ReportResponse(ReportBase):
|
| 278 |
+
"""Schema for report response."""
|
| 279 |
+
id: int
|
| 280 |
+
status: str
|
| 281 |
+
severity: Optional[str] = None
|
| 282 |
+
publish_date: Optional[datetime] = None
|
| 283 |
+
time_period_start: Optional[datetime] = None
|
| 284 |
+
time_period_end: Optional[datetime] = None
|
| 285 |
+
keywords: Optional[str] = None
|
| 286 |
+
author_id: int
|
| 287 |
+
|
| 288 |
+
class Config:
|
| 289 |
+
orm_mode = True
|
| 290 |
+
|
| 291 |
+
# Statistics response schemas
|
| 292 |
+
class ThreatStatisticsResponse(BaseModel):
|
| 293 |
+
"""Schema for threat statistics response."""
|
| 294 |
+
total_count: int
|
| 295 |
+
severity_counts: Dict[str, int]
|
| 296 |
+
status_counts: Dict[str, int]
|
| 297 |
+
category_counts: Dict[str, int]
|
| 298 |
+
time_series: List[Dict[str, Any]]
|
| 299 |
+
from_date: str
|
| 300 |
+
to_date: str
|
| 301 |
+
|
| 302 |
+
class ContentStatisticsResponse(BaseModel):
|
| 303 |
+
"""Schema for content statistics response."""
|
| 304 |
+
total_count: int
|
| 305 |
+
content_type_counts: Dict[str, int]
|
| 306 |
+
content_status_counts: Dict[str, int]
|
| 307 |
+
source_counts: Dict[str, int]
|
| 308 |
+
time_series: List[Dict[str, Any]]
|
| 309 |
+
from_date: str
|
| 310 |
+
to_date: str
|
src/api/security.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Security Module
|
| 3 |
+
|
| 4 |
+
This module provides security features for the API, including:
|
| 5 |
+
1. Authentication using JWT tokens
|
| 6 |
+
2. Rate limiting to prevent abuse
|
| 7 |
+
3. Role-based access control
|
| 8 |
+
4. Request validation
|
| 9 |
+
5. Audit logging
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import secrets
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
from typing import Dict, List, Optional, Union, Any, Callable
|
| 17 |
+
|
| 18 |
+
from fastapi import Depends, HTTPException, Security, status, Request
|
| 19 |
+
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
| 20 |
+
from jose import JWTError, jwt
|
| 21 |
+
from passlib.context import CryptContext
|
| 22 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 23 |
+
from sqlalchemy.future import select
|
| 24 |
+
from pydantic import BaseModel, EmailStr
|
| 25 |
+
|
| 26 |
+
from src.models.user import User
|
| 27 |
+
from src.api.database import get_db
|
| 28 |
+
|
| 29 |
+
# Configure logging
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Security configuration
|
| 33 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
|
| 34 |
+
ALGORITHM = "HS256"
|
| 35 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
| 36 |
+
API_KEY_NAME = "X-API-Key"
|
| 37 |
+
|
| 38 |
+
# Set up password hashing
|
| 39 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 40 |
+
|
| 41 |
+
# Set up security schemes
|
| 42 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
| 43 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
| 44 |
+
|
| 45 |
+
# User models
|
| 46 |
+
class Token(BaseModel):
|
| 47 |
+
access_token: str
|
| 48 |
+
token_type: str
|
| 49 |
+
expires_at: datetime
|
| 50 |
+
|
| 51 |
+
class TokenData(BaseModel):
|
| 52 |
+
username: Optional[str] = None
|
| 53 |
+
scopes: List[str] = []
|
| 54 |
+
|
| 55 |
+
class UserInDB(BaseModel):
|
| 56 |
+
id: int
|
| 57 |
+
username: str
|
| 58 |
+
email: EmailStr
|
| 59 |
+
full_name: Optional[str] = None
|
| 60 |
+
is_active: bool = True
|
| 61 |
+
is_superuser: bool = False
|
| 62 |
+
scopes: List[str] = []
|
| 63 |
+
|
| 64 |
+
class Config:
|
| 65 |
+
from_attributes = True
|
| 66 |
+
|
| 67 |
+
# Rate limiting
|
| 68 |
+
class RateLimiter:
|
| 69 |
+
"""Simple in-memory rate limiter"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, rate_limit: int = 100, time_window: int = 60):
|
| 72 |
+
"""
|
| 73 |
+
Initialize rate limiter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
rate_limit: Maximum number of requests per time window
|
| 77 |
+
time_window: Time window in seconds
|
| 78 |
+
"""
|
| 79 |
+
self.rate_limit = rate_limit
|
| 80 |
+
self.time_window = time_window
|
| 81 |
+
self.requests = {}
|
| 82 |
+
|
| 83 |
+
def is_rate_limited(self, key: str) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
Check if a key is rate limited.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
key: Identifier for the client (IP address, API key, etc.)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
True if rate limited, False otherwise
|
| 92 |
+
"""
|
| 93 |
+
current_time = time.time()
|
| 94 |
+
|
| 95 |
+
# Initialize or clean up old requests
|
| 96 |
+
if key not in self.requests:
|
| 97 |
+
self.requests[key] = []
|
| 98 |
+
else:
|
| 99 |
+
# Remove requests outside the time window
|
| 100 |
+
self.requests[key] = [t for t in self.requests[key] if t > current_time - self.time_window]
|
| 101 |
+
|
| 102 |
+
# Check if rate limit is exceeded
|
| 103 |
+
if len(self.requests[key]) >= self.rate_limit:
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
# Add the current request
|
| 107 |
+
self.requests[key].append(current_time)
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# Create global rate limiter instance
|
| 111 |
+
rate_limiter = RateLimiter()
|
| 112 |
+
|
| 113 |
+
# Role-based access control
|
| 114 |
+
# Define roles and permissions
|
| 115 |
+
ROLES = {
|
| 116 |
+
"admin": ["read:all", "write:all", "delete:all"],
|
| 117 |
+
"analyst": ["read:all", "write:threats", "write:indicators", "write:reports"],
|
| 118 |
+
"user": ["read:threats", "read:reports", "read:dashboard"],
|
| 119 |
+
"api": ["read:all", "write:threats", "write:indicators"]
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Security utility functions
|
| 123 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 124 |
+
"""Verify a password against a hash"""
|
| 125 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 126 |
+
|
| 127 |
+
def get_password_hash(password: str) -> str:
|
| 128 |
+
"""Hash a password for storage"""
|
| 129 |
+
return pwd_context.hash(password)
|
| 130 |
+
|
| 131 |
+
async def get_user(db: AsyncSession, username: str) -> Optional[UserInDB]:
|
| 132 |
+
"""Get a user from the database by username"""
|
| 133 |
+
result = await db.execute(select(User).filter(User.username == username))
|
| 134 |
+
user_db = result.scalars().first()
|
| 135 |
+
|
| 136 |
+
if not user_db:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
# Get user roles and scopes
|
| 140 |
+
scopes = []
|
| 141 |
+
if user_db.is_superuser:
|
| 142 |
+
scopes = ROLES["admin"]
|
| 143 |
+
else:
|
| 144 |
+
# In a real application, you would look up user roles in a database
|
| 145 |
+
# For simplicity, we'll assume non-superusers have the "user" role
|
| 146 |
+
scopes = ROLES["user"]
|
| 147 |
+
|
| 148 |
+
return UserInDB(
|
| 149 |
+
id=user_db.id,
|
| 150 |
+
username=user_db.username,
|
| 151 |
+
email=user_db.email,
|
| 152 |
+
full_name=user_db.full_name,
|
| 153 |
+
is_active=user_db.is_active,
|
| 154 |
+
is_superuser=user_db.is_superuser,
|
| 155 |
+
scopes=scopes
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
|
| 159 |
+
"""Authenticate a user with username and password"""
|
| 160 |
+
user = await get_user(db, username)
|
| 161 |
+
if not user:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# Get the user from the database again to get the hashed password
|
| 165 |
+
result = await db.execute(select(User).filter(User.username == username))
|
| 166 |
+
user_db = result.scalars().first()
|
| 167 |
+
|
| 168 |
+
if not verify_password(password, user_db.hashed_password):
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
return user
|
| 172 |
+
|
| 173 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
| 174 |
+
"""Create a JWT access token"""
|
| 175 |
+
to_encode = data.copy()
|
| 176 |
+
|
| 177 |
+
if expires_delta:
|
| 178 |
+
expire = datetime.utcnow() + expires_delta
|
| 179 |
+
else:
|
| 180 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 181 |
+
|
| 182 |
+
to_encode.update({"exp": expire})
|
| 183 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 184 |
+
return encoded_jwt
|
| 185 |
+
|
| 186 |
+
async def get_api_key_user(
|
| 187 |
+
api_key: str,
|
| 188 |
+
db: AsyncSession
|
| 189 |
+
) -> Optional[UserInDB]:
|
| 190 |
+
"""Get user associated with an API key"""
|
| 191 |
+
# In a real application, you would look up API keys in a database
|
| 192 |
+
# For simplicity, we'll use a simple hardcoded mapping
|
| 193 |
+
# TODO: Replace with database-backed API key storage
|
| 194 |
+
API_KEYS = {
|
| 195 |
+
"test-api-key": "api_user",
|
| 196 |
+
# Add more API keys here
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if api_key not in API_KEYS:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
username = API_KEYS[api_key]
|
| 203 |
+
user = await get_user(db, username)
|
| 204 |
+
|
| 205 |
+
if not user:
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
# Override scopes with API role scopes
|
| 209 |
+
user.scopes = ROLES["api"]
|
| 210 |
+
|
| 211 |
+
return user
|
| 212 |
+
|
| 213 |
+
# Dependencies for FastAPI
|
| 214 |
+
async def rate_limit(request: Request):
|
| 215 |
+
"""Rate limiting dependency"""
|
| 216 |
+
# Use API key or IP address as the rate limit key
|
| 217 |
+
client_key = request.headers.get(API_KEY_NAME) or request.client.host
|
| 218 |
+
|
| 219 |
+
if rate_limiter.is_rate_limited(client_key):
|
| 220 |
+
logger.warning(f"Rate limit exceeded for {client_key}")
|
| 221 |
+
raise HTTPException(
|
| 222 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 223 |
+
detail="Rate limit exceeded. Please try again later."
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return True
|
| 227 |
+
|
| 228 |
+
async def get_current_user(
|
| 229 |
+
token: str = Depends(oauth2_scheme),
|
| 230 |
+
api_key: str = Security(api_key_header),
|
| 231 |
+
db: AsyncSession = Depends(get_db)
|
| 232 |
+
) -> UserInDB:
|
| 233 |
+
"""
|
| 234 |
+
Get the current user from either JWT token or API key.
|
| 235 |
+
|
| 236 |
+
This dependency can be used to require authentication for endpoints.
|
| 237 |
+
"""
|
| 238 |
+
credentials_exception = HTTPException(
|
| 239 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 240 |
+
detail="Could not validate credentials",
|
| 241 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Check API key first
|
| 245 |
+
if api_key:
|
| 246 |
+
user = await get_api_key_user(api_key, db)
|
| 247 |
+
if user:
|
| 248 |
+
return user
|
| 249 |
+
|
| 250 |
+
# Then check JWT token
|
| 251 |
+
try:
|
| 252 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 253 |
+
username = payload.get("sub")
|
| 254 |
+
if username is None:
|
| 255 |
+
raise credentials_exception
|
| 256 |
+
|
| 257 |
+
token_data = TokenData(
|
| 258 |
+
username=username,
|
| 259 |
+
scopes=payload.get("scopes", [])
|
| 260 |
+
)
|
| 261 |
+
except JWTError:
|
| 262 |
+
raise credentials_exception
|
| 263 |
+
|
| 264 |
+
user = await get_user(db, username=token_data.username)
|
| 265 |
+
if user is None:
|
| 266 |
+
raise credentials_exception
|
| 267 |
+
|
| 268 |
+
return user
|
| 269 |
+
|
| 270 |
+
async def get_current_active_user(
|
| 271 |
+
current_user: UserInDB = Depends(get_current_user)
|
| 272 |
+
) -> UserInDB:
|
| 273 |
+
"""
|
| 274 |
+
Get the current active user.
|
| 275 |
+
|
| 276 |
+
This dependency can be used to require an active user for endpoints.
|
| 277 |
+
"""
|
| 278 |
+
if not current_user.is_active:
|
| 279 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
| 280 |
+
|
| 281 |
+
return current_user
|
| 282 |
+
|
| 283 |
+
def has_scope(required_scopes: List[str]):
|
| 284 |
+
"""
|
| 285 |
+
Create a dependency that checks if the user has the required scopes.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
required_scopes: List of required scopes
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
A dependency function that checks if the user has the required scopes
|
| 292 |
+
"""
|
| 293 |
+
async def _has_scope(
|
| 294 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
| 295 |
+
) -> UserInDB:
|
| 296 |
+
for scope in required_scopes:
|
| 297 |
+
if scope not in current_user.scopes:
|
| 298 |
+
raise HTTPException(
|
| 299 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 300 |
+
detail=f"Permission denied. Required scope: {scope}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return current_user
|
| 304 |
+
|
| 305 |
+
return _has_scope
|
| 306 |
+
|
| 307 |
+
def admin_only(
|
| 308 |
+
current_user: UserInDB = Depends(get_current_active_user)
|
| 309 |
+
) -> UserInDB:
|
| 310 |
+
"""
|
| 311 |
+
Dependency that requires an admin user.
|
| 312 |
+
"""
|
| 313 |
+
if not current_user.is_superuser:
|
| 314 |
+
raise HTTPException(
|
| 315 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 316 |
+
detail="Permission denied. Admin access required."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return current_user
|
| 320 |
+
|
| 321 |
+
# Audit logging middleware
|
| 322 |
+
async def audit_log_middleware(request: Request, call_next):
|
| 323 |
+
"""
|
| 324 |
+
Middleware for audit logging.
|
| 325 |
+
|
| 326 |
+
Records details about API requests.
|
| 327 |
+
"""
|
| 328 |
+
# Get request details
|
| 329 |
+
method = request.method
|
| 330 |
+
path = request.url.path
|
| 331 |
+
client_host = request.client.host
|
| 332 |
+
user_agent = request.headers.get("User-Agent", "Unknown")
|
| 333 |
+
|
| 334 |
+
# Get user details if available
|
| 335 |
+
user = getattr(request.state, "user", None)
|
| 336 |
+
username = user.username if user else "Anonymous"
|
| 337 |
+
|
| 338 |
+
# Log request
|
| 339 |
+
logger.info(
|
| 340 |
+
f"API Request: {method} {path} | User: {username} | "
|
| 341 |
+
f"Client: {client_host} | User-Agent: {user_agent}"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Process the request
|
| 345 |
+
start_time = time.time()
|
| 346 |
+
response = await call_next(request)
|
| 347 |
+
process_time = time.time() - start_time
|
| 348 |
+
|
| 349 |
+
# Log response
|
| 350 |
+
logger.info(
|
| 351 |
+
f"API Response: {method} {path} | Status: {response.status_code} | "
|
| 352 |
+
f"Time: {process_time:.4f}s | User: {username}"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return response
|
| 356 |
+
|
| 357 |
+
# API key validation middleware
|
| 358 |
+
def validate_api_key(request: Request):
|
| 359 |
+
"""
|
| 360 |
+
Middleware function to validate API keys.
|
| 361 |
+
|
| 362 |
+
This can be used as a dependency for FastAPI routes.
|
| 363 |
+
"""
|
| 364 |
+
api_key = request.headers.get(API_KEY_NAME)
|
| 365 |
+
if not api_key:
|
| 366 |
+
raise HTTPException(
|
| 367 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 368 |
+
detail="API key required",
|
| 369 |
+
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# In a real application, you would validate the API key against a database
|
| 373 |
+
# For simplicity, we'll use a hardcoded list
|
| 374 |
+
valid_keys = ["test-api-key"] # Replace with database lookup
|
| 375 |
+
if api_key not in valid_keys:
|
| 376 |
+
raise HTTPException(
|
| 377 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 378 |
+
detail="Invalid API key",
|
| 379 |
+
headers={"WWW-Authenticate": f"{API_KEY_NAME}"},
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return True
|
src/api/services/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Package initialization for API services.
|
| 3 |
+
"""
|
src/api/services/alert_service.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service for alert operations.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from sqlalchemy import func, or_, and_
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Optional, Dict, Any, Union
|
| 9 |
+
|
| 10 |
+
from src.models.alert import Alert, AlertStatus, AlertCategory
|
| 11 |
+
from src.models.threat import ThreatSeverity
|
| 12 |
+
from src.api.schemas import PaginationParams
|
| 13 |
+
|
| 14 |
+
async def create_alert(
|
| 15 |
+
db: AsyncSession,
|
| 16 |
+
title: str,
|
| 17 |
+
description: str,
|
| 18 |
+
severity: ThreatSeverity,
|
| 19 |
+
category: AlertCategory,
|
| 20 |
+
source_url: Optional[str] = None,
|
| 21 |
+
threat_id: Optional[int] = None,
|
| 22 |
+
mention_id: Optional[int] = None,
|
| 23 |
+
) -> Alert:
|
| 24 |
+
"""
|
| 25 |
+
Create a new alert.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
db: Database session
|
| 29 |
+
title: Alert title
|
| 30 |
+
description: Alert description
|
| 31 |
+
severity: Alert severity
|
| 32 |
+
category: Alert category
|
| 33 |
+
source_url: Source URL for the alert
|
| 34 |
+
threat_id: ID of related threat
|
| 35 |
+
mention_id: ID of related dark web mention
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Alert: Created alert
|
| 39 |
+
"""
|
| 40 |
+
db_alert = Alert(
|
| 41 |
+
title=title,
|
| 42 |
+
description=description,
|
| 43 |
+
severity=severity,
|
| 44 |
+
status=AlertStatus.NEW,
|
| 45 |
+
category=category,
|
| 46 |
+
generated_at=datetime.utcnow(),
|
| 47 |
+
source_url=source_url,
|
| 48 |
+
is_read=False,
|
| 49 |
+
threat_id=threat_id,
|
| 50 |
+
mention_id=mention_id,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
db.add(db_alert)
|
| 54 |
+
await db.commit()
|
| 55 |
+
await db.refresh(db_alert)
|
| 56 |
+
|
| 57 |
+
return db_alert
|
| 58 |
+
|
| 59 |
+
async def get_alert_by_id(db: AsyncSession, alert_id: int) -> Optional[Alert]:
|
| 60 |
+
"""
|
| 61 |
+
Get alert by ID.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
db: Database session
|
| 65 |
+
alert_id: Alert ID
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Optional[Alert]: Alert or None if not found
|
| 69 |
+
"""
|
| 70 |
+
result = await db.execute(select(Alert).filter(Alert.id == alert_id))
|
| 71 |
+
return result.scalars().first()
|
| 72 |
+
|
| 73 |
+
async def get_alerts(
|
| 74 |
+
db: AsyncSession,
|
| 75 |
+
pagination: PaginationParams,
|
| 76 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 77 |
+
status: Optional[List[AlertStatus]] = None,
|
| 78 |
+
category: Optional[List[AlertCategory]] = None,
|
| 79 |
+
is_read: Optional[bool] = None,
|
| 80 |
+
search_query: Optional[str] = None,
|
| 81 |
+
from_date: Optional[datetime] = None,
|
| 82 |
+
to_date: Optional[datetime] = None,
|
| 83 |
+
) -> List[Alert]:
|
| 84 |
+
"""
|
| 85 |
+
Get alerts with filtering and pagination.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
db: Database session
|
| 89 |
+
pagination: Pagination parameters
|
| 90 |
+
severity: Filter by severity
|
| 91 |
+
status: Filter by status
|
| 92 |
+
category: Filter by category
|
| 93 |
+
is_read: Filter by read status
|
| 94 |
+
search_query: Search in title and description
|
| 95 |
+
from_date: Filter by generated_at >= from_date
|
| 96 |
+
to_date: Filter by generated_at <= to_date
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List[Alert]: List of alerts
|
| 100 |
+
"""
|
| 101 |
+
query = select(Alert)
|
| 102 |
+
|
| 103 |
+
# Apply filters
|
| 104 |
+
if severity:
|
| 105 |
+
query = query.filter(Alert.severity.in_(severity))
|
| 106 |
+
|
| 107 |
+
if status:
|
| 108 |
+
query = query.filter(Alert.status.in_(status))
|
| 109 |
+
|
| 110 |
+
if category:
|
| 111 |
+
query = query.filter(Alert.category.in_(category))
|
| 112 |
+
|
| 113 |
+
if is_read is not None:
|
| 114 |
+
query = query.filter(Alert.is_read == is_read)
|
| 115 |
+
|
| 116 |
+
if search_query:
|
| 117 |
+
search_filter = or_(
|
| 118 |
+
Alert.title.ilike(f"%{search_query}%"),
|
| 119 |
+
Alert.description.ilike(f"%{search_query}%")
|
| 120 |
+
)
|
| 121 |
+
query = query.filter(search_filter)
|
| 122 |
+
|
| 123 |
+
if from_date:
|
| 124 |
+
query = query.filter(Alert.generated_at >= from_date)
|
| 125 |
+
|
| 126 |
+
if to_date:
|
| 127 |
+
query = query.filter(Alert.generated_at <= to_date)
|
| 128 |
+
|
| 129 |
+
# Apply pagination
|
| 130 |
+
query = query.order_by(Alert.generated_at.desc())
|
| 131 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
| 132 |
+
|
| 133 |
+
result = await db.execute(query)
|
| 134 |
+
return result.scalars().all()
|
| 135 |
+
|
| 136 |
+
async def count_alerts(
|
| 137 |
+
db: AsyncSession,
|
| 138 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 139 |
+
status: Optional[List[AlertStatus]] = None,
|
| 140 |
+
category: Optional[List[AlertCategory]] = None,
|
| 141 |
+
is_read: Optional[bool] = None,
|
| 142 |
+
search_query: Optional[str] = None,
|
| 143 |
+
from_date: Optional[datetime] = None,
|
| 144 |
+
to_date: Optional[datetime] = None,
|
| 145 |
+
) -> int:
|
| 146 |
+
"""
|
| 147 |
+
Count alerts with filtering.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
db: Database session
|
| 151 |
+
severity: Filter by severity
|
| 152 |
+
status: Filter by status
|
| 153 |
+
category: Filter by category
|
| 154 |
+
is_read: Filter by read status
|
| 155 |
+
search_query: Search in title and description
|
| 156 |
+
from_date: Filter by generated_at >= from_date
|
| 157 |
+
to_date: Filter by generated_at <= to_date
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
int: Count of alerts
|
| 161 |
+
"""
|
| 162 |
+
query = select(func.count(Alert.id))
|
| 163 |
+
|
| 164 |
+
# Apply filters (same as in get_alerts)
|
| 165 |
+
if severity:
|
| 166 |
+
query = query.filter(Alert.severity.in_(severity))
|
| 167 |
+
|
| 168 |
+
if status:
|
| 169 |
+
query = query.filter(Alert.status.in_(status))
|
| 170 |
+
|
| 171 |
+
if category:
|
| 172 |
+
query = query.filter(Alert.category.in_(category))
|
| 173 |
+
|
| 174 |
+
if is_read is not None:
|
| 175 |
+
query = query.filter(Alert.is_read == is_read)
|
| 176 |
+
|
| 177 |
+
if search_query:
|
| 178 |
+
search_filter = or_(
|
| 179 |
+
Alert.title.ilike(f"%{search_query}%"),
|
| 180 |
+
Alert.description.ilike(f"%{search_query}%")
|
| 181 |
+
)
|
| 182 |
+
query = query.filter(search_filter)
|
| 183 |
+
|
| 184 |
+
if from_date:
|
| 185 |
+
query = query.filter(Alert.generated_at >= from_date)
|
| 186 |
+
|
| 187 |
+
if to_date:
|
| 188 |
+
query = query.filter(Alert.generated_at <= to_date)
|
| 189 |
+
|
| 190 |
+
result = await db.execute(query)
|
| 191 |
+
return result.scalar()
|
| 192 |
+
|
| 193 |
+
async def update_alert_status(
|
| 194 |
+
db: AsyncSession,
|
| 195 |
+
alert_id: int,
|
| 196 |
+
status: AlertStatus,
|
| 197 |
+
action_taken: Optional[str] = None,
|
| 198 |
+
) -> Optional[Alert]:
|
| 199 |
+
"""
|
| 200 |
+
Update alert status.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
db: Database session
|
| 204 |
+
alert_id: Alert ID
|
| 205 |
+
status: New status
|
| 206 |
+
action_taken: Description of action taken
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Optional[Alert]: Updated alert or None if not found
|
| 210 |
+
"""
|
| 211 |
+
alert = await get_alert_by_id(db, alert_id)
|
| 212 |
+
if not alert:
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
alert.status = status
|
| 216 |
+
|
| 217 |
+
if action_taken:
|
| 218 |
+
alert.action_taken = action_taken
|
| 219 |
+
|
| 220 |
+
if status == AlertStatus.RESOLVED:
|
| 221 |
+
alert.resolved_at = datetime.utcnow()
|
| 222 |
+
|
| 223 |
+
alert.updated_at = datetime.utcnow()
|
| 224 |
+
|
| 225 |
+
await db.commit()
|
| 226 |
+
await db.refresh(alert)
|
| 227 |
+
|
| 228 |
+
return alert
|
| 229 |
+
|
| 230 |
+
async def mark_alert_as_read(
|
| 231 |
+
db: AsyncSession,
|
| 232 |
+
alert_id: int,
|
| 233 |
+
) -> Optional[Alert]:
|
| 234 |
+
"""
|
| 235 |
+
Mark alert as read.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
db: Database session
|
| 239 |
+
alert_id: Alert ID
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Optional[Alert]: Updated alert or None if not found
|
| 243 |
+
"""
|
| 244 |
+
alert = await get_alert_by_id(db, alert_id)
|
| 245 |
+
if not alert:
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
alert.is_read = True
|
| 249 |
+
alert.updated_at = datetime.utcnow()
|
| 250 |
+
|
| 251 |
+
await db.commit()
|
| 252 |
+
await db.refresh(alert)
|
| 253 |
+
|
| 254 |
+
return alert
|
| 255 |
+
|
| 256 |
+
async def assign_alert(
|
| 257 |
+
db: AsyncSession,
|
| 258 |
+
alert_id: int,
|
| 259 |
+
user_id: int,
|
| 260 |
+
) -> Optional[Alert]:
|
| 261 |
+
"""
|
| 262 |
+
Assign alert to a user.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
db: Database session
|
| 266 |
+
alert_id: Alert ID
|
| 267 |
+
user_id: User ID to assign to
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Optional[Alert]: Updated alert or None if not found
|
| 271 |
+
"""
|
| 272 |
+
alert = await get_alert_by_id(db, alert_id)
|
| 273 |
+
if not alert:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
alert.assigned_to_id = user_id
|
| 277 |
+
alert.status = AlertStatus.ASSIGNED
|
| 278 |
+
alert.updated_at = datetime.utcnow()
|
| 279 |
+
|
| 280 |
+
await db.commit()
|
| 281 |
+
await db.refresh(alert)
|
| 282 |
+
|
| 283 |
+
return alert
|
| 284 |
+
|
| 285 |
+
async def get_alert_counts_by_severity(
|
| 286 |
+
db: AsyncSession,
|
| 287 |
+
from_date: Optional[datetime] = None,
|
| 288 |
+
to_date: Optional[datetime] = None,
|
| 289 |
+
) -> Dict[str, int]:
|
| 290 |
+
"""
|
| 291 |
+
Get count of alerts by severity.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
db: Database session
|
| 295 |
+
from_date: Filter by generated_at >= from_date
|
| 296 |
+
to_date: Filter by generated_at <= to_date
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Dict[str, int]: Mapping of severity to count
|
| 300 |
+
"""
|
| 301 |
+
result = {}
|
| 302 |
+
|
| 303 |
+
for severity in ThreatSeverity:
|
| 304 |
+
query = select(func.count(Alert.id)).filter(Alert.severity == severity)
|
| 305 |
+
|
| 306 |
+
if from_date:
|
| 307 |
+
query = query.filter(Alert.generated_at >= from_date)
|
| 308 |
+
|
| 309 |
+
if to_date:
|
| 310 |
+
query = query.filter(Alert.generated_at <= to_date)
|
| 311 |
+
|
| 312 |
+
count_result = await db.execute(query)
|
| 313 |
+
count = count_result.scalar() or 0
|
| 314 |
+
result[severity.value] = count
|
| 315 |
+
|
| 316 |
+
return result
|
src/api/services/dark_web_content_service.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service for dark web content operations.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from sqlalchemy import func, or_, text
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Optional, Dict, Any, Union
|
| 9 |
+
|
| 10 |
+
from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus
|
| 11 |
+
from src.models.threat import Threat, ThreatCategory, ThreatSeverity, ThreatStatus
|
| 12 |
+
from src.api.schemas import PaginationParams
|
| 13 |
+
|
| 14 |
+
async def create_content(
|
| 15 |
+
db: AsyncSession,
|
| 16 |
+
url: str,
|
| 17 |
+
content: str,
|
| 18 |
+
title: Optional[str] = None,
|
| 19 |
+
content_type: ContentType = ContentType.OTHER,
|
| 20 |
+
content_status: ContentStatus = ContentStatus.NEW,
|
| 21 |
+
source_name: Optional[str] = None,
|
| 22 |
+
source_type: Optional[str] = None,
|
| 23 |
+
language: Optional[str] = None,
|
| 24 |
+
relevance_score: float = 0.0,
|
| 25 |
+
sentiment_score: float = 0.0,
|
| 26 |
+
entity_data: Optional[str] = None,
|
| 27 |
+
) -> DarkWebContent:
|
| 28 |
+
"""
|
| 29 |
+
Create a new dark web content entry.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
db: Database session
|
| 33 |
+
url: URL of the content
|
| 34 |
+
content: Text content
|
| 35 |
+
title: Title of the content
|
| 36 |
+
content_type: Type of content
|
| 37 |
+
content_status: Status of content
|
| 38 |
+
source_name: Name of the source
|
| 39 |
+
source_type: Type of source
|
| 40 |
+
language: Language of the content
|
| 41 |
+
relevance_score: Relevance score (0-1)
|
| 42 |
+
sentiment_score: Sentiment score (-1 to 1)
|
| 43 |
+
entity_data: JSON string of extracted entities
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
DarkWebContent: Created content
|
| 47 |
+
"""
|
| 48 |
+
# Extract domain from URL if possible
|
| 49 |
+
domain = None
|
| 50 |
+
if url:
|
| 51 |
+
try:
|
| 52 |
+
from urllib.parse import urlparse
|
| 53 |
+
parsed_url = urlparse(url)
|
| 54 |
+
domain = parsed_url.netloc
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
db_content = DarkWebContent(
|
| 59 |
+
url=url,
|
| 60 |
+
domain=domain,
|
| 61 |
+
title=title,
|
| 62 |
+
content=content,
|
| 63 |
+
content_type=content_type,
|
| 64 |
+
content_status=content_status,
|
| 65 |
+
source_name=source_name,
|
| 66 |
+
source_type=source_type,
|
| 67 |
+
language=language,
|
| 68 |
+
scraped_at=datetime.utcnow(),
|
| 69 |
+
relevance_score=relevance_score,
|
| 70 |
+
sentiment_score=sentiment_score,
|
| 71 |
+
entity_data=entity_data,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
db.add(db_content)
|
| 75 |
+
await db.commit()
|
| 76 |
+
await db.refresh(db_content)
|
| 77 |
+
|
| 78 |
+
return db_content
|
| 79 |
+
|
| 80 |
+
async def get_content_by_id(db: AsyncSession, content_id: int) -> Optional[DarkWebContent]:
|
| 81 |
+
"""
|
| 82 |
+
Get dark web content by ID.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
db: Database session
|
| 86 |
+
content_id: Content ID
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Optional[DarkWebContent]: Content or None if not found
|
| 90 |
+
"""
|
| 91 |
+
result = await db.execute(select(DarkWebContent).filter(DarkWebContent.id == content_id))
|
| 92 |
+
return result.scalars().first()
|
| 93 |
+
|
| 94 |
+
async def get_contents(
|
| 95 |
+
db: AsyncSession,
|
| 96 |
+
pagination: PaginationParams,
|
| 97 |
+
content_type: Optional[List[ContentType]] = None,
|
| 98 |
+
content_status: Optional[List[ContentStatus]] = None,
|
| 99 |
+
source_name: Optional[str] = None,
|
| 100 |
+
search_query: Optional[str] = None,
|
| 101 |
+
from_date: Optional[datetime] = None,
|
| 102 |
+
to_date: Optional[datetime] = None,
|
| 103 |
+
) -> List[DarkWebContent]:
|
| 104 |
+
"""
|
| 105 |
+
Get dark web contents with filtering and pagination.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
db: Database session
|
| 109 |
+
pagination: Pagination parameters
|
| 110 |
+
content_type: Filter by content type
|
| 111 |
+
content_status: Filter by content status
|
| 112 |
+
source_name: Filter by source name
|
| 113 |
+
search_query: Search in title and content
|
| 114 |
+
from_date: Filter by scraped_at >= from_date
|
| 115 |
+
to_date: Filter by scraped_at <= to_date
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
List[DarkWebContent]: List of dark web contents
|
| 119 |
+
"""
|
| 120 |
+
query = select(DarkWebContent)
|
| 121 |
+
|
| 122 |
+
# Apply filters
|
| 123 |
+
if content_type:
|
| 124 |
+
query = query.filter(DarkWebContent.content_type.in_(content_type))
|
| 125 |
+
|
| 126 |
+
if content_status:
|
| 127 |
+
query = query.filter(DarkWebContent.content_status.in_(content_status))
|
| 128 |
+
|
| 129 |
+
if source_name:
|
| 130 |
+
query = query.filter(DarkWebContent.source_name == source_name)
|
| 131 |
+
|
| 132 |
+
if search_query:
|
| 133 |
+
search_filter = or_(
|
| 134 |
+
DarkWebContent.title.ilike(f"%{search_query}%"),
|
| 135 |
+
DarkWebContent.content.ilike(f"%{search_query}%")
|
| 136 |
+
)
|
| 137 |
+
query = query.filter(search_filter)
|
| 138 |
+
|
| 139 |
+
if from_date:
|
| 140 |
+
query = query.filter(DarkWebContent.scraped_at >= from_date)
|
| 141 |
+
|
| 142 |
+
if to_date:
|
| 143 |
+
query = query.filter(DarkWebContent.scraped_at <= to_date)
|
| 144 |
+
|
| 145 |
+
# Apply pagination
|
| 146 |
+
query = query.order_by(DarkWebContent.scraped_at.desc())
|
| 147 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
| 148 |
+
|
| 149 |
+
result = await db.execute(query)
|
| 150 |
+
return result.scalars().all()
|
| 151 |
+
|
| 152 |
+
async def count_contents(
|
| 153 |
+
db: AsyncSession,
|
| 154 |
+
content_type: Optional[List[ContentType]] = None,
|
| 155 |
+
content_status: Optional[List[ContentStatus]] = None,
|
| 156 |
+
source_name: Optional[str] = None,
|
| 157 |
+
search_query: Optional[str] = None,
|
| 158 |
+
from_date: Optional[datetime] = None,
|
| 159 |
+
to_date: Optional[datetime] = None,
|
| 160 |
+
) -> int:
|
| 161 |
+
"""
|
| 162 |
+
Count dark web contents with filtering.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
db: Database session
|
| 166 |
+
content_type: Filter by content type
|
| 167 |
+
content_status: Filter by content status
|
| 168 |
+
source_name: Filter by source name
|
| 169 |
+
search_query: Search in title and content
|
| 170 |
+
from_date: Filter by scraped_at >= from_date
|
| 171 |
+
to_date: Filter by scraped_at <= to_date
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
int: Count of dark web contents
|
| 175 |
+
"""
|
| 176 |
+
query = select(func.count(DarkWebContent.id))
|
| 177 |
+
|
| 178 |
+
# Apply filters (same as in get_contents)
|
| 179 |
+
if content_type:
|
| 180 |
+
query = query.filter(DarkWebContent.content_type.in_(content_type))
|
| 181 |
+
|
| 182 |
+
if content_status:
|
| 183 |
+
query = query.filter(DarkWebContent.content_status.in_(content_status))
|
| 184 |
+
|
| 185 |
+
if source_name:
|
| 186 |
+
query = query.filter(DarkWebContent.source_name == source_name)
|
| 187 |
+
|
| 188 |
+
if search_query:
|
| 189 |
+
search_filter = or_(
|
| 190 |
+
DarkWebContent.title.ilike(f"%{search_query}%"),
|
| 191 |
+
DarkWebContent.content.ilike(f"%{search_query}%")
|
| 192 |
+
)
|
| 193 |
+
query = query.filter(search_filter)
|
| 194 |
+
|
| 195 |
+
if from_date:
|
| 196 |
+
query = query.filter(DarkWebContent.scraped_at >= from_date)
|
| 197 |
+
|
| 198 |
+
if to_date:
|
| 199 |
+
query = query.filter(DarkWebContent.scraped_at <= to_date)
|
| 200 |
+
|
| 201 |
+
result = await db.execute(query)
|
| 202 |
+
return result.scalar()
|
| 203 |
+
|
| 204 |
+
async def create_mention(
|
| 205 |
+
db: AsyncSession,
|
| 206 |
+
content_id: int,
|
| 207 |
+
keyword: str,
|
| 208 |
+
keyword_category: Optional[str] = None,
|
| 209 |
+
context: Optional[str] = None,
|
| 210 |
+
snippet: Optional[str] = None,
|
| 211 |
+
mention_type: Optional[str] = None,
|
| 212 |
+
confidence: float = 0.0,
|
| 213 |
+
is_verified: bool = False,
|
| 214 |
+
) -> DarkWebMention:
|
| 215 |
+
"""
|
| 216 |
+
Create a new dark web mention.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
db: Database session
|
| 220 |
+
content_id: ID of the content where the mention was found
|
| 221 |
+
keyword: Keyword that was mentioned
|
| 222 |
+
keyword_category: Category of the keyword
|
| 223 |
+
context: Text surrounding the mention
|
| 224 |
+
snippet: Extract of text containing the mention
|
| 225 |
+
mention_type: Type of mention
|
| 226 |
+
confidence: Confidence score (0-1)
|
| 227 |
+
is_verified: Whether the mention is verified
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
DarkWebMention: Created mention
|
| 231 |
+
"""
|
| 232 |
+
db_mention = DarkWebMention(
|
| 233 |
+
content_id=content_id,
|
| 234 |
+
keyword=keyword,
|
| 235 |
+
keyword_category=keyword_category,
|
| 236 |
+
context=context,
|
| 237 |
+
snippet=snippet,
|
| 238 |
+
mention_type=mention_type,
|
| 239 |
+
confidence=confidence,
|
| 240 |
+
is_verified=is_verified,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
db.add(db_mention)
|
| 244 |
+
await db.commit()
|
| 245 |
+
await db.refresh(db_mention)
|
| 246 |
+
|
| 247 |
+
return db_mention
|
| 248 |
+
|
| 249 |
+
async def get_mention_by_id(db: AsyncSession, mention_id: int) -> Optional[DarkWebMention]:
|
| 250 |
+
"""
|
| 251 |
+
Get dark web mention by ID.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
db: Database session
|
| 255 |
+
mention_id: Mention ID
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Optional[DarkWebMention]: Mention or None if not found
|
| 259 |
+
"""
|
| 260 |
+
result = await db.execute(select(DarkWebMention).filter(DarkWebMention.id == mention_id))
|
| 261 |
+
return result.scalars().first()
|
| 262 |
+
|
| 263 |
+
async def get_mentions(
|
| 264 |
+
db: AsyncSession,
|
| 265 |
+
pagination: PaginationParams,
|
| 266 |
+
keyword: Optional[str] = None,
|
| 267 |
+
content_id: Optional[int] = None,
|
| 268 |
+
is_verified: Optional[bool] = None,
|
| 269 |
+
from_date: Optional[datetime] = None,
|
| 270 |
+
to_date: Optional[datetime] = None,
|
| 271 |
+
) -> List[DarkWebMention]:
|
| 272 |
+
"""
|
| 273 |
+
Get dark web mentions with filtering and pagination.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
db: Database session
|
| 277 |
+
pagination: Pagination parameters
|
| 278 |
+
keyword: Filter by keyword
|
| 279 |
+
content_id: Filter by content ID
|
| 280 |
+
is_verified: Filter by verification status
|
| 281 |
+
from_date: Filter by created_at >= from_date
|
| 282 |
+
to_date: Filter by created_at <= to_date
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
List[DarkWebMention]: List of dark web mentions
|
| 286 |
+
"""
|
| 287 |
+
query = select(DarkWebMention)
|
| 288 |
+
|
| 289 |
+
# Apply filters
|
| 290 |
+
if keyword:
|
| 291 |
+
query = query.filter(DarkWebMention.keyword.ilike(f"%{keyword}%"))
|
| 292 |
+
|
| 293 |
+
if content_id:
|
| 294 |
+
query = query.filter(DarkWebMention.content_id == content_id)
|
| 295 |
+
|
| 296 |
+
if is_verified is not None:
|
| 297 |
+
query = query.filter(DarkWebMention.is_verified == is_verified)
|
| 298 |
+
|
| 299 |
+
if from_date:
|
| 300 |
+
query = query.filter(DarkWebMention.created_at >= from_date)
|
| 301 |
+
|
| 302 |
+
if to_date:
|
| 303 |
+
query = query.filter(DarkWebMention.created_at <= to_date)
|
| 304 |
+
|
| 305 |
+
# Apply pagination
|
| 306 |
+
query = query.order_by(DarkWebMention.created_at.desc())
|
| 307 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
| 308 |
+
|
| 309 |
+
result = await db.execute(query)
|
| 310 |
+
return result.scalars().all()
|
| 311 |
+
|
| 312 |
+
async def create_threat_from_content(
|
| 313 |
+
db: AsyncSession,
|
| 314 |
+
content_id: int,
|
| 315 |
+
title: str,
|
| 316 |
+
description: str,
|
| 317 |
+
severity: ThreatSeverity,
|
| 318 |
+
category: ThreatCategory,
|
| 319 |
+
confidence_score: float = 0.0,
|
| 320 |
+
) -> Threat:
|
| 321 |
+
"""
|
| 322 |
+
Create a threat from dark web content.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
db: Database session
|
| 326 |
+
content_id: ID of the content
|
| 327 |
+
title: Threat title
|
| 328 |
+
description: Threat description
|
| 329 |
+
severity: Threat severity
|
| 330 |
+
category: Threat category
|
| 331 |
+
confidence_score: Confidence score (0-1)
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Threat: Created threat
|
| 335 |
+
"""
|
| 336 |
+
# Get the content
|
| 337 |
+
content = await get_content_by_id(db, content_id)
|
| 338 |
+
if not content:
|
| 339 |
+
raise ValueError(f"Content with ID {content_id} not found")
|
| 340 |
+
|
| 341 |
+
# Create the threat
|
| 342 |
+
from src.api.services.threat_service import create_threat
|
| 343 |
+
|
| 344 |
+
threat = await create_threat(
|
| 345 |
+
db=db,
|
| 346 |
+
title=title,
|
| 347 |
+
description=description,
|
| 348 |
+
severity=severity,
|
| 349 |
+
category=category,
|
| 350 |
+
status=ThreatStatus.NEW,
|
| 351 |
+
source_url=content.url,
|
| 352 |
+
source_name=content.source_name,
|
| 353 |
+
source_type=content.source_type,
|
| 354 |
+
confidence_score=confidence_score,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
return threat
|
src/api/services/report_service.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service for working with intelligence reports.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from sqlalchemy import update, delete, func, desc, and_, or_
|
| 7 |
+
from typing import List, Optional, Dict, Any, Union
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from src.models.report import Report, ReportType, ReportStatus
|
| 12 |
+
from src.models.threat import ThreatSeverity
|
| 13 |
+
from src.api.schemas import PaginationParams
|
| 14 |
+
|
| 15 |
+
# Configure logger
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def create_report(
|
| 20 |
+
db: AsyncSession,
|
| 21 |
+
title: str,
|
| 22 |
+
summary: str,
|
| 23 |
+
content: str,
|
| 24 |
+
report_type: ReportType,
|
| 25 |
+
report_id: str,
|
| 26 |
+
status: ReportStatus = ReportStatus.DRAFT,
|
| 27 |
+
severity: Optional[ThreatSeverity] = None,
|
| 28 |
+
publish_date: Optional[datetime] = None,
|
| 29 |
+
time_period_start: Optional[datetime] = None,
|
| 30 |
+
time_period_end: Optional[datetime] = None,
|
| 31 |
+
keywords: Optional[List[str]] = None,
|
| 32 |
+
source_data: Optional[Dict[str, Any]] = None,
|
| 33 |
+
author_id: Optional[int] = None,
|
| 34 |
+
) -> Report:
|
| 35 |
+
"""
|
| 36 |
+
Create a new intelligence report.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
db: Database session
|
| 40 |
+
title: Report title
|
| 41 |
+
summary: Report summary
|
| 42 |
+
content: Report content
|
| 43 |
+
report_type: Type of report
|
| 44 |
+
report_id: Custom ID for the report
|
| 45 |
+
status: Report status
|
| 46 |
+
severity: Report severity
|
| 47 |
+
publish_date: Publication date
|
| 48 |
+
time_period_start: Start of time period covered
|
| 49 |
+
time_period_end: End of time period covered
|
| 50 |
+
keywords: List of keywords related to the report
|
| 51 |
+
source_data: Sources and references
|
| 52 |
+
author_id: ID of the report author
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Report: Created report
|
| 56 |
+
"""
|
| 57 |
+
report = Report(
|
| 58 |
+
title=title,
|
| 59 |
+
summary=summary,
|
| 60 |
+
content=content,
|
| 61 |
+
report_type=report_type,
|
| 62 |
+
report_id=report_id,
|
| 63 |
+
status=status,
|
| 64 |
+
severity=severity,
|
| 65 |
+
publish_date=publish_date,
|
| 66 |
+
time_period_start=time_period_start,
|
| 67 |
+
time_period_end=time_period_end,
|
| 68 |
+
keywords=keywords or [],
|
| 69 |
+
source_data=source_data or {},
|
| 70 |
+
author_id=author_id,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
db.add(report)
|
| 74 |
+
await db.commit()
|
| 75 |
+
await db.refresh(report)
|
| 76 |
+
|
| 77 |
+
return report
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
async def get_report_by_id(db: AsyncSession, report_id: int) -> Optional[Report]:
|
| 81 |
+
"""
|
| 82 |
+
Get report by ID.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
db: Database session
|
| 86 |
+
report_id: Report ID
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Optional[Report]: Found report or None
|
| 90 |
+
"""
|
| 91 |
+
result = await db.execute(
|
| 92 |
+
select(Report).where(Report.id == report_id)
|
| 93 |
+
)
|
| 94 |
+
return result.scalars().first()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def get_report_by_custom_id(db: AsyncSession, custom_id: str) -> Optional[Report]:
|
| 98 |
+
"""
|
| 99 |
+
Get report by custom ID.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
db: Database session
|
| 103 |
+
custom_id: Custom report ID (e.g., RPT-2023-0001)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Optional[Report]: Found report or None
|
| 107 |
+
"""
|
| 108 |
+
result = await db.execute(
|
| 109 |
+
select(Report).where(Report.report_id == custom_id)
|
| 110 |
+
)
|
| 111 |
+
return result.scalars().first()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
async def get_reports(
|
| 115 |
+
db: AsyncSession,
|
| 116 |
+
pagination: PaginationParams,
|
| 117 |
+
report_type: Optional[List[ReportType]] = None,
|
| 118 |
+
status: Optional[List[ReportStatus]] = None,
|
| 119 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 120 |
+
search_query: Optional[str] = None,
|
| 121 |
+
keywords: Optional[List[str]] = None,
|
| 122 |
+
author_id: Optional[int] = None,
|
| 123 |
+
from_date: Optional[datetime] = None,
|
| 124 |
+
to_date: Optional[datetime] = None,
|
| 125 |
+
) -> List[Report]:
|
| 126 |
+
"""
|
| 127 |
+
Get reports with filtering.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
db: Database session
|
| 131 |
+
pagination: Pagination parameters
|
| 132 |
+
report_type: Filter by report type
|
| 133 |
+
status: Filter by status
|
| 134 |
+
severity: Filter by severity
|
| 135 |
+
search_query: Search in title and summary
|
| 136 |
+
keywords: Filter by keywords
|
| 137 |
+
author_id: Filter by author ID
|
| 138 |
+
from_date: Filter by created_at >= from_date
|
| 139 |
+
to_date: Filter by created_at <= to_date
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
List[Report]: List of reports
|
| 143 |
+
"""
|
| 144 |
+
query = select(Report)
|
| 145 |
+
|
| 146 |
+
# Apply filters
|
| 147 |
+
if report_type:
|
| 148 |
+
query = query.where(Report.report_type.in_(report_type))
|
| 149 |
+
|
| 150 |
+
if status:
|
| 151 |
+
query = query.where(Report.status.in_(status))
|
| 152 |
+
|
| 153 |
+
if severity:
|
| 154 |
+
query = query.where(Report.severity.in_(severity))
|
| 155 |
+
|
| 156 |
+
if search_query:
|
| 157 |
+
search_filter = or_(
|
| 158 |
+
Report.title.ilike(f"%{search_query}%"),
|
| 159 |
+
Report.summary.ilike(f"%{search_query}%"),
|
| 160 |
+
Report.content.ilike(f"%{search_query}%"),
|
| 161 |
+
)
|
| 162 |
+
query = query.where(search_filter)
|
| 163 |
+
|
| 164 |
+
if keywords:
|
| 165 |
+
# For JSON arrays, need to use a more complex query
|
| 166 |
+
for keyword in keywords:
|
| 167 |
+
query = query.where(Report.keywords.contains([keyword]))
|
| 168 |
+
|
| 169 |
+
if author_id:
|
| 170 |
+
query = query.where(Report.author_id == author_id)
|
| 171 |
+
|
| 172 |
+
if from_date:
|
| 173 |
+
query = query.where(Report.created_at >= from_date)
|
| 174 |
+
|
| 175 |
+
if to_date:
|
| 176 |
+
query = query.where(Report.created_at <= to_date)
|
| 177 |
+
|
| 178 |
+
# Apply pagination
|
| 179 |
+
query = query.order_by(desc(Report.created_at))
|
| 180 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
| 181 |
+
|
| 182 |
+
result = await db.execute(query)
|
| 183 |
+
return result.scalars().all()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
async def count_reports(
|
| 187 |
+
db: AsyncSession,
|
| 188 |
+
report_type: Optional[List[ReportType]] = None,
|
| 189 |
+
status: Optional[List[ReportStatus]] = None,
|
| 190 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 191 |
+
search_query: Optional[str] = None,
|
| 192 |
+
keywords: Optional[List[str]] = None,
|
| 193 |
+
author_id: Optional[int] = None,
|
| 194 |
+
from_date: Optional[datetime] = None,
|
| 195 |
+
to_date: Optional[datetime] = None,
|
| 196 |
+
) -> int:
|
| 197 |
+
"""
|
| 198 |
+
Count reports with filtering.
|
| 199 |
+
|
| 200 |
+
Args are the same as get_reports, except pagination.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
int: Count of matching reports
|
| 204 |
+
"""
|
| 205 |
+
query = select(func.count(Report.id))
|
| 206 |
+
|
| 207 |
+
# Apply filters
|
| 208 |
+
if report_type:
|
| 209 |
+
query = query.where(Report.report_type.in_(report_type))
|
| 210 |
+
|
| 211 |
+
if status:
|
| 212 |
+
query = query.where(Report.status.in_(status))
|
| 213 |
+
|
| 214 |
+
if severity:
|
| 215 |
+
query = query.where(Report.severity.in_(severity))
|
| 216 |
+
|
| 217 |
+
if search_query:
|
| 218 |
+
search_filter = or_(
|
| 219 |
+
Report.title.ilike(f"%{search_query}%"),
|
| 220 |
+
Report.summary.ilike(f"%{search_query}%"),
|
| 221 |
+
Report.content.ilike(f"%{search_query}%"),
|
| 222 |
+
)
|
| 223 |
+
query = query.where(search_filter)
|
| 224 |
+
|
| 225 |
+
if keywords:
|
| 226 |
+
# For JSON arrays, need to use a more complex query
|
| 227 |
+
for keyword in keywords:
|
| 228 |
+
query = query.where(Report.keywords.contains([keyword]))
|
| 229 |
+
|
| 230 |
+
if author_id:
|
| 231 |
+
query = query.where(Report.author_id == author_id)
|
| 232 |
+
|
| 233 |
+
if from_date:
|
| 234 |
+
query = query.where(Report.created_at >= from_date)
|
| 235 |
+
|
| 236 |
+
if to_date:
|
| 237 |
+
query = query.where(Report.created_at <= to_date)
|
| 238 |
+
|
| 239 |
+
result = await db.execute(query)
|
| 240 |
+
return result.scalar()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
async def update_report(
|
| 244 |
+
db: AsyncSession,
|
| 245 |
+
report_id: int,
|
| 246 |
+
title: Optional[str] = None,
|
| 247 |
+
summary: Optional[str] = None,
|
| 248 |
+
content: Optional[str] = None,
|
| 249 |
+
report_type: Optional[ReportType] = None,
|
| 250 |
+
status: Optional[ReportStatus] = None,
|
| 251 |
+
severity: Optional[ThreatSeverity] = None,
|
| 252 |
+
publish_date: Optional[datetime] = None,
|
| 253 |
+
time_period_start: Optional[datetime] = None,
|
| 254 |
+
time_period_end: Optional[datetime] = None,
|
| 255 |
+
keywords: Optional[List[str]] = None,
|
| 256 |
+
source_data: Optional[Dict[str, Any]] = None,
|
| 257 |
+
) -> Optional[Report]:
|
| 258 |
+
"""
|
| 259 |
+
Update a report.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
db: Database session
|
| 263 |
+
report_id: Report ID
|
| 264 |
+
Other args: Fields to update
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
Optional[Report]: Updated report or None
|
| 268 |
+
"""
|
| 269 |
+
report = await get_report_by_id(db, report_id)
|
| 270 |
+
|
| 271 |
+
if not report:
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
# Update fields if provided
|
| 275 |
+
if title is not None:
|
| 276 |
+
report.title = title
|
| 277 |
+
|
| 278 |
+
if summary is not None:
|
| 279 |
+
report.summary = summary
|
| 280 |
+
|
| 281 |
+
if content is not None:
|
| 282 |
+
report.content = content
|
| 283 |
+
|
| 284 |
+
if report_type is not None:
|
| 285 |
+
report.report_type = report_type
|
| 286 |
+
|
| 287 |
+
if status is not None:
|
| 288 |
+
report.status = status
|
| 289 |
+
|
| 290 |
+
if severity is not None:
|
| 291 |
+
report.severity = severity
|
| 292 |
+
|
| 293 |
+
if publish_date is not None:
|
| 294 |
+
report.publish_date = publish_date
|
| 295 |
+
|
| 296 |
+
if time_period_start is not None:
|
| 297 |
+
report.time_period_start = time_period_start
|
| 298 |
+
|
| 299 |
+
if time_period_end is not None:
|
| 300 |
+
report.time_period_end = time_period_end
|
| 301 |
+
|
| 302 |
+
if keywords is not None:
|
| 303 |
+
report.keywords = keywords
|
| 304 |
+
|
| 305 |
+
if source_data is not None:
|
| 306 |
+
report.source_data = source_data
|
| 307 |
+
|
| 308 |
+
await db.commit()
|
| 309 |
+
await db.refresh(report)
|
| 310 |
+
|
| 311 |
+
return report
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
async def add_threat_to_report(
|
| 315 |
+
db: AsyncSession,
|
| 316 |
+
report_id: int,
|
| 317 |
+
threat_id: int,
|
| 318 |
+
) -> Optional[Report]:
|
| 319 |
+
"""
|
| 320 |
+
Add a threat to a report.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
db: Database session
|
| 324 |
+
report_id: Report ID
|
| 325 |
+
threat_id: Threat ID
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Optional[Report]: Updated report or None
|
| 329 |
+
"""
|
| 330 |
+
from src.api.services.threat_service import get_threat_by_id
|
| 331 |
+
|
| 332 |
+
# Get report and threat
|
| 333 |
+
report = await get_report_by_id(db, report_id)
|
| 334 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 335 |
+
|
| 336 |
+
if not report or not threat:
|
| 337 |
+
return None
|
| 338 |
+
|
| 339 |
+
# Add threat to report
|
| 340 |
+
report.threats.append(threat)
|
| 341 |
+
await db.commit()
|
| 342 |
+
await db.refresh(report)
|
| 343 |
+
|
| 344 |
+
return report
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
async def remove_threat_from_report(
|
| 348 |
+
db: AsyncSession,
|
| 349 |
+
report_id: int,
|
| 350 |
+
threat_id: int,
|
| 351 |
+
) -> Optional[Report]:
|
| 352 |
+
"""
|
| 353 |
+
Remove a threat from a report.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
db: Database session
|
| 357 |
+
report_id: Report ID
|
| 358 |
+
threat_id: Threat ID
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Optional[Report]: Updated report or None
|
| 362 |
+
"""
|
| 363 |
+
from src.api.services.threat_service import get_threat_by_id
|
| 364 |
+
|
| 365 |
+
# Get report and threat
|
| 366 |
+
report = await get_report_by_id(db, report_id)
|
| 367 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 368 |
+
|
| 369 |
+
if not report or not threat:
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
# Remove threat from report
|
| 373 |
+
if threat in report.threats:
|
| 374 |
+
report.threats.remove(threat)
|
| 375 |
+
await db.commit()
|
| 376 |
+
await db.refresh(report)
|
| 377 |
+
|
| 378 |
+
return report
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
async def publish_report(
|
| 382 |
+
db: AsyncSession,
|
| 383 |
+
report_id: int,
|
| 384 |
+
) -> Optional[Report]:
|
| 385 |
+
"""
|
| 386 |
+
Publish a report.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
db: Database session
|
| 390 |
+
report_id: Report ID
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Optional[Report]: Updated report or None
|
| 394 |
+
"""
|
| 395 |
+
report = await get_report_by_id(db, report_id)
|
| 396 |
+
|
| 397 |
+
if not report:
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
# Update status and publish date
|
| 401 |
+
report.status = ReportStatus.PUBLISHED
|
| 402 |
+
|
| 403 |
+
if not report.publish_date:
|
| 404 |
+
report.publish_date = datetime.utcnow()
|
| 405 |
+
|
| 406 |
+
await db.commit()
|
| 407 |
+
await db.refresh(report)
|
| 408 |
+
|
| 409 |
+
return report
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
async def archive_report(
|
| 413 |
+
db: AsyncSession,
|
| 414 |
+
report_id: int,
|
| 415 |
+
) -> Optional[Report]:
|
| 416 |
+
"""
|
| 417 |
+
Archive a report.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
db: Database session
|
| 421 |
+
report_id: Report ID
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
Optional[Report]: Updated report or None
|
| 425 |
+
"""
|
| 426 |
+
report = await get_report_by_id(db, report_id)
|
| 427 |
+
|
| 428 |
+
if not report:
|
| 429 |
+
return None
|
| 430 |
+
|
| 431 |
+
# Update status
|
| 432 |
+
report.status = ReportStatus.ARCHIVED
|
| 433 |
+
await db.commit()
|
| 434 |
+
await db.refresh(report)
|
| 435 |
+
|
| 436 |
+
return report
|
src/api/services/search_history_service.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Search History and Trends Service
|
| 3 |
+
|
| 4 |
+
This service manages search history, saved searches, and trend analysis.
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from sqlalchemy import func, desc, and_, or_, text
|
| 12 |
+
from sqlalchemy.future import select
|
| 13 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 14 |
+
from sqlalchemy.orm import selectinload
|
| 15 |
+
|
| 16 |
+
from src.models.search_history import SearchHistory, SearchResult, SavedSearch, TrendTopic
|
| 17 |
+
from src.models.dark_web_content import DarkWebContent
|
| 18 |
+
from src.models.user import User
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
async def add_search_history(
|
| 24 |
+
db: AsyncSession,
|
| 25 |
+
query: str,
|
| 26 |
+
user_id: Optional[int] = None,
|
| 27 |
+
result_count: int = 0,
|
| 28 |
+
category: Optional[str] = None,
|
| 29 |
+
is_saved: bool = False,
|
| 30 |
+
notes: Optional[str] = None,
|
| 31 |
+
tags: Optional[str] = None
|
| 32 |
+
) -> SearchHistory:
|
| 33 |
+
"""
|
| 34 |
+
Add a new search history entry.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
db: Database session
|
| 38 |
+
query: Search query
|
| 39 |
+
user_id: ID of the user who performed the search (optional)
|
| 40 |
+
result_count: Number of results returned
|
| 41 |
+
category: Category of the search
|
| 42 |
+
is_saved: Whether this is a saved search
|
| 43 |
+
notes: Optional notes
|
| 44 |
+
tags: Optional tags (comma-separated)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
The created SearchHistory object
|
| 48 |
+
"""
|
| 49 |
+
search_history = SearchHistory(
|
| 50 |
+
query=query,
|
| 51 |
+
user_id=user_id,
|
| 52 |
+
result_count=result_count,
|
| 53 |
+
category=category,
|
| 54 |
+
is_saved=is_saved,
|
| 55 |
+
notes=notes,
|
| 56 |
+
tags=tags
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
db.add(search_history)
|
| 60 |
+
await db.commit()
|
| 61 |
+
await db.refresh(search_history)
|
| 62 |
+
|
| 63 |
+
# Update trend data
|
| 64 |
+
await update_trend_data(db, query, category)
|
| 65 |
+
|
| 66 |
+
return search_history
|
| 67 |
+
|
| 68 |
+
async def add_search_result(
|
| 69 |
+
db: AsyncSession,
|
| 70 |
+
search_id: int,
|
| 71 |
+
url: str,
|
| 72 |
+
title: Optional[str] = None,
|
| 73 |
+
snippet: Optional[str] = None,
|
| 74 |
+
source: Optional[str] = None,
|
| 75 |
+
relevance_score: float = 0.0,
|
| 76 |
+
content_id: Optional[int] = None
|
| 77 |
+
) -> SearchResult:
|
| 78 |
+
"""
|
| 79 |
+
Add a new search result.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
db: Database session
|
| 83 |
+
search_id: ID of the parent search
|
| 84 |
+
url: URL of the result
|
| 85 |
+
title: Title of the result
|
| 86 |
+
snippet: Text snippet from the result
|
| 87 |
+
source: Source of the result
|
| 88 |
+
relevance_score: Score indicating relevance to the search query
|
| 89 |
+
content_id: ID of the content in our database (if applicable)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
The created SearchResult object
|
| 93 |
+
"""
|
| 94 |
+
search_result = SearchResult(
|
| 95 |
+
search_id=search_id,
|
| 96 |
+
url=url,
|
| 97 |
+
title=title,
|
| 98 |
+
snippet=snippet,
|
| 99 |
+
source=source,
|
| 100 |
+
relevance_score=relevance_score,
|
| 101 |
+
content_id=content_id
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
db.add(search_result)
|
| 105 |
+
await db.commit()
|
| 106 |
+
await db.refresh(search_result)
|
| 107 |
+
|
| 108 |
+
return search_result
|
| 109 |
+
|
| 110 |
+
async def get_search_history(
|
| 111 |
+
db: AsyncSession,
|
| 112 |
+
skip: int = 0,
|
| 113 |
+
limit: int = 100,
|
| 114 |
+
user_id: Optional[int] = None,
|
| 115 |
+
query_filter: Optional[str] = None,
|
| 116 |
+
date_from: Optional[datetime] = None,
|
| 117 |
+
date_to: Optional[datetime] = None,
|
| 118 |
+
category: Optional[str] = None,
|
| 119 |
+
is_saved: Optional[bool] = None,
|
| 120 |
+
include_results: bool = False
|
| 121 |
+
) -> List[SearchHistory]:
|
| 122 |
+
"""
|
| 123 |
+
Get search history with filtering options.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
db: Database session
|
| 127 |
+
skip: Number of items to skip
|
| 128 |
+
limit: Maximum number of items to return
|
| 129 |
+
user_id: Filter by user ID
|
| 130 |
+
query_filter: Filter by search query (partial match)
|
| 131 |
+
date_from: Filter by timestamp (from)
|
| 132 |
+
date_to: Filter by timestamp (to)
|
| 133 |
+
category: Filter by category
|
| 134 |
+
is_saved: Filter by saved status
|
| 135 |
+
include_results: Whether to include search results
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
List of SearchHistory objects
|
| 139 |
+
"""
|
| 140 |
+
statement = select(SearchHistory)
|
| 141 |
+
|
| 142 |
+
# Apply filters
|
| 143 |
+
if user_id is not None:
|
| 144 |
+
statement = statement.where(SearchHistory.user_id == user_id)
|
| 145 |
+
|
| 146 |
+
if query_filter:
|
| 147 |
+
statement = statement.where(SearchHistory.query.ilike(f"%{query_filter}%"))
|
| 148 |
+
|
| 149 |
+
if date_from:
|
| 150 |
+
statement = statement.where(SearchHistory.timestamp >= date_from)
|
| 151 |
+
|
| 152 |
+
if date_to:
|
| 153 |
+
statement = statement.where(SearchHistory.timestamp <= date_to)
|
| 154 |
+
|
| 155 |
+
if category:
|
| 156 |
+
statement = statement.where(SearchHistory.category == category)
|
| 157 |
+
|
| 158 |
+
if is_saved is not None:
|
| 159 |
+
statement = statement.where(SearchHistory.is_saved == is_saved)
|
| 160 |
+
|
| 161 |
+
# Load related data if requested
|
| 162 |
+
if include_results:
|
| 163 |
+
statement = statement.options(selectinload(SearchHistory.search_results))
|
| 164 |
+
|
| 165 |
+
# Apply pagination
|
| 166 |
+
statement = statement.order_by(desc(SearchHistory.timestamp)).offset(skip).limit(limit)
|
| 167 |
+
|
| 168 |
+
result = await db.execute(statement)
|
| 169 |
+
return result.scalars().all()
|
| 170 |
+
|
| 171 |
+
async def get_search_by_id(
|
| 172 |
+
db: AsyncSession,
|
| 173 |
+
search_id: int,
|
| 174 |
+
include_results: bool = False
|
| 175 |
+
) -> Optional[SearchHistory]:
|
| 176 |
+
"""
|
| 177 |
+
Get a search history entry by ID.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
db: Database session
|
| 181 |
+
search_id: Search history ID
|
| 182 |
+
include_results: Whether to include search results
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
SearchHistory object or None if not found
|
| 186 |
+
"""
|
| 187 |
+
statement = select(SearchHistory).where(SearchHistory.id == search_id)
|
| 188 |
+
|
| 189 |
+
if include_results:
|
| 190 |
+
statement = statement.options(selectinload(SearchHistory.search_results))
|
| 191 |
+
|
| 192 |
+
result = await db.execute(statement)
|
| 193 |
+
return result.scalars().first()
|
| 194 |
+
|
| 195 |
+
async def delete_search_history(db: AsyncSession, search_id: int) -> bool:
|
| 196 |
+
"""
|
| 197 |
+
Delete a search history entry.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
db: Database session
|
| 201 |
+
search_id: ID of the search to delete
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
True if successful, False otherwise
|
| 205 |
+
"""
|
| 206 |
+
search = await get_search_by_id(db, search_id)
|
| 207 |
+
if not search:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
await db.delete(search)
|
| 211 |
+
await db.commit()
|
| 212 |
+
return True
|
| 213 |
+
|
| 214 |
+
async def save_search(
|
| 215 |
+
db: AsyncSession,
|
| 216 |
+
search_id: int,
|
| 217 |
+
is_saved: bool = True,
|
| 218 |
+
notes: Optional[str] = None,
|
| 219 |
+
tags: Optional[str] = None
|
| 220 |
+
) -> Optional[SearchHistory]:
|
| 221 |
+
"""
|
| 222 |
+
Save or unsave a search history entry.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
db: Database session
|
| 226 |
+
search_id: ID of the search
|
| 227 |
+
is_saved: Whether to save or unsave
|
| 228 |
+
notes: Optional notes to add
|
| 229 |
+
tags: Optional tags to add (comma-separated)
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Updated SearchHistory object or None if not found
|
| 233 |
+
"""
|
| 234 |
+
search = await get_search_by_id(db, search_id)
|
| 235 |
+
if not search:
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
search.is_saved = is_saved
|
| 239 |
+
|
| 240 |
+
if notes:
|
| 241 |
+
search.notes = notes
|
| 242 |
+
|
| 243 |
+
if tags:
|
| 244 |
+
search.tags = tags
|
| 245 |
+
|
| 246 |
+
await db.commit()
|
| 247 |
+
await db.refresh(search)
|
| 248 |
+
return search
|
| 249 |
+
|
| 250 |
+
async def create_saved_search(
|
| 251 |
+
db: AsyncSession,
|
| 252 |
+
name: str,
|
| 253 |
+
query: str,
|
| 254 |
+
user_id: int,
|
| 255 |
+
frequency: int = 24,
|
| 256 |
+
notification_enabled: bool = True,
|
| 257 |
+
threshold: int = 1,
|
| 258 |
+
category: Optional[str] = None
|
| 259 |
+
) -> SavedSearch:
|
| 260 |
+
"""
|
| 261 |
+
Create a new saved search with periodic monitoring.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
db: Database session
|
| 265 |
+
name: Name of the saved search
|
| 266 |
+
query: Search query
|
| 267 |
+
user_id: ID of the user
|
| 268 |
+
frequency: How often to run this search (in hours, 0 for manual only)
|
| 269 |
+
notification_enabled: Whether to send notifications for new results
|
| 270 |
+
threshold: Minimum number of new results for notification
|
| 271 |
+
category: Category of the search
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
The created SavedSearch object
|
| 275 |
+
"""
|
| 276 |
+
saved_search = SavedSearch(
|
| 277 |
+
name=name,
|
| 278 |
+
query=query,
|
| 279 |
+
user_id=user_id,
|
| 280 |
+
frequency=frequency,
|
| 281 |
+
notification_enabled=notification_enabled,
|
| 282 |
+
threshold=threshold,
|
| 283 |
+
category=category
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
db.add(saved_search)
|
| 287 |
+
await db.commit()
|
| 288 |
+
await db.refresh(saved_search)
|
| 289 |
+
|
| 290 |
+
return saved_search
|
| 291 |
+
|
| 292 |
+
async def get_saved_searches(
|
| 293 |
+
db: AsyncSession,
|
| 294 |
+
user_id: Optional[int] = None,
|
| 295 |
+
is_active: Optional[bool] = None,
|
| 296 |
+
skip: int = 0,
|
| 297 |
+
limit: int = 100
|
| 298 |
+
) -> List[SavedSearch]:
|
| 299 |
+
"""
|
| 300 |
+
Get saved searches with filtering options.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
db: Database session
|
| 304 |
+
user_id: Filter by user ID
|
| 305 |
+
is_active: Filter by active status
|
| 306 |
+
skip: Number of items to skip
|
| 307 |
+
limit: Maximum number of items to return
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
List of SavedSearch objects
|
| 311 |
+
"""
|
| 312 |
+
statement = select(SavedSearch)
|
| 313 |
+
|
| 314 |
+
# Apply filters
|
| 315 |
+
if user_id is not None:
|
| 316 |
+
statement = statement.where(SavedSearch.user_id == user_id)
|
| 317 |
+
|
| 318 |
+
if is_active is not None:
|
| 319 |
+
statement = statement.where(SavedSearch.is_active == is_active)
|
| 320 |
+
|
| 321 |
+
# Apply pagination
|
| 322 |
+
statement = statement.order_by(SavedSearch.name).offset(skip).limit(limit)
|
| 323 |
+
|
| 324 |
+
result = await db.execute(statement)
|
| 325 |
+
return result.scalars().all()
|
| 326 |
+
|
| 327 |
+
async def update_trend_data(
|
| 328 |
+
db: AsyncSession,
|
| 329 |
+
query: str,
|
| 330 |
+
category: Optional[str] = None
|
| 331 |
+
) -> None:
|
| 332 |
+
"""
|
| 333 |
+
Update trend data based on search queries.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
db: Database session
|
| 337 |
+
query: Search query
|
| 338 |
+
category: Category of the search
|
| 339 |
+
"""
|
| 340 |
+
# Split query into individual terms/topics
|
| 341 |
+
topics = [t.strip() for t in query.split() if len(t.strip()) > 3]
|
| 342 |
+
|
| 343 |
+
# Process each topic
|
| 344 |
+
for topic in topics:
|
| 345 |
+
# Check if topic already exists
|
| 346 |
+
statement = select(TrendTopic).where(TrendTopic.topic == topic)
|
| 347 |
+
result = await db.execute(statement)
|
| 348 |
+
trend_topic = result.scalars().first()
|
| 349 |
+
|
| 350 |
+
if trend_topic:
|
| 351 |
+
# Update existing topic
|
| 352 |
+
trend_topic.last_seen = datetime.utcnow()
|
| 353 |
+
trend_topic.mention_count += 1
|
| 354 |
+
|
| 355 |
+
# Calculate growth rate (percentage change over the last 24 hours)
|
| 356 |
+
time_diff = (trend_topic.last_seen - trend_topic.first_seen).total_seconds() / 3600 # hours
|
| 357 |
+
if time_diff > 0:
|
| 358 |
+
hourly_rate = trend_topic.mention_count / time_diff
|
| 359 |
+
trend_topic.growth_rate = hourly_rate * 24 # daily growth rate
|
| 360 |
+
|
| 361 |
+
# Update category if provided and not already set
|
| 362 |
+
if category and not trend_topic.category:
|
| 363 |
+
trend_topic.category = category
|
| 364 |
+
else:
|
| 365 |
+
# Create a new trend topic
|
| 366 |
+
trend_topic = TrendTopic(
|
| 367 |
+
topic=topic,
|
| 368 |
+
category=category,
|
| 369 |
+
mention_count=1,
|
| 370 |
+
growth_rate=1.0 # Initial growth rate
|
| 371 |
+
)
|
| 372 |
+
db.add(trend_topic)
|
| 373 |
+
|
| 374 |
+
await db.commit()
|
| 375 |
+
|
| 376 |
+
async def get_trending_topics(
|
| 377 |
+
db: AsyncSession,
|
| 378 |
+
days: int = 7,
|
| 379 |
+
limit: int = 20,
|
| 380 |
+
category: Optional[str] = None,
|
| 381 |
+
min_mentions: int = 3
|
| 382 |
+
) -> List[TrendTopic]:
|
| 383 |
+
"""
|
| 384 |
+
Get trending topics over a specific time period.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
db: Database session
|
| 388 |
+
days: Number of days to consider
|
| 389 |
+
limit: Maximum number of topics to return
|
| 390 |
+
category: Filter by category
|
| 391 |
+
min_mentions: Minimum number of mentions
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
List of TrendTopic objects sorted by growth rate
|
| 395 |
+
"""
|
| 396 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
| 397 |
+
|
| 398 |
+
statement = select(TrendTopic).where(
|
| 399 |
+
and_(
|
| 400 |
+
TrendTopic.last_seen >= cutoff_date,
|
| 401 |
+
TrendTopic.mention_count >= min_mentions,
|
| 402 |
+
TrendTopic.is_active == True
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if category:
|
| 407 |
+
statement = statement.where(TrendTopic.category == category)
|
| 408 |
+
|
| 409 |
+
statement = statement.order_by(desc(TrendTopic.growth_rate)).limit(limit)
|
| 410 |
+
|
| 411 |
+
result = await db.execute(statement)
|
| 412 |
+
return result.scalars().all()
|
| 413 |
+
|
| 414 |
+
async def get_search_frequency(
|
| 415 |
+
db: AsyncSession,
|
| 416 |
+
days: int = 30,
|
| 417 |
+
interval: str = 'day'
|
| 418 |
+
) -> List[Dict[str, Any]]:
|
| 419 |
+
"""
|
| 420 |
+
Get search frequency over time for visualization.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
db: Database session
|
| 424 |
+
days: Number of days to analyze
|
| 425 |
+
interval: Time interval ('hour', 'day', 'week', 'month')
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
List of dictionaries with time intervals and search counts
|
| 429 |
+
"""
|
| 430 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
| 431 |
+
|
| 432 |
+
# SQL query depends on the interval
|
| 433 |
+
if interval == 'hour':
|
| 434 |
+
date_format = "YYYY-MM-DD HH24:00"
|
| 435 |
+
trunc_expr = func.date_trunc('hour', SearchHistory.timestamp)
|
| 436 |
+
elif interval == 'day':
|
| 437 |
+
date_format = "YYYY-MM-DD"
|
| 438 |
+
trunc_expr = func.date_trunc('day', SearchHistory.timestamp)
|
| 439 |
+
elif interval == 'week':
|
| 440 |
+
date_format = "YYYY-WW"
|
| 441 |
+
trunc_expr = func.date_trunc('week', SearchHistory.timestamp)
|
| 442 |
+
else: # month
|
| 443 |
+
date_format = "YYYY-MM"
|
| 444 |
+
trunc_expr = func.date_trunc('month', SearchHistory.timestamp)
|
| 445 |
+
|
| 446 |
+
# Query for search count by interval
|
| 447 |
+
statement = select(
|
| 448 |
+
trunc_expr.label('interval'),
|
| 449 |
+
func.count(SearchHistory.id).label('count')
|
| 450 |
+
).where(
|
| 451 |
+
SearchHistory.timestamp >= cutoff_date
|
| 452 |
+
).group_by(
|
| 453 |
+
'interval'
|
| 454 |
+
).order_by(
|
| 455 |
+
'interval'
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
result = await db.execute(statement)
|
| 459 |
+
rows = result.all()
|
| 460 |
+
|
| 461 |
+
# Convert to list of dictionaries
|
| 462 |
+
return [{"interval": row.interval, "count": row.count} for row in rows]
|
| 463 |
+
|
| 464 |
+
async def get_popular_searches(
|
| 465 |
+
db: AsyncSession,
|
| 466 |
+
days: int = 30,
|
| 467 |
+
limit: int = 10
|
| 468 |
+
) -> List[Dict[str, Any]]:
|
| 469 |
+
"""
|
| 470 |
+
Get the most popular search terms.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
db: Database session
|
| 474 |
+
days: Number of days to analyze
|
| 475 |
+
limit: Maximum number of terms to return
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
List of dictionaries with search queries and counts
|
| 479 |
+
"""
|
| 480 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
| 481 |
+
|
| 482 |
+
statement = select(
|
| 483 |
+
SearchHistory.query,
|
| 484 |
+
func.count(SearchHistory.id).label('count')
|
| 485 |
+
).where(
|
| 486 |
+
SearchHistory.timestamp >= cutoff_date
|
| 487 |
+
).group_by(
|
| 488 |
+
SearchHistory.query
|
| 489 |
+
).order_by(
|
| 490 |
+
desc('count')
|
| 491 |
+
).limit(limit)
|
| 492 |
+
|
| 493 |
+
result = await db.execute(statement)
|
| 494 |
+
rows = result.all()
|
| 495 |
+
|
| 496 |
+
return [{"query": row.query, "count": row.count} for row in rows]
|
| 497 |
+
|
| 498 |
+
async def get_search_categories(
|
| 499 |
+
db: AsyncSession,
|
| 500 |
+
days: int = 30
|
| 501 |
+
) -> List[Dict[str, Any]]:
|
| 502 |
+
"""
|
| 503 |
+
Get distribution of search categories.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
db: Database session
|
| 507 |
+
days: Number of days to analyze
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
List of dictionaries with categories and counts
|
| 511 |
+
"""
|
| 512 |
+
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
| 513 |
+
|
| 514 |
+
statement = select(
|
| 515 |
+
SearchHistory.category,
|
| 516 |
+
func.count(SearchHistory.id).label('count')
|
| 517 |
+
).where(
|
| 518 |
+
and_(
|
| 519 |
+
SearchHistory.timestamp >= cutoff_date,
|
| 520 |
+
SearchHistory.category.is_not(None)
|
| 521 |
+
)
|
| 522 |
+
).group_by(
|
| 523 |
+
SearchHistory.category
|
| 524 |
+
).order_by(
|
| 525 |
+
desc('count')
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
result = await db.execute(statement)
|
| 529 |
+
rows = result.all()
|
| 530 |
+
|
| 531 |
+
return [{"category": row.category or "Uncategorized", "count": row.count} for row in rows]
|
| 532 |
+
|
| 533 |
+
async def get_search_trend_analysis(
|
| 534 |
+
db: AsyncSession,
|
| 535 |
+
days: int = 90,
|
| 536 |
+
trend_days: int = 7,
|
| 537 |
+
limit: int = 10
|
| 538 |
+
) -> Dict[str, Any]:
|
| 539 |
+
"""
|
| 540 |
+
Get comprehensive analysis of search trends.
|
| 541 |
+
|
| 542 |
+
Args:
|
| 543 |
+
db: Database session
|
| 544 |
+
days: Total days to analyze
|
| 545 |
+
trend_days: Days to calculate short-term trends
|
| 546 |
+
limit: Maximum number of items in each category
|
| 547 |
+
|
| 548 |
+
Returns:
|
| 549 |
+
Dictionary with various trend analyses
|
| 550 |
+
"""
|
| 551 |
+
# Get overall search frequency
|
| 552 |
+
frequency = await get_search_frequency(db, days, 'day')
|
| 553 |
+
|
| 554 |
+
# Get popular searches
|
| 555 |
+
popular = await get_popular_searches(db, days, limit)
|
| 556 |
+
|
| 557 |
+
# Get recent trending topics
|
| 558 |
+
trending = await get_trending_topics(db, trend_days, limit)
|
| 559 |
+
|
| 560 |
+
# Get category distribution
|
| 561 |
+
categories = await get_search_categories(db, days)
|
| 562 |
+
|
| 563 |
+
# Get recent (last 24 hours) vs. overall popular terms
|
| 564 |
+
recent_popular = await get_popular_searches(db, 1, limit)
|
| 565 |
+
|
| 566 |
+
# Calculate velocity (rate of change)
|
| 567 |
+
# This compares the last 7 days to the previous 7 days
|
| 568 |
+
cutoff_recent = datetime.utcnow() - timedelta(days=trend_days)
|
| 569 |
+
cutoff_previous = cutoff_recent - timedelta(days=trend_days)
|
| 570 |
+
|
| 571 |
+
# Query for velocity calculation
|
| 572 |
+
statement_recent = select(func.count(SearchHistory.id)).where(
|
| 573 |
+
SearchHistory.timestamp >= cutoff_recent
|
| 574 |
+
)
|
| 575 |
+
statement_previous = select(func.count(SearchHistory.id)).where(
|
| 576 |
+
and_(
|
| 577 |
+
SearchHistory.timestamp >= cutoff_previous,
|
| 578 |
+
SearchHistory.timestamp < cutoff_recent
|
| 579 |
+
)
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
result_recent = await db.execute(statement_recent)
|
| 583 |
+
result_previous = await db.execute(statement_previous)
|
| 584 |
+
|
| 585 |
+
count_recent = result_recent.scalar() or 0
|
| 586 |
+
count_previous = result_previous.scalar() or 0
|
| 587 |
+
|
| 588 |
+
if count_previous > 0:
|
| 589 |
+
velocity = (count_recent - count_previous) / count_previous * 100 # percentage change
|
| 590 |
+
else:
|
| 591 |
+
velocity = 100.0 if count_recent > 0 else 0.0
|
| 592 |
+
|
| 593 |
+
# Compile the results
|
| 594 |
+
return {
|
| 595 |
+
"frequency": frequency,
|
| 596 |
+
"popular_searches": popular,
|
| 597 |
+
"trending_topics": [
|
| 598 |
+
{"topic": t.topic, "mentions": t.mention_count, "growth_rate": t.growth_rate}
|
| 599 |
+
for t in trending
|
| 600 |
+
],
|
| 601 |
+
"categories": categories,
|
| 602 |
+
"recent_popular": recent_popular,
|
| 603 |
+
"velocity": velocity,
|
| 604 |
+
"total_searches": {
|
| 605 |
+
"total": count_recent + count_previous,
|
| 606 |
+
"recent": count_recent,
|
| 607 |
+
"previous": count_previous
|
| 608 |
+
}
|
| 609 |
+
}
|
src/api/services/subscription_service.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Subscription service.
|
| 3 |
+
|
| 4 |
+
This module provides functions for managing subscriptions.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import stripe
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
from sqlalchemy import select, update, delete
|
| 14 |
+
from sqlalchemy.orm import joinedload
|
| 15 |
+
|
| 16 |
+
from src.models.subscription import (
|
| 17 |
+
SubscriptionPlan, UserSubscription, PaymentHistory,
|
| 18 |
+
SubscriptionTier, BillingPeriod, SubscriptionStatus, PaymentStatus
|
| 19 |
+
)
|
| 20 |
+
from src.models.user import User
|
| 21 |
+
|
| 22 |
+
# Set up Stripe API key
|
| 23 |
+
stripe.api_key = os.environ.get("STRIPE_SECRET_KEY")
|
| 24 |
+
STRIPE_PUBLISHABLE_KEY = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
| 25 |
+
|
| 26 |
+
# Set up logging
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def get_subscription_plans(
|
| 31 |
+
db: AsyncSession,
|
| 32 |
+
active_only: bool = True
|
| 33 |
+
) -> List[SubscriptionPlan]:
|
| 34 |
+
"""
|
| 35 |
+
Get all subscription plans.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
db: Database session
|
| 39 |
+
active_only: If True, only return active plans
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
List of subscription plans
|
| 43 |
+
"""
|
| 44 |
+
query = select(SubscriptionPlan)
|
| 45 |
+
|
| 46 |
+
if active_only:
|
| 47 |
+
query = query.where(SubscriptionPlan.is_active == True)
|
| 48 |
+
|
| 49 |
+
result = await db.execute(query)
|
| 50 |
+
plans = result.scalars().all()
|
| 51 |
+
|
| 52 |
+
return plans
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
async def get_subscription_plan_by_id(
|
| 56 |
+
db: AsyncSession,
|
| 57 |
+
plan_id: int
|
| 58 |
+
) -> Optional[SubscriptionPlan]:
|
| 59 |
+
"""
|
| 60 |
+
Get a subscription plan by ID.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
db: Database session
|
| 64 |
+
plan_id: ID of the plan to get
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Subscription plan or None if not found
|
| 68 |
+
"""
|
| 69 |
+
query = select(SubscriptionPlan).where(SubscriptionPlan.id == plan_id)
|
| 70 |
+
result = await db.execute(query)
|
| 71 |
+
plan = result.scalars().first()
|
| 72 |
+
|
| 73 |
+
return plan
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def get_subscription_plan_by_tier(
|
| 77 |
+
db: AsyncSession,
|
| 78 |
+
tier: SubscriptionTier
|
| 79 |
+
) -> Optional[SubscriptionPlan]:
|
| 80 |
+
"""
|
| 81 |
+
Get a subscription plan by tier.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
db: Database session
|
| 85 |
+
tier: Tier of the plan to get
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Subscription plan or None if not found
|
| 89 |
+
"""
|
| 90 |
+
query = select(SubscriptionPlan).where(SubscriptionPlan.tier == tier)
|
| 91 |
+
result = await db.execute(query)
|
| 92 |
+
plan = result.scalars().first()
|
| 93 |
+
|
| 94 |
+
return plan
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def create_subscription_plan(
|
| 98 |
+
db: AsyncSession,
|
| 99 |
+
name: str,
|
| 100 |
+
tier: SubscriptionTier,
|
| 101 |
+
description: str,
|
| 102 |
+
price_monthly: float,
|
| 103 |
+
price_annually: float,
|
| 104 |
+
max_alerts: int = 10,
|
| 105 |
+
max_reports: int = 5,
|
| 106 |
+
max_searches_per_day: int = 20,
|
| 107 |
+
max_monitoring_keywords: int = 10,
|
| 108 |
+
max_data_retention_days: int = 30,
|
| 109 |
+
supports_api_access: bool = False,
|
| 110 |
+
supports_live_feed: bool = False,
|
| 111 |
+
supports_dark_web_monitoring: bool = False,
|
| 112 |
+
supports_export: bool = False,
|
| 113 |
+
supports_advanced_analytics: bool = False,
|
| 114 |
+
create_stripe_product: bool = True
|
| 115 |
+
) -> Optional[SubscriptionPlan]:
|
| 116 |
+
"""
|
| 117 |
+
Create a new subscription plan.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
db: Database session
|
| 121 |
+
name: Name of the plan
|
| 122 |
+
tier: Tier of the plan
|
| 123 |
+
description: Description of the plan
|
| 124 |
+
price_monthly: Monthly price of the plan
|
| 125 |
+
price_annually: Annual price of the plan
|
| 126 |
+
max_alerts: Maximum number of alerts allowed
|
| 127 |
+
max_reports: Maximum number of reports allowed
|
| 128 |
+
max_searches_per_day: Maximum number of searches per day
|
| 129 |
+
max_monitoring_keywords: Maximum number of monitoring keywords
|
| 130 |
+
max_data_retention_days: Maximum number of days to retain data
|
| 131 |
+
supports_api_access: Whether the plan supports API access
|
| 132 |
+
supports_live_feed: Whether the plan supports live feed
|
| 133 |
+
supports_dark_web_monitoring: Whether the plan supports dark web monitoring
|
| 134 |
+
supports_export: Whether the plan supports data export
|
| 135 |
+
supports_advanced_analytics: Whether the plan supports advanced analytics
|
| 136 |
+
create_stripe_product: Whether to create a Stripe product for this plan
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Created subscription plan or None if creation failed
|
| 140 |
+
"""
|
| 141 |
+
# Check if plan with the same tier already exists
|
| 142 |
+
existing_plan = await get_subscription_plan_by_tier(db, tier)
|
| 143 |
+
|
| 144 |
+
if existing_plan:
|
| 145 |
+
logger.warning(f"Subscription plan with tier {tier} already exists")
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
# Create Stripe product if requested
|
| 149 |
+
stripe_product_id = None
|
| 150 |
+
stripe_monthly_price_id = None
|
| 151 |
+
stripe_annual_price_id = None
|
| 152 |
+
|
| 153 |
+
if create_stripe_product and stripe.api_key:
|
| 154 |
+
try:
|
| 155 |
+
# Create Stripe product
|
| 156 |
+
product = stripe.Product.create(
|
| 157 |
+
name=name,
|
| 158 |
+
description=description,
|
| 159 |
+
metadata={
|
| 160 |
+
"tier": tier.value,
|
| 161 |
+
"max_alerts": max_alerts,
|
| 162 |
+
"max_reports": max_reports,
|
| 163 |
+
"max_searches_per_day": max_searches_per_day,
|
| 164 |
+
"max_monitoring_keywords": max_monitoring_keywords,
|
| 165 |
+
"max_data_retention_days": max_data_retention_days,
|
| 166 |
+
"supports_api_access": "yes" if supports_api_access else "no",
|
| 167 |
+
"supports_live_feed": "yes" if supports_live_feed else "no",
|
| 168 |
+
"supports_dark_web_monitoring": "yes" if supports_dark_web_monitoring else "no",
|
| 169 |
+
"supports_export": "yes" if supports_export else "no",
|
| 170 |
+
"supports_advanced_analytics": "yes" if supports_advanced_analytics else "no"
|
| 171 |
+
}
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
stripe_product_id = product.id
|
| 175 |
+
|
| 176 |
+
# Create monthly price
|
| 177 |
+
monthly_price = stripe.Price.create(
|
| 178 |
+
product=product.id,
|
| 179 |
+
unit_amount=int(price_monthly * 100), # Stripe uses cents
|
| 180 |
+
currency="usd",
|
| 181 |
+
recurring={"interval": "month"},
|
| 182 |
+
metadata={"billing_period": "monthly"}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
stripe_monthly_price_id = monthly_price.id
|
| 186 |
+
|
| 187 |
+
# Create annual price
|
| 188 |
+
annual_price = stripe.Price.create(
|
| 189 |
+
product=product.id,
|
| 190 |
+
unit_amount=int(price_annually * 100), # Stripe uses cents
|
| 191 |
+
currency="usd",
|
| 192 |
+
recurring={"interval": "year"},
|
| 193 |
+
metadata={"billing_period": "annually"}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
stripe_annual_price_id = annual_price.id
|
| 197 |
+
|
| 198 |
+
logger.info(f"Created Stripe product {product.id} for plan {name}")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"Failed to create Stripe product for plan {name}: {e}")
|
| 201 |
+
|
| 202 |
+
# Create plan in database
|
| 203 |
+
plan = SubscriptionPlan(
|
| 204 |
+
name=name,
|
| 205 |
+
tier=tier,
|
| 206 |
+
description=description,
|
| 207 |
+
price_monthly=price_monthly,
|
| 208 |
+
price_annually=price_annually,
|
| 209 |
+
max_alerts=max_alerts,
|
| 210 |
+
max_reports=max_reports,
|
| 211 |
+
max_searches_per_day=max_searches_per_day,
|
| 212 |
+
max_monitoring_keywords=max_monitoring_keywords,
|
| 213 |
+
max_data_retention_days=max_data_retention_days,
|
| 214 |
+
supports_api_access=supports_api_access,
|
| 215 |
+
supports_live_feed=supports_live_feed,
|
| 216 |
+
supports_dark_web_monitoring=supports_dark_web_monitoring,
|
| 217 |
+
supports_export=supports_export,
|
| 218 |
+
supports_advanced_analytics=supports_advanced_analytics,
|
| 219 |
+
stripe_product_id=stripe_product_id,
|
| 220 |
+
stripe_monthly_price_id=stripe_monthly_price_id,
|
| 221 |
+
stripe_annual_price_id=stripe_annual_price_id
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
db.add(plan)
|
| 225 |
+
await db.commit()
|
| 226 |
+
await db.refresh(plan)
|
| 227 |
+
|
| 228 |
+
return plan
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
async def update_subscription_plan(
|
| 232 |
+
db: AsyncSession,
|
| 233 |
+
plan_id: int,
|
| 234 |
+
name: Optional[str] = None,
|
| 235 |
+
description: Optional[str] = None,
|
| 236 |
+
price_monthly: Optional[float] = None,
|
| 237 |
+
price_annually: Optional[float] = None,
|
| 238 |
+
is_active: Optional[bool] = None,
|
| 239 |
+
max_alerts: Optional[int] = None,
|
| 240 |
+
max_reports: Optional[int] = None,
|
| 241 |
+
max_searches_per_day: Optional[int] = None,
|
| 242 |
+
max_monitoring_keywords: Optional[int] = None,
|
| 243 |
+
max_data_retention_days: Optional[int] = None,
|
| 244 |
+
supports_api_access: Optional[bool] = None,
|
| 245 |
+
supports_live_feed: Optional[bool] = None,
|
| 246 |
+
supports_dark_web_monitoring: Optional[bool] = None,
|
| 247 |
+
supports_export: Optional[bool] = None,
|
| 248 |
+
supports_advanced_analytics: Optional[bool] = None,
|
| 249 |
+
update_stripe_product: bool = True
|
| 250 |
+
) -> Optional[SubscriptionPlan]:
|
| 251 |
+
"""
|
| 252 |
+
Update a subscription plan.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
db: Database session
|
| 256 |
+
plan_id: ID of the plan to update
|
| 257 |
+
name: New name of the plan
|
| 258 |
+
description: New description of the plan
|
| 259 |
+
price_monthly: New monthly price of the plan
|
| 260 |
+
price_annually: New annual price of the plan
|
| 261 |
+
is_active: New active status of the plan
|
| 262 |
+
max_alerts: New maximum number of alerts allowed
|
| 263 |
+
max_reports: New maximum number of reports allowed
|
| 264 |
+
max_searches_per_day: New maximum number of searches per day
|
| 265 |
+
max_monitoring_keywords: New maximum number of monitoring keywords
|
| 266 |
+
max_data_retention_days: New maximum number of days to retain data
|
| 267 |
+
supports_api_access: New API access support status
|
| 268 |
+
supports_live_feed: New live feed support status
|
| 269 |
+
supports_dark_web_monitoring: New dark web monitoring support status
|
| 270 |
+
supports_export: New data export support status
|
| 271 |
+
supports_advanced_analytics: New advanced analytics support status
|
| 272 |
+
update_stripe_product: Whether to update the Stripe product for this plan
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Updated subscription plan or None if update failed
|
| 276 |
+
"""
|
| 277 |
+
# Get existing plan
|
| 278 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
| 279 |
+
|
| 280 |
+
if not plan:
|
| 281 |
+
logger.warning(f"Subscription plan with ID {plan_id} not found")
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
+
# Prepare update data
|
| 285 |
+
update_data = {}
|
| 286 |
+
|
| 287 |
+
if name is not None:
|
| 288 |
+
update_data["name"] = name
|
| 289 |
+
|
| 290 |
+
if description is not None:
|
| 291 |
+
update_data["description"] = description
|
| 292 |
+
|
| 293 |
+
if price_monthly is not None:
|
| 294 |
+
update_data["price_monthly"] = price_monthly
|
| 295 |
+
|
| 296 |
+
if price_annually is not None:
|
| 297 |
+
update_data["price_annually"] = price_annually
|
| 298 |
+
|
| 299 |
+
if is_active is not None:
|
| 300 |
+
update_data["is_active"] = is_active
|
| 301 |
+
|
| 302 |
+
if max_alerts is not None:
|
| 303 |
+
update_data["max_alerts"] = max_alerts
|
| 304 |
+
|
| 305 |
+
if max_reports is not None:
|
| 306 |
+
update_data["max_reports"] = max_reports
|
| 307 |
+
|
| 308 |
+
if max_searches_per_day is not None:
|
| 309 |
+
update_data["max_searches_per_day"] = max_searches_per_day
|
| 310 |
+
|
| 311 |
+
if max_monitoring_keywords is not None:
|
| 312 |
+
update_data["max_monitoring_keywords"] = max_monitoring_keywords
|
| 313 |
+
|
| 314 |
+
if max_data_retention_days is not None:
|
| 315 |
+
update_data["max_data_retention_days"] = max_data_retention_days
|
| 316 |
+
|
| 317 |
+
if supports_api_access is not None:
|
| 318 |
+
update_data["supports_api_access"] = supports_api_access
|
| 319 |
+
|
| 320 |
+
if supports_live_feed is not None:
|
| 321 |
+
update_data["supports_live_feed"] = supports_live_feed
|
| 322 |
+
|
| 323 |
+
if supports_dark_web_monitoring is not None:
|
| 324 |
+
update_data["supports_dark_web_monitoring"] = supports_dark_web_monitoring
|
| 325 |
+
|
| 326 |
+
if supports_export is not None:
|
| 327 |
+
update_data["supports_export"] = supports_export
|
| 328 |
+
|
| 329 |
+
if supports_advanced_analytics is not None:
|
| 330 |
+
update_data["supports_advanced_analytics"] = supports_advanced_analytics
|
| 331 |
+
|
| 332 |
+
# Update Stripe product if requested
|
| 333 |
+
if update_stripe_product and plan.stripe_product_id and stripe.api_key:
|
| 334 |
+
try:
|
| 335 |
+
# Update Stripe product
|
| 336 |
+
product_update_data = {}
|
| 337 |
+
|
| 338 |
+
if name is not None:
|
| 339 |
+
product_update_data["name"] = name
|
| 340 |
+
|
| 341 |
+
if description is not None:
|
| 342 |
+
product_update_data["description"] = description
|
| 343 |
+
|
| 344 |
+
metadata_update = {}
|
| 345 |
+
|
| 346 |
+
if max_alerts is not None:
|
| 347 |
+
metadata_update["max_alerts"] = max_alerts
|
| 348 |
+
|
| 349 |
+
if max_reports is not None:
|
| 350 |
+
metadata_update["max_reports"] = max_reports
|
| 351 |
+
|
| 352 |
+
if max_searches_per_day is not None:
|
| 353 |
+
metadata_update["max_searches_per_day"] = max_searches_per_day
|
| 354 |
+
|
| 355 |
+
if max_monitoring_keywords is not None:
|
| 356 |
+
metadata_update["max_monitoring_keywords"] = max_monitoring_keywords
|
| 357 |
+
|
| 358 |
+
if max_data_retention_days is not None:
|
| 359 |
+
metadata_update["max_data_retention_days"] = max_data_retention_days
|
| 360 |
+
|
| 361 |
+
if supports_api_access is not None:
|
| 362 |
+
metadata_update["supports_api_access"] = "yes" if supports_api_access else "no"
|
| 363 |
+
|
| 364 |
+
if supports_live_feed is not None:
|
| 365 |
+
metadata_update["supports_live_feed"] = "yes" if supports_live_feed else "no"
|
| 366 |
+
|
| 367 |
+
if supports_dark_web_monitoring is not None:
|
| 368 |
+
metadata_update["supports_dark_web_monitoring"] = "yes" if supports_dark_web_monitoring else "no"
|
| 369 |
+
|
| 370 |
+
if supports_export is not None:
|
| 371 |
+
metadata_update["supports_export"] = "yes" if supports_export else "no"
|
| 372 |
+
|
| 373 |
+
if supports_advanced_analytics is not None:
|
| 374 |
+
metadata_update["supports_advanced_analytics"] = "yes" if supports_advanced_analytics else "no"
|
| 375 |
+
|
| 376 |
+
if metadata_update:
|
| 377 |
+
product_update_data["metadata"] = metadata_update
|
| 378 |
+
|
| 379 |
+
if product_update_data:
|
| 380 |
+
stripe.Product.modify(plan.stripe_product_id, **product_update_data)
|
| 381 |
+
|
| 382 |
+
# Update prices if needed
|
| 383 |
+
if price_monthly is not None and plan.stripe_monthly_price_id:
|
| 384 |
+
# Can't update existing price in Stripe, create a new one
|
| 385 |
+
new_monthly_price = stripe.Price.create(
|
| 386 |
+
product=plan.stripe_product_id,
|
| 387 |
+
unit_amount=int(price_monthly * 100), # Stripe uses cents
|
| 388 |
+
currency="usd",
|
| 389 |
+
recurring={"interval": "month"},
|
| 390 |
+
metadata={"billing_period": "monthly"}
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
update_data["stripe_monthly_price_id"] = new_monthly_price.id
|
| 394 |
+
|
| 395 |
+
if price_annually is not None and plan.stripe_annual_price_id:
|
| 396 |
+
# Can't update existing price in Stripe, create a new one
|
| 397 |
+
new_annual_price = stripe.Price.create(
|
| 398 |
+
product=plan.stripe_product_id,
|
| 399 |
+
unit_amount=int(price_annually * 100), # Stripe uses cents
|
| 400 |
+
currency="usd",
|
| 401 |
+
recurring={"interval": "year"},
|
| 402 |
+
metadata={"billing_period": "annually"}
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
update_data["stripe_annual_price_id"] = new_annual_price.id
|
| 406 |
+
|
| 407 |
+
logger.info(f"Updated Stripe product {plan.stripe_product_id} for plan {plan.name}")
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logger.error(f"Failed to update Stripe product for plan {plan.name}: {e}")
|
| 410 |
+
|
| 411 |
+
# Update plan in database
|
| 412 |
+
if update_data:
|
| 413 |
+
await db.execute(
|
| 414 |
+
update(SubscriptionPlan)
|
| 415 |
+
.where(SubscriptionPlan.id == plan_id)
|
| 416 |
+
.values(**update_data)
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
await db.commit()
|
| 420 |
+
|
| 421 |
+
# Refresh plan
|
| 422 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
| 423 |
+
|
| 424 |
+
return plan
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
async def get_user_subscription(
|
| 428 |
+
db: AsyncSession,
|
| 429 |
+
user_id: int
|
| 430 |
+
) -> Optional[UserSubscription]:
|
| 431 |
+
"""
|
| 432 |
+
Get a user's active subscription.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
db: Database session
|
| 436 |
+
user_id: ID of the user
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
User subscription or None if not found
|
| 440 |
+
"""
|
| 441 |
+
query = (
|
| 442 |
+
select(UserSubscription)
|
| 443 |
+
.where(UserSubscription.user_id == user_id)
|
| 444 |
+
.where(UserSubscription.status != SubscriptionStatus.CANCELED)
|
| 445 |
+
.options(joinedload(UserSubscription.plan))
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
result = await db.execute(query)
|
| 449 |
+
subscription = result.scalars().first()
|
| 450 |
+
|
| 451 |
+
return subscription
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
async def get_user_subscription_by_id(
|
| 455 |
+
db: AsyncSession,
|
| 456 |
+
subscription_id: int
|
| 457 |
+
) -> Optional[UserSubscription]:
|
| 458 |
+
"""
|
| 459 |
+
Get a user subscription by ID.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
db: Database session
|
| 463 |
+
subscription_id: ID of the subscription
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
User subscription or None if not found
|
| 467 |
+
"""
|
| 468 |
+
query = (
|
| 469 |
+
select(UserSubscription)
|
| 470 |
+
.where(UserSubscription.id == subscription_id)
|
| 471 |
+
.options(joinedload(UserSubscription.plan))
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
result = await db.execute(query)
|
| 475 |
+
subscription = result.scalars().first()
|
| 476 |
+
|
| 477 |
+
return subscription
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
async def create_user_subscription(
|
| 481 |
+
db: AsyncSession,
|
| 482 |
+
user_id: int,
|
| 483 |
+
plan_id: int,
|
| 484 |
+
billing_period: BillingPeriod = BillingPeriod.MONTHLY,
|
| 485 |
+
create_stripe_subscription: bool = True,
|
| 486 |
+
payment_method_id: Optional[str] = None
|
| 487 |
+
) -> Optional[UserSubscription]:
|
| 488 |
+
"""
|
| 489 |
+
Create a new user subscription.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
db: Database session
|
| 493 |
+
user_id: ID of the user
|
| 494 |
+
plan_id: ID of the subscription plan
|
| 495 |
+
billing_period: Billing period (monthly or annually)
|
| 496 |
+
create_stripe_subscription: Whether to create a Stripe subscription
|
| 497 |
+
payment_method_id: ID of the payment method to use (required if create_stripe_subscription is True)
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
Created user subscription or None if creation failed
|
| 501 |
+
"""
|
| 502 |
+
# Check if user exists
|
| 503 |
+
query = select(User).where(User.id == user_id)
|
| 504 |
+
result = await db.execute(query)
|
| 505 |
+
user = result.scalars().first()
|
| 506 |
+
|
| 507 |
+
if not user:
|
| 508 |
+
logger.warning(f"User with ID {user_id} not found")
|
| 509 |
+
return None
|
| 510 |
+
|
| 511 |
+
# Check if plan exists
|
| 512 |
+
plan = await get_subscription_plan_by_id(db, plan_id)
|
| 513 |
+
|
| 514 |
+
if not plan:
|
| 515 |
+
logger.warning(f"Subscription plan with ID {plan_id} not found")
|
| 516 |
+
return None
|
| 517 |
+
|
| 518 |
+
# Check if user already has an active subscription
|
| 519 |
+
existing_subscription = await get_user_subscription(db, user_id)
|
| 520 |
+
|
| 521 |
+
if existing_subscription:
|
| 522 |
+
logger.warning(f"User with ID {user_id} already has an active subscription")
|
| 523 |
+
return None
|
| 524 |
+
|
| 525 |
+
# Calculate subscription period
|
| 526 |
+
now = datetime.utcnow()
|
| 527 |
+
|
| 528 |
+
if billing_period == BillingPeriod.MONTHLY:
|
| 529 |
+
current_period_end = now + timedelta(days=30)
|
| 530 |
+
price = plan.price_monthly
|
| 531 |
+
stripe_price_id = plan.stripe_monthly_price_id
|
| 532 |
+
elif billing_period == BillingPeriod.ANNUALLY:
|
| 533 |
+
current_period_end = now + timedelta(days=365)
|
| 534 |
+
price = plan.price_annually
|
| 535 |
+
stripe_price_id = plan.stripe_annual_price_id
|
| 536 |
+
else:
|
| 537 |
+
logger.warning(f"Invalid billing period: {billing_period}")
|
| 538 |
+
return None
|
| 539 |
+
|
| 540 |
+
# Create Stripe subscription if requested
|
| 541 |
+
stripe_subscription_id = None
|
| 542 |
+
stripe_customer_id = None
|
| 543 |
+
|
| 544 |
+
if create_stripe_subscription and stripe.api_key and plan.stripe_product_id:
|
| 545 |
+
if not payment_method_id:
|
| 546 |
+
logger.warning("Payment method ID is required to create a Stripe subscription")
|
| 547 |
+
return None
|
| 548 |
+
|
| 549 |
+
try:
|
| 550 |
+
# Create or retrieve Stripe customer
|
| 551 |
+
customers = stripe.Customer.list(email=user.email)
|
| 552 |
+
|
| 553 |
+
if customers.data:
|
| 554 |
+
customer = customers.data[0]
|
| 555 |
+
stripe_customer_id = customer.id
|
| 556 |
+
else:
|
| 557 |
+
customer = stripe.Customer.create(
|
| 558 |
+
email=user.email,
|
| 559 |
+
name=user.full_name,
|
| 560 |
+
metadata={"user_id": user_id}
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
stripe_customer_id = customer.id
|
| 564 |
+
|
| 565 |
+
# Attach payment method to customer
|
| 566 |
+
stripe.PaymentMethod.attach(
|
| 567 |
+
payment_method_id,
|
| 568 |
+
customer=stripe_customer_id
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Set as default payment method
|
| 572 |
+
stripe.Customer.modify(
|
| 573 |
+
stripe_customer_id,
|
| 574 |
+
invoice_settings={
|
| 575 |
+
"default_payment_method": payment_method_id
|
| 576 |
+
}
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Create subscription
|
| 580 |
+
subscription = stripe.Subscription.create(
|
| 581 |
+
customer=stripe_customer_id,
|
| 582 |
+
items=[
|
| 583 |
+
{"price": stripe_price_id}
|
| 584 |
+
],
|
| 585 |
+
expand=["latest_invoice.payment_intent"]
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
stripe_subscription_id = subscription.id
|
| 589 |
+
|
| 590 |
+
logger.info(f"Created Stripe subscription {subscription.id} for user {user_id}")
|
| 591 |
+
except Exception as e:
|
| 592 |
+
logger.error(f"Failed to create Stripe subscription for user {user_id}: {e}")
|
| 593 |
+
return None
|
| 594 |
+
|
| 595 |
+
# Create subscription in database
|
| 596 |
+
subscription = UserSubscription(
|
| 597 |
+
user_id=user_id,
|
| 598 |
+
plan_id=plan_id,
|
| 599 |
+
status=SubscriptionStatus.ACTIVE,
|
| 600 |
+
billing_period=billing_period,
|
| 601 |
+
current_period_start=now,
|
| 602 |
+
current_period_end=current_period_end,
|
| 603 |
+
stripe_subscription_id=stripe_subscription_id,
|
| 604 |
+
stripe_customer_id=stripe_customer_id
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
db.add(subscription)
|
| 608 |
+
await db.commit()
|
| 609 |
+
await db.refresh(subscription)
|
| 610 |
+
|
| 611 |
+
# Record payment
|
| 612 |
+
if subscription.id:
|
| 613 |
+
payment_status = PaymentStatus.SUCCEEDED if stripe_subscription_id else PaymentStatus.PENDING
|
| 614 |
+
|
| 615 |
+
payment = PaymentHistory(
|
| 616 |
+
user_id=user_id,
|
| 617 |
+
subscription_id=subscription.id,
|
| 618 |
+
amount=price,
|
| 619 |
+
currency="USD",
|
| 620 |
+
status=payment_status
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
db.add(payment)
|
| 624 |
+
await db.commit()
|
| 625 |
+
|
| 626 |
+
return subscription
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
async def cancel_user_subscription(
|
| 630 |
+
db: AsyncSession,
|
| 631 |
+
subscription_id: int,
|
| 632 |
+
cancel_stripe_subscription: bool = True
|
| 633 |
+
) -> Optional[UserSubscription]:
|
| 634 |
+
"""
|
| 635 |
+
Cancel a user subscription.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
db: Database session
|
| 639 |
+
subscription_id: ID of the subscription to cancel
|
| 640 |
+
cancel_stripe_subscription: Whether to cancel the Stripe subscription
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
Canceled user subscription or None if cancellation failed
|
| 644 |
+
"""
|
| 645 |
+
# Get subscription
|
| 646 |
+
subscription = await get_user_subscription_by_id(db, subscription_id)
|
| 647 |
+
|
| 648 |
+
if not subscription:
|
| 649 |
+
logger.warning(f"Subscription with ID {subscription_id} not found")
|
| 650 |
+
return None
|
| 651 |
+
|
| 652 |
+
# Cancel Stripe subscription if requested
|
| 653 |
+
if cancel_stripe_subscription and subscription.stripe_subscription_id and stripe.api_key:
|
| 654 |
+
try:
|
| 655 |
+
stripe.Subscription.modify(
|
| 656 |
+
subscription.stripe_subscription_id,
|
| 657 |
+
cancel_at_period_end=True
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
logger.info(f"Canceled Stripe subscription {subscription.stripe_subscription_id} at period end")
|
| 661 |
+
except Exception as e:
|
| 662 |
+
logger.error(f"Failed to cancel Stripe subscription {subscription.stripe_subscription_id}: {e}")
|
| 663 |
+
|
| 664 |
+
# Update subscription in database
|
| 665 |
+
now = datetime.utcnow()
|
| 666 |
+
|
| 667 |
+
await db.execute(
|
| 668 |
+
update(UserSubscription)
|
| 669 |
+
.where(UserSubscription.id == subscription_id)
|
| 670 |
+
.values(
|
| 671 |
+
status=SubscriptionStatus.CANCELED,
|
| 672 |
+
canceled_at=now
|
| 673 |
+
)
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
await db.commit()
|
| 677 |
+
|
| 678 |
+
# Refresh subscription
|
| 679 |
+
subscription = await get_user_subscription_by_id(db, subscription_id)
|
| 680 |
+
|
| 681 |
+
return subscription
|
src/api/services/threat_service.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service for threat operations.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy.future import select
|
| 6 |
+
from sqlalchemy import func, or_, and_
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from typing import List, Optional, Dict, Any, Union
|
| 9 |
+
|
| 10 |
+
from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory
|
| 11 |
+
from src.models.indicator import Indicator, IndicatorType
|
| 12 |
+
from src.api.schemas import PaginationParams
|
| 13 |
+
|
| 14 |
+
async def create_threat(
|
| 15 |
+
db: AsyncSession,
|
| 16 |
+
title: str,
|
| 17 |
+
description: str,
|
| 18 |
+
severity: ThreatSeverity,
|
| 19 |
+
category: ThreatCategory,
|
| 20 |
+
status: ThreatStatus = ThreatStatus.NEW,
|
| 21 |
+
source_url: Optional[str] = None,
|
| 22 |
+
source_name: Optional[str] = None,
|
| 23 |
+
source_type: Optional[str] = None,
|
| 24 |
+
affected_entity: Optional[str] = None,
|
| 25 |
+
affected_entity_type: Optional[str] = None,
|
| 26 |
+
confidence_score: float = 0.0,
|
| 27 |
+
risk_score: float = 0.0,
|
| 28 |
+
) -> Threat:
|
| 29 |
+
"""
|
| 30 |
+
Create a new threat.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
db: Database session
|
| 34 |
+
title: Threat title
|
| 35 |
+
description: Threat description
|
| 36 |
+
severity: Threat severity
|
| 37 |
+
category: Threat category
|
| 38 |
+
status: Threat status
|
| 39 |
+
source_url: URL of the source
|
| 40 |
+
source_name: Name of the source
|
| 41 |
+
source_type: Type of source
|
| 42 |
+
affected_entity: Name of affected entity
|
| 43 |
+
affected_entity_type: Type of affected entity
|
| 44 |
+
confidence_score: Confidence score (0-1)
|
| 45 |
+
risk_score: Risk score (0-1)
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Threat: Created threat
|
| 49 |
+
"""
|
| 50 |
+
db_threat = Threat(
|
| 51 |
+
title=title,
|
| 52 |
+
description=description,
|
| 53 |
+
severity=severity,
|
| 54 |
+
category=category,
|
| 55 |
+
status=status,
|
| 56 |
+
source_url=source_url,
|
| 57 |
+
source_name=source_name,
|
| 58 |
+
source_type=source_type,
|
| 59 |
+
discovered_at=datetime.utcnow(),
|
| 60 |
+
affected_entity=affected_entity,
|
| 61 |
+
affected_entity_type=affected_entity_type,
|
| 62 |
+
confidence_score=confidence_score,
|
| 63 |
+
risk_score=risk_score,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
db.add(db_threat)
|
| 67 |
+
await db.commit()
|
| 68 |
+
await db.refresh(db_threat)
|
| 69 |
+
|
| 70 |
+
return db_threat
|
| 71 |
+
|
| 72 |
+
async def get_threat_by_id(db: AsyncSession, threat_id: int) -> Optional[Threat]:
|
| 73 |
+
"""
|
| 74 |
+
Get threat by ID.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
db: Database session
|
| 78 |
+
threat_id: Threat ID
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Optional[Threat]: Threat or None if not found
|
| 82 |
+
"""
|
| 83 |
+
result = await db.execute(select(Threat).filter(Threat.id == threat_id))
|
| 84 |
+
return result.scalars().first()
|
| 85 |
+
|
| 86 |
+
async def get_threats(
|
| 87 |
+
db: AsyncSession,
|
| 88 |
+
pagination: PaginationParams,
|
| 89 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 90 |
+
status: Optional[List[ThreatStatus]] = None,
|
| 91 |
+
category: Optional[List[ThreatCategory]] = None,
|
| 92 |
+
search_query: Optional[str] = None,
|
| 93 |
+
from_date: Optional[datetime] = None,
|
| 94 |
+
to_date: Optional[datetime] = None,
|
| 95 |
+
) -> List[Threat]:
|
| 96 |
+
"""
|
| 97 |
+
Get threats with filtering and pagination.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
db: Database session
|
| 101 |
+
pagination: Pagination parameters
|
| 102 |
+
severity: Filter by severity
|
| 103 |
+
status: Filter by status
|
| 104 |
+
category: Filter by category
|
| 105 |
+
search_query: Search in title and description
|
| 106 |
+
from_date: Filter by discovered_at >= from_date
|
| 107 |
+
to_date: Filter by discovered_at <= to_date
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
List[Threat]: List of threats
|
| 111 |
+
"""
|
| 112 |
+
query = select(Threat)
|
| 113 |
+
|
| 114 |
+
# Apply filters
|
| 115 |
+
if severity:
|
| 116 |
+
query = query.filter(Threat.severity.in_(severity))
|
| 117 |
+
|
| 118 |
+
if status:
|
| 119 |
+
query = query.filter(Threat.status.in_(status))
|
| 120 |
+
|
| 121 |
+
if category:
|
| 122 |
+
query = query.filter(Threat.category.in_(category))
|
| 123 |
+
|
| 124 |
+
if search_query:
|
| 125 |
+
search_filter = or_(
|
| 126 |
+
Threat.title.ilike(f"%{search_query}%"),
|
| 127 |
+
Threat.description.ilike(f"%{search_query}%")
|
| 128 |
+
)
|
| 129 |
+
query = query.filter(search_filter)
|
| 130 |
+
|
| 131 |
+
if from_date:
|
| 132 |
+
query = query.filter(Threat.discovered_at >= from_date)
|
| 133 |
+
|
| 134 |
+
if to_date:
|
| 135 |
+
query = query.filter(Threat.discovered_at <= to_date)
|
| 136 |
+
|
| 137 |
+
# Apply pagination
|
| 138 |
+
query = query.order_by(Threat.discovered_at.desc())
|
| 139 |
+
query = query.offset((pagination.page - 1) * pagination.size).limit(pagination.size)
|
| 140 |
+
|
| 141 |
+
result = await db.execute(query)
|
| 142 |
+
return result.scalars().all()
|
| 143 |
+
|
| 144 |
+
async def count_threats(
|
| 145 |
+
db: AsyncSession,
|
| 146 |
+
severity: Optional[List[ThreatSeverity]] = None,
|
| 147 |
+
status: Optional[List[ThreatStatus]] = None,
|
| 148 |
+
category: Optional[List[ThreatCategory]] = None,
|
| 149 |
+
search_query: Optional[str] = None,
|
| 150 |
+
from_date: Optional[datetime] = None,
|
| 151 |
+
to_date: Optional[datetime] = None,
|
| 152 |
+
) -> int:
|
| 153 |
+
"""
|
| 154 |
+
Count threats with filtering.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
db: Database session
|
| 158 |
+
severity: Filter by severity
|
| 159 |
+
status: Filter by status
|
| 160 |
+
category: Filter by category
|
| 161 |
+
search_query: Search in title and description
|
| 162 |
+
from_date: Filter by discovered_at >= from_date
|
| 163 |
+
to_date: Filter by discovered_at <= to_date
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
int: Count of threats
|
| 167 |
+
"""
|
| 168 |
+
query = select(func.count(Threat.id))
|
| 169 |
+
|
| 170 |
+
# Apply filters (same as in get_threats)
|
| 171 |
+
if severity:
|
| 172 |
+
query = query.filter(Threat.severity.in_(severity))
|
| 173 |
+
|
| 174 |
+
if status:
|
| 175 |
+
query = query.filter(Threat.status.in_(status))
|
| 176 |
+
|
| 177 |
+
if category:
|
| 178 |
+
query = query.filter(Threat.category.in_(category))
|
| 179 |
+
|
| 180 |
+
if search_query:
|
| 181 |
+
search_filter = or_(
|
| 182 |
+
Threat.title.ilike(f"%{search_query}%"),
|
| 183 |
+
Threat.description.ilike(f"%{search_query}%")
|
| 184 |
+
)
|
| 185 |
+
query = query.filter(search_filter)
|
| 186 |
+
|
| 187 |
+
if from_date:
|
| 188 |
+
query = query.filter(Threat.discovered_at >= from_date)
|
| 189 |
+
|
| 190 |
+
if to_date:
|
| 191 |
+
query = query.filter(Threat.discovered_at <= to_date)
|
| 192 |
+
|
| 193 |
+
result = await db.execute(query)
|
| 194 |
+
return result.scalar()
|
| 195 |
+
|
| 196 |
+
async def update_threat(
|
| 197 |
+
db: AsyncSession,
|
| 198 |
+
threat_id: int,
|
| 199 |
+
title: Optional[str] = None,
|
| 200 |
+
description: Optional[str] = None,
|
| 201 |
+
severity: Optional[ThreatSeverity] = None,
|
| 202 |
+
status: Optional[ThreatStatus] = None,
|
| 203 |
+
category: Optional[ThreatCategory] = None,
|
| 204 |
+
affected_entity: Optional[str] = None,
|
| 205 |
+
affected_entity_type: Optional[str] = None,
|
| 206 |
+
confidence_score: Optional[float] = None,
|
| 207 |
+
risk_score: Optional[float] = None,
|
| 208 |
+
) -> Optional[Threat]:
|
| 209 |
+
"""
|
| 210 |
+
Update threat.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
db: Database session
|
| 214 |
+
threat_id: Threat ID
|
| 215 |
+
title: New title
|
| 216 |
+
description: New description
|
| 217 |
+
severity: New severity
|
| 218 |
+
status: New status
|
| 219 |
+
category: New category
|
| 220 |
+
affected_entity: New affected entity
|
| 221 |
+
affected_entity_type: New affected entity type
|
| 222 |
+
confidence_score: New confidence score
|
| 223 |
+
risk_score: New risk score
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Optional[Threat]: Updated threat or None if not found
|
| 227 |
+
"""
|
| 228 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 229 |
+
if not threat:
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
if title is not None:
|
| 233 |
+
threat.title = title
|
| 234 |
+
|
| 235 |
+
if description is not None:
|
| 236 |
+
threat.description = description
|
| 237 |
+
|
| 238 |
+
if severity is not None:
|
| 239 |
+
threat.severity = severity
|
| 240 |
+
|
| 241 |
+
if status is not None:
|
| 242 |
+
threat.status = status
|
| 243 |
+
|
| 244 |
+
if category is not None:
|
| 245 |
+
threat.category = category
|
| 246 |
+
|
| 247 |
+
if affected_entity is not None:
|
| 248 |
+
threat.affected_entity = affected_entity
|
| 249 |
+
|
| 250 |
+
if affected_entity_type is not None:
|
| 251 |
+
threat.affected_entity_type = affected_entity_type
|
| 252 |
+
|
| 253 |
+
if confidence_score is not None:
|
| 254 |
+
threat.confidence_score = confidence_score
|
| 255 |
+
|
| 256 |
+
if risk_score is not None:
|
| 257 |
+
threat.risk_score = risk_score
|
| 258 |
+
|
| 259 |
+
threat.updated_at = datetime.utcnow()
|
| 260 |
+
|
| 261 |
+
await db.commit()
|
| 262 |
+
await db.refresh(threat)
|
| 263 |
+
|
| 264 |
+
return threat
|
| 265 |
+
|
| 266 |
+
async def add_indicator_to_threat(
|
| 267 |
+
db: AsyncSession,
|
| 268 |
+
threat_id: int,
|
| 269 |
+
value: str,
|
| 270 |
+
indicator_type: IndicatorType,
|
| 271 |
+
description: Optional[str] = None,
|
| 272 |
+
is_verified: bool = False,
|
| 273 |
+
context: Optional[str] = None,
|
| 274 |
+
source: Optional[str] = None,
|
| 275 |
+
confidence_score: float = 0.0,
|
| 276 |
+
) -> Indicator:
|
| 277 |
+
"""
|
| 278 |
+
Add an indicator to a threat.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
db: Database session
|
| 282 |
+
threat_id: Threat ID
|
| 283 |
+
value: Indicator value
|
| 284 |
+
indicator_type: Indicator type
|
| 285 |
+
description: Description of the indicator
|
| 286 |
+
is_verified: Whether the indicator is verified
|
| 287 |
+
context: Context of the indicator
|
| 288 |
+
source: Source of the indicator
|
| 289 |
+
confidence_score: Confidence score (0-1)
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Indicator: Created indicator
|
| 293 |
+
"""
|
| 294 |
+
# Check if threat exists
|
| 295 |
+
threat = await get_threat_by_id(db, threat_id)
|
| 296 |
+
if not threat:
|
| 297 |
+
raise ValueError(f"Threat with ID {threat_id} not found")
|
| 298 |
+
|
| 299 |
+
# Create indicator
|
| 300 |
+
db_indicator = Indicator(
|
| 301 |
+
threat_id=threat_id,
|
| 302 |
+
value=value,
|
| 303 |
+
indicator_type=indicator_type,
|
| 304 |
+
description=description,
|
| 305 |
+
is_verified=is_verified,
|
| 306 |
+
context=context,
|
| 307 |
+
source=source,
|
| 308 |
+
confidence_score=confidence_score,
|
| 309 |
+
first_seen=datetime.utcnow(),
|
| 310 |
+
last_seen=datetime.utcnow(),
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
db.add(db_indicator)
|
| 314 |
+
await db.commit()
|
| 315 |
+
await db.refresh(db_indicator)
|
| 316 |
+
|
| 317 |
+
return db_indicator
|
| 318 |
+
|
| 319 |
+
async def get_threat_statistics(
|
| 320 |
+
db: AsyncSession,
|
| 321 |
+
from_date: Optional[datetime] = None,
|
| 322 |
+
to_date: Optional[datetime] = None,
|
| 323 |
+
) -> Dict[str, Any]:
|
| 324 |
+
"""
|
| 325 |
+
Get threat statistics.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
db: Database session
|
| 329 |
+
from_date: Filter by discovered_at >= from_date
|
| 330 |
+
to_date: Filter by discovered_at <= to_date
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Dict[str, Any]: Threat statistics
|
| 334 |
+
"""
|
| 335 |
+
# Set default time range if not provided
|
| 336 |
+
if not to_date:
|
| 337 |
+
to_date = datetime.utcnow()
|
| 338 |
+
|
| 339 |
+
if not from_date:
|
| 340 |
+
from_date = to_date - timedelta(days=30)
|
| 341 |
+
|
| 342 |
+
# Get count by severity
|
| 343 |
+
severity_counts = {}
|
| 344 |
+
for severity in ThreatSeverity:
|
| 345 |
+
query = select(func.count(Threat.id)).filter(and_(
|
| 346 |
+
Threat.severity == severity,
|
| 347 |
+
Threat.discovered_at >= from_date,
|
| 348 |
+
Threat.discovered_at <= to_date,
|
| 349 |
+
))
|
| 350 |
+
result = await db.execute(query)
|
| 351 |
+
severity_counts[severity.value] = result.scalar() or 0
|
| 352 |
+
|
| 353 |
+
# Get count by status
|
| 354 |
+
status_counts = {}
|
| 355 |
+
for status in ThreatStatus:
|
| 356 |
+
query = select(func.count(Threat.id)).filter(and_(
|
| 357 |
+
Threat.status == status,
|
| 358 |
+
Threat.discovered_at >= from_date,
|
| 359 |
+
Threat.discovered_at <= to_date,
|
| 360 |
+
))
|
| 361 |
+
result = await db.execute(query)
|
| 362 |
+
status_counts[status.value] = result.scalar() or 0
|
| 363 |
+
|
| 364 |
+
# Get count by category
|
| 365 |
+
category_counts = {}
|
| 366 |
+
for category in ThreatCategory:
|
| 367 |
+
query = select(func.count(Threat.id)).filter(and_(
|
| 368 |
+
Threat.category == category,
|
| 369 |
+
Threat.discovered_at >= from_date,
|
| 370 |
+
Threat.discovered_at <= to_date,
|
| 371 |
+
))
|
| 372 |
+
result = await db.execute(query)
|
| 373 |
+
category_counts[category.value] = result.scalar() or 0
|
| 374 |
+
|
| 375 |
+
# Get total count
|
| 376 |
+
query = select(func.count(Threat.id)).filter(and_(
|
| 377 |
+
Threat.discovered_at >= from_date,
|
| 378 |
+
Threat.discovered_at <= to_date,
|
| 379 |
+
))
|
| 380 |
+
result = await db.execute(query)
|
| 381 |
+
total_count = result.scalar() or 0
|
| 382 |
+
|
| 383 |
+
# Get count by day
|
| 384 |
+
time_series = []
|
| 385 |
+
current_date = from_date.date()
|
| 386 |
+
end_date = to_date.date()
|
| 387 |
+
|
| 388 |
+
while current_date <= end_date:
|
| 389 |
+
next_date = current_date + timedelta(days=1)
|
| 390 |
+
query = select(func.count(Threat.id)).filter(and_(
|
| 391 |
+
Threat.discovered_at >= datetime.combine(current_date, datetime.min.time()),
|
| 392 |
+
Threat.discovered_at < datetime.combine(next_date, datetime.min.time()),
|
| 393 |
+
))
|
| 394 |
+
result = await db.execute(query)
|
| 395 |
+
count = result.scalar() or 0
|
| 396 |
+
time_series.append({
|
| 397 |
+
"date": current_date.isoformat(),
|
| 398 |
+
"count": count
|
| 399 |
+
})
|
| 400 |
+
current_date = next_date
|
| 401 |
+
|
| 402 |
+
# Return statistics
|
| 403 |
+
return {
|
| 404 |
+
"total_count": total_count,
|
| 405 |
+
"severity_counts": severity_counts,
|
| 406 |
+
"status_counts": status_counts,
|
| 407 |
+
"category_counts": category_counts,
|
| 408 |
+
"time_series": time_series,
|
| 409 |
+
"from_date": from_date.isoformat(),
|
| 410 |
+
"to_date": to_date.isoformat(),
|
| 411 |
+
}
|
src/api/services/user_service.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 2 |
+
from sqlalchemy.future import select
|
| 3 |
+
from sqlalchemy import update
|
| 4 |
+
from passlib.context import CryptContext
|
| 5 |
+
from typing import Optional, List, Dict, Any
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from src.models.user import User
|
| 9 |
+
from src.api.schemas import UserCreate, UserUpdate, UserInDB
|
| 10 |
+
|
| 11 |
+
# Configure logger
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Password context for hashing and verification
|
| 15 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 16 |
+
|
| 17 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 18 |
+
"""
|
| 19 |
+
Verify password against hash.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
plain_password: Plain password
|
| 23 |
+
hashed_password: Hashed password
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
bool: True if password is correct
|
| 27 |
+
"""
|
| 28 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 29 |
+
|
| 30 |
+
def get_password_hash(password: str) -> str:
|
| 31 |
+
"""
|
| 32 |
+
Hash password.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
password: Plain password
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
str: Hashed password
|
| 39 |
+
"""
|
| 40 |
+
return pwd_context.hash(password)
|
| 41 |
+
|
| 42 |
+
async def get_user_by_username(db: AsyncSession, username: str) -> Optional[UserInDB]:
|
| 43 |
+
"""
|
| 44 |
+
Get user by username.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
db: Database session
|
| 48 |
+
username: Username
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Optional[UserInDB]: User if found, None otherwise
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
result = await db.execute(select(User).where(User.username == username))
|
| 55 |
+
user = result.scalars().first()
|
| 56 |
+
|
| 57 |
+
if not user:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
# Convert SQLAlchemy model to Pydantic model
|
| 61 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
| 62 |
+
return UserInDB(**user_dict)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Error getting user by username: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
async def authenticate_user(db: AsyncSession, username: str, password: str) -> Optional[UserInDB]:
|
| 68 |
+
"""
|
| 69 |
+
Authenticate user.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
db: Database session
|
| 73 |
+
username: Username
|
| 74 |
+
password: Plain password
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Optional[UserInDB]: User if authenticated, None otherwise
|
| 78 |
+
"""
|
| 79 |
+
user = await get_user_by_username(db, username)
|
| 80 |
+
|
| 81 |
+
if not user:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
if not verify_password(password, user.hashed_password):
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
return user
|
| 88 |
+
|
| 89 |
+
async def create_user(db: AsyncSession, user_data: UserCreate) -> Optional[UserInDB]:
|
| 90 |
+
"""
|
| 91 |
+
Create a new user.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
db: Database session
|
| 95 |
+
user_data: User data
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Optional[UserInDB]: Created user
|
| 99 |
+
"""
|
| 100 |
+
try:
|
| 101 |
+
# Check if user already exists
|
| 102 |
+
existing_user = await get_user_by_username(db, user_data.username)
|
| 103 |
+
if existing_user:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
# Create new user
|
| 107 |
+
hashed_password = get_password_hash(user_data.password)
|
| 108 |
+
user = User(
|
| 109 |
+
username=user_data.username,
|
| 110 |
+
email=user_data.email,
|
| 111 |
+
full_name=user_data.full_name,
|
| 112 |
+
hashed_password=hashed_password,
|
| 113 |
+
is_active=user_data.is_active
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
db.add(user)
|
| 117 |
+
await db.commit()
|
| 118 |
+
await db.refresh(user)
|
| 119 |
+
|
| 120 |
+
# Convert SQLAlchemy model to Pydantic model
|
| 121 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
| 122 |
+
return UserInDB(**user_dict)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Error creating user: {e}")
|
| 125 |
+
await db.rollback()
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
async def update_user(db: AsyncSession, user_id: int, user_data: UserUpdate) -> Optional[UserInDB]:
|
| 129 |
+
"""
|
| 130 |
+
Update user.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
db: Database session
|
| 134 |
+
user_id: User ID
|
| 135 |
+
user_data: User data
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Optional[UserInDB]: Updated user
|
| 139 |
+
"""
|
| 140 |
+
try:
|
| 141 |
+
# Create update dictionary
|
| 142 |
+
update_data = user_data.dict(exclude_unset=True)
|
| 143 |
+
|
| 144 |
+
# Hash password if provided
|
| 145 |
+
if "password" in update_data:
|
| 146 |
+
update_data["hashed_password"] = get_password_hash(update_data.pop("password"))
|
| 147 |
+
|
| 148 |
+
# Update user
|
| 149 |
+
stmt = update(User).where(User.id == user_id).values(**update_data)
|
| 150 |
+
await db.execute(stmt)
|
| 151 |
+
await db.commit()
|
| 152 |
+
|
| 153 |
+
# Get updated user
|
| 154 |
+
result = await db.execute(select(User).where(User.id == user_id))
|
| 155 |
+
user = result.scalars().first()
|
| 156 |
+
|
| 157 |
+
if not user:
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
# Convert SQLAlchemy model to Pydantic model
|
| 161 |
+
user_dict = {c.name: getattr(user, c.name) for c in user.__table__.columns}
|
| 162 |
+
return UserInDB(**user_dict)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Error updating user: {e}")
|
| 165 |
+
await db.rollback()
|
| 166 |
+
return None
|