Spaces:
Running
Running
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, ClassVar | |
| import gradio as gr | |
| import torch | |
| from llama_cpp import Llama | |
| from chromadb import EmbeddingFunction | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class ModelStorage: | |
| '''Global model storage''' | |
| LLM_MODEL: ClassVar[dict[str, Llama]] = {} | |
| EMBED_MODEL: ClassVar[dict[str, EmbeddingFunction]] = {} | |
| class UiBlocksConfig: | |
| '''Gradio settings for gr.Blocks()''' | |
| CSS: str | None = ''' | |
| .gradio-container { | |
| width: 70% !important; | |
| margin: 0 auto !important; | |
| } | |
| ''' | |
| if hasattr(sys, 'getandroidapilevel') or 'ANDROID_ROOT' in os.environ: | |
| CSS = None | |
| UI_BLOCKS_KWARGS: dict[str, Any] = dict( | |
| theme=None, | |
| css=CSS, | |
| analytics_enabled=False, | |
| ) | |
| class InferenceConfig: | |
| '''Model inference settings''' | |
| def __init__(self): | |
| self.encode_kwargs: dict[str, Any] = dict( | |
| batch_size=300, | |
| normalize_embeddings=None, | |
| ) | |
| self.sampling_kwargs: dict[str, Any] = dict( | |
| temperature=0.2, | |
| top_p=0.95, | |
| top_k=40, | |
| repeat_penalty=1.0, | |
| ) | |
| self.do_sample: bool = False | |
| self.rag_mode: bool = False | |
| self.history_len: int = 0 | |
| self.show_thinking: bool = False | |
| class TextLoadConfig: | |
| '''Settings for loading texts from documents''' | |
| def __init__(self): | |
| self.partition_kwargs: dict[str, str | int | bool | None] = dict( | |
| chunking_strategy='basic', | |
| max_characters=800, | |
| new_after_n_chars=500, | |
| overlap=0, | |
| clean=True, | |
| bullets=True, | |
| extra_whitespace=True, | |
| dashes=False, | |
| trailing_punctuation=True, | |
| lowercase=False, | |
| ) | |
| self.SUPPORTED_FILE_EXTS: str = '.csv .tsv .docx .md .org .pdf .pptx .xlsx' | |
| self.subtitle_lang: str = 'ru' | |
| self.SUBTITLE_LANGS: list[str] = ['ru', 'en'] | |
| self.max_lines_text_view: int = 200 | |
| class DbConfig: | |
| '''Vector database parameters (Chroma)''' | |
| def __init__(self): | |
| self.create_collection_kwargs: dict[str, Any] = dict( | |
| configuration=dict( | |
| hnsw=dict( | |
| space='cosine', # l2, ip, cosine, default l2 | |
| ef_construction=200, | |
| ) | |
| ) | |
| ) | |
| self.query_kwargs: dict[str, Any] = dict( | |
| n_results=2, | |
| max_distance_treshold=0.5, | |
| ) | |
| class PromptConfig: | |
| '''Prompts''' | |
| def __init__(self): | |
| self.system_prompt: str | None = None | |
| self.user_msg_with_context: str = '' | |
| self.context_template: str = '''Ответь на вопрос при условии контекста. | |
| Контекст: | |
| {context} | |
| Вопрос: | |
| {user_message} | |
| Ответ:''' | |
| class ModelConfig: | |
| '''Configuration of paths, models and generation parameters''' | |
| def __init__(self): | |
| self.LLM_MODELS_PATH: Path = Path('models') | |
| self.EMBED_MODELS_PATH: Path = Path('embed_models') | |
| self.LLM_MODELS_PATH.mkdir(exist_ok=True) | |
| self.EMBED_MODELS_PATH.mkdir(exist_ok=True) | |
| self.llm_model_repo: str = 'bartowski/google_gemma-3-1b-it-GGUF' | |
| self.llm_model_file: str = 'google_gemma-3-1b-it-Q8_0.gguf' | |
| self.embed_model_repo: str = 'Alibaba-NLP/gte-multilingual-base' | |
| self.embed_model_kwargs: dict[str, Any] = dict( | |
| device='cuda:0', | |
| trust_remote_code=True, | |
| cache_folder=self.EMBED_MODELS_PATH, | |
| token=os.getenv('HF_TOKEN'), | |
| model_kwargs=dict( | |
| torch_dtype='auto', | |
| ) | |
| ) | |
| self.llm_model_kwargs: dict[str, Any] = dict( | |
| n_gpu_layers=-1, | |
| n_ctx=4096, | |
| verbose=False, | |
| local_dir=self.LLM_MODELS_PATH, | |
| ) | |
| class ReposConfig: | |
| '''Links to repositories with ggu models''' | |
| def __init__(self): | |
| self.llm_model_repos: list[str] = [ | |
| 'bartowski/google_gemma-3-1b-it-GGUF', | |
| 'bartowski/google_gemma-3-4b-it-GGUF', | |
| 'bartowski/Qwen_Qwen3-1.7B-GGUF', | |
| 'bartowski/Qwen_Qwen3-4B-GGUF', | |
| ] | |
| self.embed_model_repos: list[str] = [ | |
| 'Alibaba-NLP/gte-multilingual-base', | |
| 'sergeyzh/rubert-tiny-turbo', | |
| 'intfloat/multilingual-e5-large', | |
| 'intfloat/multilingual-e5-base', | |
| 'intfloat/multilingual-e5-small', | |
| 'intfloat/multilingual-e5-large-instruct', | |
| 'sentence-transformers/all-mpnet-base-v2', | |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
| 'ai-forever/ruElectra-medium', | |
| 'ai-forever/sbert_large_nlu_ru', | |
| 'deepvk/USER2-small', | |
| 'BAAI/bge-m3-retromae', | |
| ] | |
| class Config: | |
| '''General config''' | |
| def __init__(self): | |
| self.Inference: InferenceConfig = InferenceConfig() | |
| self.TextLoad: TextLoadConfig = TextLoadConfig() | |
| self.Prompt: PromptConfig = PromptConfig() | |
| self.Db: DbConfig = DbConfig() | |
| self.Model: ModelConfig = ModelConfig() | |
| self.Repos: ReposConfig = ReposConfig() | |
| self.generation_kwargs: dict[str, Any] = dict( | |
| do_sample=self.Inference.do_sample, | |
| temperature=self.Inference.sampling_kwargs['temperature'], | |
| top_p=self.Inference.sampling_kwargs['top_p'], | |
| top_k=self.Inference.sampling_kwargs['top_k'], | |
| repeat_penalty=self.Inference.sampling_kwargs['repeat_penalty'], | |
| history_len=self.Inference.history_len, | |
| system_prompt=self.Prompt.system_prompt, | |
| context_template=self.Prompt.context_template, | |
| show_thinking=self.Inference.show_thinking, | |
| n_results=self.Db.query_kwargs['n_results'], | |
| max_distance_treshold=self.Db.query_kwargs['max_distance_treshold'], | |
| user_msg_with_context=self.Prompt.user_msg_with_context, | |
| rag_mode=self.Inference.rag_mode, | |
| ) | |
| self.load_text_kwargs: dict[str, Any] = dict( | |
| chunking_strategy=self.TextLoad.partition_kwargs['chunking_strategy'], | |
| max_characters=self.TextLoad.partition_kwargs['max_characters'], | |
| new_after_n_chars=self.TextLoad.partition_kwargs['new_after_n_chars'], | |
| overlap=self.TextLoad.partition_kwargs['overlap'], | |
| clean=self.TextLoad.partition_kwargs['clean'], | |
| bullets=self.TextLoad.partition_kwargs['bullets'], | |
| extra_whitespace=self.TextLoad.partition_kwargs['extra_whitespace'], | |
| dashes=self.TextLoad.partition_kwargs['dashes'], | |
| trailing_punctuation=self.TextLoad.partition_kwargs['trailing_punctuation'], | |
| lowercase=self.TextLoad.partition_kwargs['lowercase'], | |
| subtitle_lang=self.TextLoad.subtitle_lang, | |
| ) | |
| self.load_model_kwargs: dict[str, Any] = dict( | |
| llm_model_repo=self.Model.llm_model_repo, | |
| llm_model_file=self.Model.llm_model_file, | |
| embed_model_repo=self.Model.embed_model_repo, | |
| n_gpu_layers=self.Model.llm_model_kwargs['n_gpu_layers'], | |
| n_ctx=self.Model.llm_model_kwargs['n_ctx'], | |
| ) | |
| self.view_text_kwargs: dict[str, Any] = dict( | |
| max_lines_text_view=self.TextLoad.max_lines_text_view, | |
| ) | |
| def get_sampling_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| temperature=self.generation_kwargs['temperature'], | |
| top_p=self.generation_kwargs['top_p'], | |
| top_k=self.generation_kwargs['top_k'], | |
| repeat_penalty=self.generation_kwargs['repeat_penalty'], | |
| ) | |
| def get_rag_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| n_results=self.generation_kwargs['n_results'], | |
| max_distance_treshold=self.generation_kwargs['max_distance_treshold'], | |
| user_msg_with_context=self.generation_kwargs['user_msg_with_context'], | |
| context_template=self.generation_kwargs['context_template'], | |
| ) | |
| def get_partition_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| chunking_strategy=self.load_text_kwargs['chunking_strategy'], | |
| max_characters=self.load_text_kwargs['max_characters'], | |
| new_after_n_chars=self.load_text_kwargs['new_after_n_chars'], | |
| overlap=self.load_text_kwargs['overlap'], | |
| clean=self.load_text_kwargs['clean'], | |
| bullets=self.load_text_kwargs['bullets'], | |
| extra_whitespace=self.load_text_kwargs['extra_whitespace'], | |
| dashes=self.load_text_kwargs['dashes'], | |
| trailing_punctuation=self.load_text_kwargs['trailing_punctuation'], | |
| lowercase=self.load_text_kwargs['lowercase'], | |
| ) | |
| def get_clean_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| bullets=self.load_text_kwargs['bullets'], | |
| extra_whitespace=self.load_text_kwargs['extra_whitespace'], | |
| dashes=self.load_text_kwargs['dashes'], | |
| trailing_punctuation=self.load_text_kwargs['trailing_punctuation'], | |
| lowercase=self.load_text_kwargs['lowercase'], | |
| ) | |
| def get_chunking_kwargs(self): | |
| return dict( | |
| max_characters=self.load_text_kwargs['max_characters'], | |
| new_after_n_chars=self.load_text_kwargs['new_after_n_chars'], | |
| overlap=self.load_text_kwargs['overlap'], | |
| ) | |
| def get_embed_model_kwargs(self) -> dict[str, Any]: | |
| return self.Model.embed_model_kwargs | |
| def get_encode_kwargs(self) -> dict[str, Any]: | |
| return self.Inference.encode_kwargs | |
| def get_llm_model_kwargs(self) -> dict[str, Any]: | |
| return self.Model.llm_model_kwargs | |
| def get_query_kwargs(self) -> dict[str, Any]: | |
| return dict( | |
| n_results=self.generation_kwargs['n_results'], | |
| max_distance_treshold=self.generation_kwargs['max_distance_treshold'], | |
| ) | |