Spaces:
Paused
Paused
| from fastapi import APIRouter, Depends, HTTPException, Response, status, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import logging | |
| import re | |
| from open_webui.utils.chat import generate_chat_completion | |
| from open_webui.utils.task import ( | |
| title_generation_template, | |
| query_generation_template, | |
| image_prompt_generation_template, | |
| autocomplete_generation_template, | |
| tags_generation_template, | |
| emoji_generation_template, | |
| moa_response_generation_template, | |
| ) | |
| from open_webui.utils.auth import get_admin_user, get_verified_user | |
| from open_webui.constants import TASKS | |
| from open_webui.routers.pipelines import process_pipeline_inlet_filter | |
| from open_webui.utils.filter import ( | |
| get_sorted_filter_ids, | |
| process_filter_functions, | |
| ) | |
| from open_webui.utils.task import get_task_model_id | |
| from open_webui.config import ( | |
| DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, | |
| ) | |
| from open_webui.env import SRC_LOG_LEVELS | |
| log = logging.getLogger(__name__) | |
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) | |
| router = APIRouter() | |
| ################################## | |
| # | |
| # Task Endpoints | |
| # | |
| ################################## | |
| async def get_task_config(request: Request, user=Depends(get_verified_user)): | |
| return { | |
| "TASK_MODEL": request.app.state.config.TASK_MODEL, | |
| "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, | |
| "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | |
| "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | |
| "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | |
| "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | |
| "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, | |
| "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, | |
| "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | |
| "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | |
| "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | |
| } | |
| class TaskConfigForm(BaseModel): | |
| TASK_MODEL: Optional[str] | |
| TASK_MODEL_EXTERNAL: Optional[str] | |
| ENABLE_TITLE_GENERATION: bool | |
| TITLE_GENERATION_PROMPT_TEMPLATE: str | |
| IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str | |
| ENABLE_AUTOCOMPLETE_GENERATION: bool | |
| AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int | |
| TAGS_GENERATION_PROMPT_TEMPLATE: str | |
| ENABLE_TAGS_GENERATION: bool | |
| ENABLE_SEARCH_QUERY_GENERATION: bool | |
| ENABLE_RETRIEVAL_QUERY_GENERATION: bool | |
| QUERY_GENERATION_PROMPT_TEMPLATE: str | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str | |
| async def update_task_config( | |
| request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) | |
| ): | |
| request.app.state.config.TASK_MODEL = form_data.TASK_MODEL | |
| request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL | |
| request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION | |
| request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( | |
| form_data.TITLE_GENERATION_PROMPT_TEMPLATE | |
| ) | |
| request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( | |
| form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| ) | |
| request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( | |
| form_data.ENABLE_AUTOCOMPLETE_GENERATION | |
| ) | |
| request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( | |
| form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH | |
| ) | |
| request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( | |
| form_data.TAGS_GENERATION_PROMPT_TEMPLATE | |
| ) | |
| request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION | |
| request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( | |
| form_data.ENABLE_SEARCH_QUERY_GENERATION | |
| ) | |
| request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( | |
| form_data.ENABLE_RETRIEVAL_QUERY_GENERATION | |
| ) | |
| request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( | |
| form_data.QUERY_GENERATION_PROMPT_TEMPLATE | |
| ) | |
| request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( | |
| form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE | |
| ) | |
| return { | |
| "TASK_MODEL": request.app.state.config.TASK_MODEL, | |
| "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, | |
| "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, | |
| "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | |
| "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | |
| "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | |
| "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | |
| "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, | |
| "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | |
| "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | |
| "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | |
| } | |
| async def generate_title( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if not request.app.state.config.ENABLE_TITLE_GENERATION: | |
| return JSONResponse( | |
| status_code=status.HTTP_200_OK, | |
| content={"detail": "Title generation is disabled"}, | |
| ) | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug( | |
| f"generating chat title using model {task_model_id} for user {user.email} " | |
| ) | |
| if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": | |
| template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE | |
| content = title_generation_template( | |
| template, | |
| form_data["messages"], | |
| { | |
| "name": user.name, | |
| "location": user.info.get("location") if user.info else None, | |
| }, | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| **( | |
| {"max_tokens": 1000} | |
| if models[task_model_id].get("owned_by") == "ollama" | |
| else { | |
| "max_completion_tokens": 1000, | |
| } | |
| ), | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.TITLE_GENERATION), | |
| "task_body": form_data, | |
| "chat_id": form_data.get("chat_id", None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error("Exception occurred", exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"detail": "An internal error has occurred."}, | |
| ) | |
| async def generate_chat_tags( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if not request.app.state.config.ENABLE_TAGS_GENERATION: | |
| return JSONResponse( | |
| status_code=status.HTTP_200_OK, | |
| content={"detail": "Tags generation is disabled"}, | |
| ) | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug( | |
| f"generating chat tags using model {task_model_id} for user {user.email} " | |
| ) | |
| if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": | |
| template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE | |
| content = tags_generation_template( | |
| template, form_data["messages"], {"name": user.name} | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.TAGS_GENERATION), | |
| "task_body": form_data, | |
| "chat_id": form_data.get("chat_id", None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error(f"Error generating chat completion: {e}") | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"detail": "An internal error has occurred."}, | |
| ) | |
| async def generate_image_prompt( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug( | |
| f"generating image prompt using model {task_model_id} for user {user.email} " | |
| ) | |
| if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": | |
| template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| content = image_prompt_generation_template( | |
| template, | |
| form_data["messages"], | |
| user={ | |
| "name": user.name, | |
| }, | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.IMAGE_PROMPT_GENERATION), | |
| "task_body": form_data, | |
| "chat_id": form_data.get("chat_id", None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error("Exception occurred", exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"detail": "An internal error has occurred."}, | |
| ) | |
| async def generate_queries( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| type = form_data.get("type") | |
| if type == "web_search": | |
| if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Search query generation is disabled", | |
| ) | |
| elif type == "retrieval": | |
| if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Query generation is disabled", | |
| ) | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug( | |
| f"generating {type} queries using model {task_model_id} for user {user.email}" | |
| ) | |
| if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": | |
| template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE | |
| content = query_generation_template( | |
| template, form_data["messages"], {"name": user.name} | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.QUERY_GENERATION), | |
| "task_body": form_data, | |
| "chat_id": form_data.get("chat_id", None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"detail": str(e)}, | |
| ) | |
| async def generate_autocompletion( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Autocompletion generation is disabled", | |
| ) | |
| type = form_data.get("type") | |
| prompt = form_data.get("prompt") | |
| messages = form_data.get("messages") | |
| if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: | |
| if ( | |
| len(prompt) | |
| > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", | |
| ) | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug( | |
| f"generating autocompletion using model {task_model_id} for user {user.email}" | |
| ) | |
| if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": | |
| template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | |
| content = autocomplete_generation_template( | |
| template, prompt, messages, type, {"name": user.name} | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.AUTOCOMPLETE_GENERATION), | |
| "task_body": form_data, | |
| "chat_id": form_data.get("chat_id", None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error(f"Error generating chat completion: {e}") | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"detail": "An internal error has occurred."}, | |
| ) | |
| async def generate_emoji( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") | |
| template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE | |
| content = emoji_generation_template( | |
| template, | |
| form_data["prompt"], | |
| { | |
| "name": user.name, | |
| "location": user.info.get("location") if user.info else None, | |
| }, | |
| ) | |
| payload = { | |
| "model": task_model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": False, | |
| **( | |
| {"max_tokens": 4} | |
| if models[task_model_id].get("owned_by") == "ollama" | |
| else { | |
| "max_completion_tokens": 4, | |
| } | |
| ), | |
| "chat_id": form_data.get("chat_id", None), | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "task": str(TASKS.EMOJI_GENERATION), | |
| "task_body": form_data, | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"detail": str(e)}, | |
| ) | |
| async def generate_moa_response( | |
| request: Request, form_data: dict, user=Depends(get_verified_user) | |
| ): | |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
| models = { | |
| request.state.model["id"]: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data["model"] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="Model not found", | |
| ) | |
| template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE | |
| content = moa_response_generation_template( | |
| template, | |
| form_data["prompt"], | |
| form_data["responses"], | |
| ) | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": content}], | |
| "stream": form_data.get("stream", False), | |
| "metadata": { | |
| **(request.state.metadata if hasattr(request.state, "metadata") else {}), | |
| "chat_id": form_data.get("chat_id", None), | |
| "task": str(TASKS.MOA_RESPONSE_GENERATION), | |
| "task_body": form_data, | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={"detail": str(e)}, | |
| ) | |