chatbot-rag / config.py
sergey21000's picture
Upload 2 files
ac7d372 verified
import os
import sys
from pathlib import Path
from typing import Any, ClassVar
import gradio as gr
import torch
from llama_cpp import Llama
from chromadb import EmbeddingFunction
from dotenv import load_dotenv
load_dotenv()
class ModelStorage:
'''Global model storage'''
LLM_MODEL: ClassVar[dict[str, Llama]] = {}
EMBED_MODEL: ClassVar[dict[str, EmbeddingFunction]] = {}
class UiBlocksConfig:
'''Gradio settings for gr.Blocks()'''
CSS: str | None = '''
.gradio-container {
width: 70% !important;
margin: 0 auto !important;
}
'''
if hasattr(sys, 'getandroidapilevel') or 'ANDROID_ROOT' in os.environ:
CSS = None
UI_BLOCKS_KWARGS: dict[str, Any] = dict(
theme=None,
css=CSS,
analytics_enabled=False,
)
class InferenceConfig:
'''Model inference settings'''
def __init__(self):
self.encode_kwargs: dict[str, Any] = dict(
batch_size=300,
normalize_embeddings=None,
)
self.sampling_kwargs: dict[str, Any] = dict(
temperature=0.2,
top_p=0.95,
top_k=40,
repeat_penalty=1.0,
)
self.do_sample: bool = False
self.rag_mode: bool = False
self.history_len: int = 0
self.show_thinking: bool = False
class TextLoadConfig:
'''Settings for loading texts from documents'''
def __init__(self):
self.partition_kwargs: dict[str, str | int | bool | None] = dict(
chunking_strategy='basic',
max_characters=800,
new_after_n_chars=500,
overlap=0,
clean=True,
bullets=True,
extra_whitespace=True,
dashes=False,
trailing_punctuation=True,
lowercase=False,
)
self.SUPPORTED_FILE_EXTS: str = '.csv .tsv .docx .md .org .pdf .pptx .xlsx'
self.subtitle_lang: str = 'ru'
self.SUBTITLE_LANGS: list[str] = ['ru', 'en']
self.max_lines_text_view: int = 200
class DbConfig:
'''Vector database parameters (Chroma)'''
def __init__(self):
self.create_collection_kwargs: dict[str, Any] = dict(
configuration=dict(
hnsw=dict(
space='cosine', # l2, ip, cosine, default l2
ef_construction=200,
)
)
)
self.query_kwargs: dict[str, Any] = dict(
n_results=2,
max_distance_treshold=0.5,
)
class PromptConfig:
'''Prompts'''
def __init__(self):
self.system_prompt: str | None = None
self.user_msg_with_context: str = ''
self.context_template: str = '''Ответь на вопрос при условии контекста.
Контекст:
{context}
Вопрос:
{user_message}
Ответ:'''
class ModelConfig:
'''Configuration of paths, models and generation parameters'''
def __init__(self):
self.LLM_MODELS_PATH: Path = Path('models')
self.EMBED_MODELS_PATH: Path = Path('embed_models')
self.LLM_MODELS_PATH.mkdir(exist_ok=True)
self.EMBED_MODELS_PATH.mkdir(exist_ok=True)
self.llm_model_repo: str = 'bartowski/google_gemma-3-1b-it-GGUF'
self.llm_model_file: str = 'google_gemma-3-1b-it-Q8_0.gguf'
self.embed_model_repo: str = 'Alibaba-NLP/gte-multilingual-base'
self.embed_model_kwargs: dict[str, Any] = dict(
device='cuda:0',
trust_remote_code=True,
cache_folder=self.EMBED_MODELS_PATH,
token=os.getenv('HF_TOKEN'),
model_kwargs=dict(
torch_dtype='auto',
)
)
self.llm_model_kwargs: dict[str, Any] = dict(
n_gpu_layers=-1,
n_ctx=4096,
verbose=False,
local_dir=self.LLM_MODELS_PATH,
)
class ReposConfig:
'''Links to repositories with ggu models'''
def __init__(self):
self.llm_model_repos: list[str] = [
'bartowski/google_gemma-3-1b-it-GGUF',
'bartowski/google_gemma-3-4b-it-GGUF',
'bartowski/Qwen_Qwen3-1.7B-GGUF',
'bartowski/Qwen_Qwen3-4B-GGUF',
]
self.embed_model_repos: list[str] = [
'Alibaba-NLP/gte-multilingual-base',
'sergeyzh/rubert-tiny-turbo',
'intfloat/multilingual-e5-large',
'intfloat/multilingual-e5-base',
'intfloat/multilingual-e5-small',
'intfloat/multilingual-e5-large-instruct',
'sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
'ai-forever/ruElectra-medium',
'ai-forever/sbert_large_nlu_ru',
'deepvk/USER2-small',
'BAAI/bge-m3-retromae',
]
class Config:
'''General config'''
def __init__(self):
self.Inference: InferenceConfig = InferenceConfig()
self.TextLoad: TextLoadConfig = TextLoadConfig()
self.Prompt: PromptConfig = PromptConfig()
self.Db: DbConfig = DbConfig()
self.Model: ModelConfig = ModelConfig()
self.Repos: ReposConfig = ReposConfig()
self.generation_kwargs: dict[str, Any] = dict(
do_sample=self.Inference.do_sample,
temperature=self.Inference.sampling_kwargs['temperature'],
top_p=self.Inference.sampling_kwargs['top_p'],
top_k=self.Inference.sampling_kwargs['top_k'],
repeat_penalty=self.Inference.sampling_kwargs['repeat_penalty'],
history_len=self.Inference.history_len,
system_prompt=self.Prompt.system_prompt,
context_template=self.Prompt.context_template,
show_thinking=self.Inference.show_thinking,
n_results=self.Db.query_kwargs['n_results'],
max_distance_treshold=self.Db.query_kwargs['max_distance_treshold'],
user_msg_with_context=self.Prompt.user_msg_with_context,
rag_mode=self.Inference.rag_mode,
)
self.load_text_kwargs: dict[str, Any] = dict(
chunking_strategy=self.TextLoad.partition_kwargs['chunking_strategy'],
max_characters=self.TextLoad.partition_kwargs['max_characters'],
new_after_n_chars=self.TextLoad.partition_kwargs['new_after_n_chars'],
overlap=self.TextLoad.partition_kwargs['overlap'],
clean=self.TextLoad.partition_kwargs['clean'],
bullets=self.TextLoad.partition_kwargs['bullets'],
extra_whitespace=self.TextLoad.partition_kwargs['extra_whitespace'],
dashes=self.TextLoad.partition_kwargs['dashes'],
trailing_punctuation=self.TextLoad.partition_kwargs['trailing_punctuation'],
lowercase=self.TextLoad.partition_kwargs['lowercase'],
subtitle_lang=self.TextLoad.subtitle_lang,
)
self.load_model_kwargs: dict[str, Any] = dict(
llm_model_repo=self.Model.llm_model_repo,
llm_model_file=self.Model.llm_model_file,
embed_model_repo=self.Model.embed_model_repo,
n_gpu_layers=self.Model.llm_model_kwargs['n_gpu_layers'],
n_ctx=self.Model.llm_model_kwargs['n_ctx'],
)
self.view_text_kwargs: dict[str, Any] = dict(
max_lines_text_view=self.TextLoad.max_lines_text_view,
)
def get_sampling_kwargs(self) -> dict[str, Any]:
return dict(
temperature=self.generation_kwargs['temperature'],
top_p=self.generation_kwargs['top_p'],
top_k=self.generation_kwargs['top_k'],
repeat_penalty=self.generation_kwargs['repeat_penalty'],
)
def get_rag_kwargs(self) -> dict[str, Any]:
return dict(
n_results=self.generation_kwargs['n_results'],
max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
user_msg_with_context=self.generation_kwargs['user_msg_with_context'],
context_template=self.generation_kwargs['context_template'],
)
def get_partition_kwargs(self) -> dict[str, Any]:
return dict(
chunking_strategy=self.load_text_kwargs['chunking_strategy'],
max_characters=self.load_text_kwargs['max_characters'],
new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
overlap=self.load_text_kwargs['overlap'],
clean=self.load_text_kwargs['clean'],
bullets=self.load_text_kwargs['bullets'],
extra_whitespace=self.load_text_kwargs['extra_whitespace'],
dashes=self.load_text_kwargs['dashes'],
trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
lowercase=self.load_text_kwargs['lowercase'],
)
def get_clean_kwargs(self) -> dict[str, Any]:
return dict(
bullets=self.load_text_kwargs['bullets'],
extra_whitespace=self.load_text_kwargs['extra_whitespace'],
dashes=self.load_text_kwargs['dashes'],
trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
lowercase=self.load_text_kwargs['lowercase'],
)
def get_chunking_kwargs(self):
return dict(
max_characters=self.load_text_kwargs['max_characters'],
new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
overlap=self.load_text_kwargs['overlap'],
)
def get_embed_model_kwargs(self) -> dict[str, Any]:
return self.Model.embed_model_kwargs
def get_encode_kwargs(self) -> dict[str, Any]:
return self.Inference.encode_kwargs
def get_llm_model_kwargs(self) -> dict[str, Any]:
return self.Model.llm_model_kwargs
def get_query_kwargs(self) -> dict[str, Any]:
return dict(
n_results=self.generation_kwargs['n_results'],
max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
)