Spaces:
Paused
Paused
| import re | |
| import uuid | |
| import time | |
| import datetime | |
| import logging | |
| from aiohttp import ClientSession | |
| from open_webui.models.auths import ( | |
| AddUserForm, | |
| ApiKey, | |
| Auths, | |
| Token, | |
| LdapForm, | |
| SigninForm, | |
| SigninResponse, | |
| SignupForm, | |
| UpdatePasswordForm, | |
| UpdateProfileForm, | |
| UserResponse, | |
| ) | |
| from open_webui.models.users import Users | |
| from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES | |
| from open_webui.env import ( | |
| WEBUI_AUTH, | |
| WEBUI_AUTH_TRUSTED_EMAIL_HEADER, | |
| WEBUI_AUTH_TRUSTED_NAME_HEADER, | |
| WEBUI_AUTH_COOKIE_SAME_SITE, | |
| WEBUI_AUTH_COOKIE_SECURE, | |
| WEBUI_AUTH_SIGNOUT_REDIRECT_URL, | |
| SRC_LOG_LEVELS, | |
| ) | |
| from fastapi import APIRouter, Depends, HTTPException, Request, status | |
| from fastapi.responses import RedirectResponse, Response | |
| from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP | |
| from pydantic import BaseModel | |
| from open_webui.utils.misc import parse_duration, validate_email_format | |
| from open_webui.utils.auth import ( | |
| decode_token, | |
| create_api_key, | |
| create_token, | |
| get_admin_user, | |
| get_verified_user, | |
| get_current_user, | |
| get_password_hash, | |
| get_http_authorization_cred, | |
| ) | |
| from open_webui.utils.webhook import post_webhook | |
| from open_webui.utils.access_control import get_permissions | |
| from typing import Optional, List | |
| from ssl import CERT_REQUIRED, PROTOCOL_TLS | |
| if ENABLE_LDAP.value: | |
| from ldap3 import Server, Connection, NONE, Tls | |
| from ldap3.utils.conv import escape_filter_chars | |
| router = APIRouter() | |
| log = logging.getLogger(__name__) | |
| log.setLevel(SRC_LOG_LEVELS["MAIN"]) | |
| ############################ | |
| # GetSessionUser | |
| ############################ | |
| class SessionUserResponse(Token, UserResponse): | |
| expires_at: Optional[int] = None | |
| permissions: Optional[dict] = None | |
| async def get_session_user( | |
| request: Request, response: Response, user=Depends(get_current_user) | |
| ): | |
| auth_header = request.headers.get("Authorization") | |
| auth_token = get_http_authorization_cred(auth_header) | |
| token = auth_token.credentials | |
| data = decode_token(token) | |
| expires_at = None | |
| if data: | |
| expires_at = data.get("exp") | |
| if (expires_at is not None) and int(time.time()) > expires_at: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.INVALID_TOKEN, | |
| ) | |
| # Set the cookie token | |
| response.set_cookie( | |
| key="token", | |
| value=token, | |
| expires=( | |
| datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) | |
| if expires_at | |
| else None | |
| ), | |
| httponly=True, # Ensures the cookie is not accessible via JavaScript | |
| samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | |
| secure=WEBUI_AUTH_COOKIE_SECURE, | |
| ) | |
| user_permissions = get_permissions( | |
| user.id, request.app.state.config.USER_PERMISSIONS | |
| ) | |
| return { | |
| "token": token, | |
| "token_type": "Bearer", | |
| "expires_at": expires_at, | |
| "id": user.id, | |
| "email": user.email, | |
| "name": user.name, | |
| "role": user.role, | |
| "profile_image_url": user.profile_image_url, | |
| "permissions": user_permissions, | |
| } | |
| ############################ | |
| # Update Profile | |
| ############################ | |
| async def update_profile( | |
| form_data: UpdateProfileForm, session_user=Depends(get_verified_user) | |
| ): | |
| if session_user: | |
| user = Users.update_user_by_id( | |
| session_user.id, | |
| {"profile_image_url": form_data.profile_image_url, "name": form_data.name}, | |
| ) | |
| if user: | |
| return user | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT()) | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | |
| ############################ | |
| # Update Password | |
| ############################ | |
| async def update_password( | |
| form_data: UpdatePasswordForm, session_user=Depends(get_current_user) | |
| ): | |
| if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) | |
| if session_user: | |
| user = Auths.authenticate_user(session_user.email, form_data.password) | |
| if user: | |
| hashed = get_password_hash(form_data.new_password) | |
| return Auths.update_user_password_by_id(user.id, hashed) | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | |
| ############################ | |
| # LDAP Authentication | |
| ############################ | |
| async def ldap_auth(request: Request, response: Response, form_data: LdapForm): | |
| ENABLE_LDAP = request.app.state.config.ENABLE_LDAP | |
| LDAP_SERVER_LABEL = request.app.state.config.LDAP_SERVER_LABEL | |
| LDAP_SERVER_HOST = request.app.state.config.LDAP_SERVER_HOST | |
| LDAP_SERVER_PORT = request.app.state.config.LDAP_SERVER_PORT | |
| LDAP_ATTRIBUTE_FOR_MAIL = request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL | |
| LDAP_ATTRIBUTE_FOR_USERNAME = request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME | |
| LDAP_SEARCH_BASE = request.app.state.config.LDAP_SEARCH_BASE | |
| LDAP_SEARCH_FILTERS = request.app.state.config.LDAP_SEARCH_FILTERS | |
| LDAP_APP_DN = request.app.state.config.LDAP_APP_DN | |
| LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD | |
| LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS | |
| LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE | |
| LDAP_CIPHERS = ( | |
| request.app.state.config.LDAP_CIPHERS | |
| if request.app.state.config.LDAP_CIPHERS | |
| else "ALL" | |
| ) | |
| if not ENABLE_LDAP: | |
| raise HTTPException(400, detail="LDAP authentication is not enabled") | |
| try: | |
| tls = Tls( | |
| validate=CERT_REQUIRED, | |
| version=PROTOCOL_TLS, | |
| ca_certs_file=LDAP_CA_CERT_FILE, | |
| ciphers=LDAP_CIPHERS, | |
| ) | |
| except Exception as e: | |
| log.error(f"TLS configuration error: {str(e)}") | |
| raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.") | |
| try: | |
| server = Server( | |
| host=LDAP_SERVER_HOST, | |
| port=LDAP_SERVER_PORT, | |
| get_info=NONE, | |
| use_ssl=LDAP_USE_TLS, | |
| tls=tls, | |
| ) | |
| connection_app = Connection( | |
| server, | |
| LDAP_APP_DN, | |
| LDAP_APP_PASSWORD, | |
| auto_bind="NONE", | |
| authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS", | |
| ) | |
| if not connection_app.bind(): | |
| raise HTTPException(400, detail="Application account bind failed") | |
| search_success = connection_app.search( | |
| search_base=LDAP_SEARCH_BASE, | |
| search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", | |
| attributes=[ | |
| f"{LDAP_ATTRIBUTE_FOR_USERNAME}", | |
| f"{LDAP_ATTRIBUTE_FOR_MAIL}", | |
| "cn", | |
| ], | |
| ) | |
| if not search_success or not connection_app.entries: | |
| raise HTTPException(400, detail="User not found in the LDAP server") | |
| entry = connection_app.entries[0] | |
| username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() | |
| email = entry[ | |
| f"{LDAP_ATTRIBUTE_FOR_MAIL}" | |
| ].value # retrieve the Attribute value | |
| if not email: | |
| raise HTTPException(400, "User does not have a valid email address.") | |
| elif isinstance(email, str): | |
| email = email.lower() | |
| elif isinstance(email, list): | |
| email = email[0].lower() | |
| else: | |
| email = str(email).lower() | |
| cn = str(entry["cn"]) | |
| user_dn = entry.entry_dn | |
| if username == form_data.user.lower(): | |
| connection_user = Connection( | |
| server, | |
| user_dn, | |
| form_data.password, | |
| auto_bind="NONE", | |
| authentication="SIMPLE", | |
| ) | |
| if not connection_user.bind(): | |
| raise HTTPException(400, "Authentication failed.") | |
| user = Users.get_user_by_email(email) | |
| if not user: | |
| try: | |
| user_count = Users.get_num_users() | |
| role = ( | |
| "admin" | |
| if user_count == 0 | |
| else request.app.state.config.DEFAULT_USER_ROLE | |
| ) | |
| user = Auths.insert_new_auth( | |
| email=email, | |
| password=str(uuid.uuid4()), | |
| name=cn, | |
| role=role, | |
| ) | |
| if not user: | |
| raise HTTPException( | |
| 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as err: | |
| log.error(f"LDAP user creation error: {str(err)}") | |
| raise HTTPException( | |
| 500, detail="Internal error occurred during LDAP user creation." | |
| ) | |
| user = Auths.authenticate_user_by_trusted_header(email) | |
| if user: | |
| expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) | |
| expires_at = None | |
| if expires_delta: | |
| expires_at = int(time.time()) + int(expires_delta.total_seconds()) | |
| token = create_token( | |
| data={"id": user.id}, | |
| expires_delta=expires_delta, | |
| ) | |
| # Set the cookie token | |
| response.set_cookie( | |
| key="token", | |
| value=token, | |
| expires=( | |
| datetime.datetime.fromtimestamp( | |
| expires_at, datetime.timezone.utc | |
| ) | |
| if expires_at | |
| else None | |
| ), | |
| httponly=True, # Ensures the cookie is not accessible via JavaScript | |
| samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | |
| secure=WEBUI_AUTH_COOKIE_SECURE, | |
| ) | |
| user_permissions = get_permissions( | |
| user.id, request.app.state.config.USER_PERMISSIONS | |
| ) | |
| return { | |
| "token": token, | |
| "token_type": "Bearer", | |
| "expires_at": expires_at, | |
| "id": user.id, | |
| "email": user.email, | |
| "name": user.name, | |
| "role": user.role, | |
| "profile_image_url": user.profile_image_url, | |
| "permissions": user_permissions, | |
| } | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | |
| else: | |
| raise HTTPException(400, "User record mismatch.") | |
| except Exception as e: | |
| log.error(f"LDAP authentication error: {str(e)}") | |
| raise HTTPException(400, detail="LDAP authentication failed.") | |
| ############################ | |
| # SignIn | |
| ############################ | |
| async def signin(request: Request, response: Response, form_data: SigninForm): | |
| if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: | |
| if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) | |
| trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() | |
| trusted_name = trusted_email | |
| if WEBUI_AUTH_TRUSTED_NAME_HEADER: | |
| trusted_name = request.headers.get( | |
| WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email | |
| ) | |
| if not Users.get_user_by_email(trusted_email.lower()): | |
| await signup( | |
| request, | |
| response, | |
| SignupForm( | |
| email=trusted_email, password=str(uuid.uuid4()), name=trusted_name | |
| ), | |
| ) | |
| user = Auths.authenticate_user_by_trusted_header(trusted_email) | |
| elif WEBUI_AUTH == False: | |
| admin_email = "admin@localhost" | |
| admin_password = "admin" | |
| if Users.get_user_by_email(admin_email.lower()): | |
| user = Auths.authenticate_user(admin_email.lower(), admin_password) | |
| else: | |
| if Users.get_num_users() != 0: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) | |
| await signup( | |
| request, | |
| response, | |
| SignupForm(email=admin_email, password=admin_password, name="User"), | |
| ) | |
| user = Auths.authenticate_user(admin_email.lower(), admin_password) | |
| else: | |
| user = Auths.authenticate_user(form_data.email.lower(), form_data.password) | |
| if user: | |
| expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) | |
| expires_at = None | |
| if expires_delta: | |
| expires_at = int(time.time()) + int(expires_delta.total_seconds()) | |
| token = create_token( | |
| data={"id": user.id}, | |
| expires_delta=expires_delta, | |
| ) | |
| datetime_expires_at = ( | |
| datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) | |
| if expires_at | |
| else None | |
| ) | |
| # Set the cookie token | |
| response.set_cookie( | |
| key="token", | |
| value=token, | |
| expires=datetime_expires_at, | |
| httponly=True, # Ensures the cookie is not accessible via JavaScript | |
| samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | |
| secure=WEBUI_AUTH_COOKIE_SECURE, | |
| ) | |
| user_permissions = get_permissions( | |
| user.id, request.app.state.config.USER_PERMISSIONS | |
| ) | |
| return { | |
| "token": token, | |
| "token_type": "Bearer", | |
| "expires_at": expires_at, | |
| "id": user.id, | |
| "email": user.email, | |
| "name": user.name, | |
| "role": user.role, | |
| "profile_image_url": user.profile_image_url, | |
| "permissions": user_permissions, | |
| } | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) | |
| ############################ | |
| # SignUp | |
| ############################ | |
| async def signup(request: Request, response: Response, form_data: SignupForm): | |
| if WEBUI_AUTH: | |
| if ( | |
| not request.app.state.config.ENABLE_SIGNUP | |
| or not request.app.state.config.ENABLE_LOGIN_FORM | |
| ): | |
| raise HTTPException( | |
| status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | |
| ) | |
| else: | |
| if Users.get_num_users() != 0: | |
| raise HTTPException( | |
| status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED | |
| ) | |
| user_count = Users.get_num_users() | |
| if not validate_email_format(form_data.email.lower()): | |
| raise HTTPException( | |
| status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT | |
| ) | |
| if Users.get_user_by_email(form_data.email.lower()): | |
| raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | |
| try: | |
| role = ( | |
| "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE | |
| ) | |
| # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. | |
| if len(form_data.password.encode("utf-8")) > 72: | |
| raise HTTPException( | |
| status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.PASSWORD_TOO_LONG, | |
| ) | |
| hashed = get_password_hash(form_data.password) | |
| user = Auths.insert_new_auth( | |
| form_data.email.lower(), | |
| hashed, | |
| form_data.name, | |
| form_data.profile_image_url, | |
| role, | |
| ) | |
| if user: | |
| expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) | |
| expires_at = None | |
| if expires_delta: | |
| expires_at = int(time.time()) + int(expires_delta.total_seconds()) | |
| token = create_token( | |
| data={"id": user.id}, | |
| expires_delta=expires_delta, | |
| ) | |
| datetime_expires_at = ( | |
| datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc) | |
| if expires_at | |
| else None | |
| ) | |
| # Set the cookie token | |
| response.set_cookie( | |
| key="token", | |
| value=token, | |
| expires=datetime_expires_at, | |
| httponly=True, # Ensures the cookie is not accessible via JavaScript | |
| samesite=WEBUI_AUTH_COOKIE_SAME_SITE, | |
| secure=WEBUI_AUTH_COOKIE_SECURE, | |
| ) | |
| if request.app.state.config.WEBHOOK_URL: | |
| post_webhook( | |
| request.app.state.WEBUI_NAME, | |
| request.app.state.config.WEBHOOK_URL, | |
| WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | |
| { | |
| "action": "signup", | |
| "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), | |
| "user": user.model_dump_json(exclude_none=True), | |
| }, | |
| ) | |
| user_permissions = get_permissions( | |
| user.id, request.app.state.config.USER_PERMISSIONS | |
| ) | |
| if user_count == 0: | |
| # Disable signup after the first user is created | |
| request.app.state.config.ENABLE_SIGNUP = False | |
| return { | |
| "token": token, | |
| "token_type": "Bearer", | |
| "expires_at": expires_at, | |
| "id": user.id, | |
| "email": user.email, | |
| "name": user.name, | |
| "role": user.role, | |
| "profile_image_url": user.profile_image_url, | |
| "permissions": user_permissions, | |
| } | |
| else: | |
| raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) | |
| except Exception as err: | |
| log.error(f"Signup error: {str(err)}") | |
| raise HTTPException(500, detail="An internal error occurred during signup.") | |
| async def signout(request: Request, response: Response): | |
| response.delete_cookie("token") | |
| if ENABLE_OAUTH_SIGNUP.value: | |
| oauth_id_token = request.cookies.get("oauth_id_token") | |
| if oauth_id_token: | |
| try: | |
| async with ClientSession() as session: | |
| async with session.get(OPENID_PROVIDER_URL.value) as resp: | |
| if resp.status == 200: | |
| openid_data = await resp.json() | |
| logout_url = openid_data.get("end_session_endpoint") | |
| if logout_url: | |
| response.delete_cookie("oauth_id_token") | |
| return RedirectResponse( | |
| headers=response.headers, | |
| url=f"{logout_url}?id_token_hint={oauth_id_token}", | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=resp.status, | |
| detail="Failed to fetch OpenID configuration", | |
| ) | |
| except Exception as e: | |
| log.error(f"OpenID signout error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Failed to sign out from the OpenID provider.", | |
| ) | |
| if WEBUI_AUTH_SIGNOUT_REDIRECT_URL: | |
| return RedirectResponse( | |
| headers=response.headers, | |
| url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL, | |
| ) | |
| return {"status": True} | |
| ############################ | |
| # AddUser | |
| ############################ | |
| async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): | |
| if not validate_email_format(form_data.email.lower()): | |
| raise HTTPException( | |
| status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT | |
| ) | |
| if Users.get_user_by_email(form_data.email.lower()): | |
| raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) | |
| try: | |
| hashed = get_password_hash(form_data.password) | |
| user = Auths.insert_new_auth( | |
| form_data.email.lower(), | |
| hashed, | |
| form_data.name, | |
| form_data.profile_image_url, | |
| form_data.role, | |
| ) | |
| if user: | |
| token = create_token(data={"id": user.id}) | |
| return { | |
| "token": token, | |
| "token_type": "Bearer", | |
| "id": user.id, | |
| "email": user.email, | |
| "name": user.name, | |
| "role": user.role, | |
| "profile_image_url": user.profile_image_url, | |
| } | |
| else: | |
| raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) | |
| except Exception as err: | |
| log.error(f"Add user error: {str(err)}") | |
| raise HTTPException( | |
| 500, detail="An internal error occurred while adding the user." | |
| ) | |
| ############################ | |
| # GetAdminDetails | |
| ############################ | |
| async def get_admin_details(request: Request, user=Depends(get_current_user)): | |
| if request.app.state.config.SHOW_ADMIN_DETAILS: | |
| admin_email = request.app.state.config.ADMIN_EMAIL | |
| admin_name = None | |
| log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") | |
| if admin_email: | |
| admin = Users.get_user_by_email(admin_email) | |
| if admin: | |
| admin_name = admin.name | |
| else: | |
| admin = Users.get_first_user() | |
| if admin: | |
| admin_email = admin.email | |
| admin_name = admin.name | |
| return { | |
| "name": admin_name, | |
| "email": admin_email, | |
| } | |
| else: | |
| raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) | |
| ############################ | |
| # ToggleSignUp | |
| ############################ | |
| async def get_admin_config(request: Request, user=Depends(get_admin_user)): | |
| return { | |
| "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, | |
| "WEBUI_URL": request.app.state.config.WEBUI_URL, | |
| "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, | |
| "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, | |
| "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, | |
| "API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, | |
| "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, | |
| "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, | |
| "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, | |
| "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, | |
| "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, | |
| "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, | |
| "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, | |
| } | |
| class AdminConfig(BaseModel): | |
| SHOW_ADMIN_DETAILS: bool | |
| WEBUI_URL: str | |
| ENABLE_SIGNUP: bool | |
| ENABLE_API_KEY: bool | |
| ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool | |
| API_KEY_ALLOWED_ENDPOINTS: str | |
| DEFAULT_USER_ROLE: str | |
| JWT_EXPIRES_IN: str | |
| ENABLE_COMMUNITY_SHARING: bool | |
| ENABLE_MESSAGE_RATING: bool | |
| ENABLE_CHANNELS: bool | |
| ENABLE_NOTES: bool | |
| ENABLE_USER_WEBHOOKS: bool | |
| async def update_admin_config( | |
| request: Request, form_data: AdminConfig, user=Depends(get_admin_user) | |
| ): | |
| request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS | |
| request.app.state.config.WEBUI_URL = form_data.WEBUI_URL | |
| request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP | |
| request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY | |
| request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = ( | |
| form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS | |
| ) | |
| request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = ( | |
| form_data.API_KEY_ALLOWED_ENDPOINTS | |
| ) | |
| request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS | |
| request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES | |
| if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]: | |
| request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE | |
| pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$" | |
| # Check if the input string matches the pattern | |
| if re.match(pattern, form_data.JWT_EXPIRES_IN): | |
| request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN | |
| request.app.state.config.ENABLE_COMMUNITY_SHARING = ( | |
| form_data.ENABLE_COMMUNITY_SHARING | |
| ) | |
| request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING | |
| request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS | |
| return { | |
| "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS, | |
| "WEBUI_URL": request.app.state.config.WEBUI_URL, | |
| "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP, | |
| "ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY, | |
| "ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS, | |
| "API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS, | |
| "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE, | |
| "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN, | |
| "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING, | |
| "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING, | |
| "ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS, | |
| "ENABLE_NOTES": request.app.state.config.ENABLE_NOTES, | |
| "ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS, | |
| } | |
| class LdapServerConfig(BaseModel): | |
| label: str | |
| host: str | |
| port: Optional[int] = None | |
| attribute_for_mail: str = "mail" | |
| attribute_for_username: str = "uid" | |
| app_dn: str | |
| app_dn_password: str | |
| search_base: str | |
| search_filters: str = "" | |
| use_tls: bool = True | |
| certificate_path: Optional[str] = None | |
| ciphers: Optional[str] = "ALL" | |
| async def get_ldap_server(request: Request, user=Depends(get_admin_user)): | |
| return { | |
| "label": request.app.state.config.LDAP_SERVER_LABEL, | |
| "host": request.app.state.config.LDAP_SERVER_HOST, | |
| "port": request.app.state.config.LDAP_SERVER_PORT, | |
| "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, | |
| "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, | |
| "app_dn": request.app.state.config.LDAP_APP_DN, | |
| "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, | |
| "search_base": request.app.state.config.LDAP_SEARCH_BASE, | |
| "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, | |
| "use_tls": request.app.state.config.LDAP_USE_TLS, | |
| "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, | |
| "ciphers": request.app.state.config.LDAP_CIPHERS, | |
| } | |
| async def update_ldap_server( | |
| request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) | |
| ): | |
| required_fields = [ | |
| "label", | |
| "host", | |
| "attribute_for_mail", | |
| "attribute_for_username", | |
| "app_dn", | |
| "app_dn_password", | |
| "search_base", | |
| ] | |
| for key in required_fields: | |
| value = getattr(form_data, key) | |
| if not value: | |
| raise HTTPException(400, detail=f"Required field {key} is empty") | |
| request.app.state.config.LDAP_SERVER_LABEL = form_data.label | |
| request.app.state.config.LDAP_SERVER_HOST = form_data.host | |
| request.app.state.config.LDAP_SERVER_PORT = form_data.port | |
| request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL = form_data.attribute_for_mail | |
| request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( | |
| form_data.attribute_for_username | |
| ) | |
| request.app.state.config.LDAP_APP_DN = form_data.app_dn | |
| request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password | |
| request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base | |
| request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters | |
| request.app.state.config.LDAP_USE_TLS = form_data.use_tls | |
| request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path | |
| request.app.state.config.LDAP_CIPHERS = form_data.ciphers | |
| return { | |
| "label": request.app.state.config.LDAP_SERVER_LABEL, | |
| "host": request.app.state.config.LDAP_SERVER_HOST, | |
| "port": request.app.state.config.LDAP_SERVER_PORT, | |
| "attribute_for_mail": request.app.state.config.LDAP_ATTRIBUTE_FOR_MAIL, | |
| "attribute_for_username": request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME, | |
| "app_dn": request.app.state.config.LDAP_APP_DN, | |
| "app_dn_password": request.app.state.config.LDAP_APP_PASSWORD, | |
| "search_base": request.app.state.config.LDAP_SEARCH_BASE, | |
| "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, | |
| "use_tls": request.app.state.config.LDAP_USE_TLS, | |
| "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, | |
| "ciphers": request.app.state.config.LDAP_CIPHERS, | |
| } | |
| async def get_ldap_config(request: Request, user=Depends(get_admin_user)): | |
| return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} | |
| class LdapConfigForm(BaseModel): | |
| enable_ldap: Optional[bool] = None | |
| async def update_ldap_config( | |
| request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) | |
| ): | |
| request.app.state.config.ENABLE_LDAP = form_data.enable_ldap | |
| return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} | |
| ############################ | |
| # API Key | |
| ############################ | |
| # create api key | |
| async def generate_api_key(request: Request, user=Depends(get_current_user)): | |
| if not request.app.state.config.ENABLE_API_KEY: | |
| raise HTTPException( | |
| status.HTTP_403_FORBIDDEN, | |
| detail=ERROR_MESSAGES.API_KEY_CREATION_NOT_ALLOWED, | |
| ) | |
| api_key = create_api_key() | |
| success = Users.update_user_api_key_by_id(user.id, api_key) | |
| if success: | |
| return { | |
| "api_key": api_key, | |
| } | |
| else: | |
| raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR) | |
| # delete api key | |
| async def delete_api_key(user=Depends(get_current_user)): | |
| success = Users.update_user_api_key_by_id(user.id, None) | |
| return success | |
| # get api key | |
| async def get_api_key(user=Depends(get_current_user)): | |
| api_key = Users.get_user_api_key_by_id(user.id) | |
| if api_key: | |
| return { | |
| "api_key": api_key, | |
| } | |
| else: | |
| raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) | |