Upload 4 files
Browse files- constants.py +7 -7
- requirements-complete.txt +19 -0
- server.py +158 -36
constants.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
# Constants
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
| 6 |
# Also try: 'Salesforce/blip-image-captioning-base'
|
| 7 |
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
| 8 |
-
|
| 9 |
-
DEFAULT_SD_MODEL = "sinkinai/MeinaHentai-v3-baked-vae"
|
| 10 |
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 11 |
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
| 12 |
DEFAULT_REMOTE_SD_PORT = 7860
|
| 13 |
DEFAULT_CHROMA_PORT = 8000
|
| 14 |
SILERO_SAMPLES_PATH = "tts_samples"
|
| 15 |
-
SILERO_SAMPLE_TEXT = "
|
| 16 |
# ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
|
| 17 |
DEFAULT_SUMMARIZE_PARAMS = {
|
| 18 |
"temperature": 1.0,
|
|
|
|
| 1 |
# Constants
|
| 2 |
+
DEFAULT_CUDA_DEVICE = "cuda:0"
|
| 3 |
+
# Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
|
| 4 |
+
DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
|
| 5 |
+
# Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
|
| 6 |
+
DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
|
| 7 |
# Also try: 'Salesforce/blip-image-captioning-base'
|
| 8 |
DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
|
| 9 |
+
DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
|
|
|
|
| 10 |
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
| 11 |
DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
|
| 12 |
DEFAULT_REMOTE_SD_PORT = 7860
|
| 13 |
DEFAULT_CHROMA_PORT = 8000
|
| 14 |
SILERO_SAMPLES_PATH = "tts_samples"
|
| 15 |
+
SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
|
| 16 |
# ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
|
| 17 |
DEFAULT_SUMMARIZE_PARAMS = {
|
| 18 |
"temperature": 1.0,
|
requirements-complete.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask
|
| 2 |
+
flask-cloudflared
|
| 3 |
+
flask-cors
|
| 4 |
+
flask-compress
|
| 5 |
+
markdown
|
| 6 |
+
Pillow
|
| 7 |
+
colorama
|
| 8 |
+
webuiapi
|
| 9 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
| 10 |
+
torch==2.0.0+cu117
|
| 11 |
+
torchvision==0.15.1
|
| 12 |
+
torchaudio==2.0.1+cu117
|
| 13 |
+
accelerate
|
| 14 |
+
transformers==4.28.1
|
| 15 |
+
diffusers==0.16.1
|
| 16 |
+
silero-api-server
|
| 17 |
+
chromadb
|
| 18 |
+
sentence_transformers
|
| 19 |
+
edge-tts
|
server.py
CHANGED
|
@@ -21,6 +21,7 @@ import torch
|
|
| 21 |
import time
|
| 22 |
import os
|
| 23 |
import gc
|
|
|
|
| 24 |
import secrets
|
| 25 |
from PIL import Image
|
| 26 |
import base64
|
|
@@ -33,6 +34,9 @@ from colorama import Fore, Style, init as colorama_init
|
|
| 33 |
|
| 34 |
colorama_init()
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
class SplitArgs(argparse.Action):
|
| 38 |
def __call__(self, parser, namespace, values, option_string=None):
|
|
@@ -40,6 +44,16 @@ class SplitArgs(argparse.Action):
|
|
| 40 |
namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
|
| 41 |
)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Script arguments
|
| 45 |
parser = argparse.ArgumentParser(
|
|
@@ -56,6 +70,8 @@ parser.add_argument(
|
|
| 56 |
)
|
| 57 |
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
|
| 58 |
parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
|
|
|
|
|
|
|
| 59 |
parser.set_defaults(cpu=True)
|
| 60 |
parser.add_argument("--summarization-model", help="Load a custom summarization model")
|
| 61 |
parser.add_argument(
|
|
@@ -66,11 +82,10 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode
|
|
| 66 |
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
| 67 |
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
| 68 |
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
|
| 69 |
-
parser.add_argument('--chroma-persist', help="
|
| 70 |
parser.add_argument(
|
| 71 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
| 72 |
)
|
| 73 |
-
|
| 74 |
sd_group = parser.add_mutually_exclusive_group()
|
| 75 |
|
| 76 |
local_sd = sd_group.add_argument_group("sd-local")
|
|
@@ -105,8 +120,8 @@ parser.add_argument(
|
|
| 105 |
|
| 106 |
args = parser.parse_args()
|
| 107 |
|
| 108 |
-
port =
|
| 109 |
-
host = "0.0.0.0"
|
| 110 |
summarization_model = (
|
| 111 |
args.summarization_model
|
| 112 |
if args.summarization_model
|
|
@@ -142,12 +157,16 @@ if len(modules) == 0:
|
|
| 142 |
print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
|
| 143 |
|
| 144 |
# Models init
|
| 145 |
-
|
|
|
|
| 146 |
device = torch.device(device_string)
|
| 147 |
-
torch_dtype = torch.float32 if device_string
|
| 148 |
|
| 149 |
if not torch.cuda.is_available() and not args.cpu:
|
| 150 |
-
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
| 153 |
|
|
@@ -184,12 +203,10 @@ if "sd" in modules and not sd_use_remote:
|
|
| 184 |
from diffusers import StableDiffusionPipeline
|
| 185 |
from diffusers import EulerAncestralDiscreteScheduler
|
| 186 |
|
| 187 |
-
print("Initializing Stable Diffusion pipeline")
|
| 188 |
-
sd_device_string = (
|
| 189 |
-
"cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
| 190 |
-
)
|
| 191 |
sd_device = torch.device(sd_device_string)
|
| 192 |
-
sd_torch_dtype = torch.float32 if sd_device_string
|
| 193 |
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
| 194 |
sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
|
| 195 |
).to(sd_device)
|
|
@@ -252,26 +269,19 @@ if "chromadb" in modules:
|
|
| 252 |
posthog.capture = lambda *args, **kwargs: None
|
| 253 |
if args.chroma_host is None:
|
| 254 |
if args.chroma_persist:
|
| 255 |
-
chromadb_client = chromadb.
|
| 256 |
print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
|
| 257 |
else:
|
| 258 |
-
chromadb_client = chromadb.
|
| 259 |
print(f"ChromaDB is running in-memory without persistence.")
|
| 260 |
else:
|
| 261 |
chroma_port=(
|
| 262 |
args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
|
| 263 |
)
|
| 264 |
-
chromadb_client = chromadb.
|
| 265 |
-
Settings(
|
| 266 |
-
anonymized_telemetry=False,
|
| 267 |
-
chroma_api_impl="rest",
|
| 268 |
-
chroma_server_host=args.chroma_host,
|
| 269 |
-
chroma_server_http_port=chroma_port
|
| 270 |
-
)
|
| 271 |
-
)
|
| 272 |
print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
|
| 273 |
|
| 274 |
-
chromadb_embedder = SentenceTransformer(embedding_model)
|
| 275 |
chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
|
| 276 |
|
| 277 |
# Check if the db is connected and running, otherwise tell the user
|
|
@@ -405,10 +415,24 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
|
|
| 405 |
image.save(buffer, format="JPEG", quality=quality)
|
| 406 |
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 407 |
return img_str
|
| 408 |
-
|
| 409 |
-
ignore_auth = []
|
| 410 |
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
def is_authorize_ignored(request):
|
| 414 |
view_func = app.view_functions.get(request.endpoint)
|
|
@@ -418,6 +442,7 @@ def is_authorize_ignored(request):
|
|
| 418 |
return True
|
| 419 |
return False
|
| 420 |
|
|
|
|
| 421 |
@app.before_request
|
| 422 |
def before_request():
|
| 423 |
# Request time measuring
|
|
@@ -426,14 +451,14 @@ def before_request():
|
|
| 426 |
# Checks if an API key is present and valid, otherwise return unauthorized
|
| 427 |
# The options check is required so CORS doesn't get angry
|
| 428 |
try:
|
| 429 |
-
if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
|
| 430 |
print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
|
| 431 |
response = jsonify({ 'error': '401: Invalid API key' })
|
| 432 |
response.status_code = 401
|
| 433 |
-
return
|
| 434 |
except Exception as e:
|
| 435 |
print(f"API key check error: {e}")
|
| 436 |
-
return "
|
| 437 |
|
| 438 |
|
| 439 |
@app.after_request
|
|
@@ -645,7 +670,7 @@ def tts_speakers():
|
|
| 645 |
]
|
| 646 |
return jsonify(voices)
|
| 647 |
|
| 648 |
-
|
| 649 |
@app.route("/api/tts/generate", methods=["POST"])
|
| 650 |
@require_module("silero-tts")
|
| 651 |
def tts_generate():
|
|
@@ -657,8 +682,15 @@ def tts_generate():
|
|
| 657 |
# Remove asterisks
|
| 658 |
voice["text"] = voice["text"].replace("*", "")
|
| 659 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
audio = tts_service.generate(voice["speaker"], voice["text"])
|
| 661 |
-
|
|
|
|
|
|
|
|
|
|
| 662 |
except Exception as e:
|
| 663 |
print(e)
|
| 664 |
abort(500, voice["speaker"])
|
|
@@ -743,8 +775,6 @@ def chromadb_purge():
|
|
| 743 |
|
| 744 |
count = collection.count()
|
| 745 |
collection.delete()
|
| 746 |
-
#Write deletion to persistent folder
|
| 747 |
-
chromadb_client.persist()
|
| 748 |
print("ChromaDB embeddings deleted", count)
|
| 749 |
return 'Ok', 200
|
| 750 |
|
|
@@ -768,6 +798,11 @@ def chromadb_query():
|
|
| 768 |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
| 769 |
)
|
| 770 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
n_results = min(collection.count(), n_results)
|
| 772 |
query_result = collection.query(
|
| 773 |
query_texts=[data["query"]],
|
|
@@ -793,6 +828,69 @@ def chromadb_query():
|
|
| 793 |
|
| 794 |
return jsonify(messages)
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
|
| 797 |
@app.route("/api/chromadb/export", methods=["POST"])
|
| 798 |
@require_module("chromadb")
|
|
@@ -802,9 +900,14 @@ def chromadb_export():
|
|
| 802 |
abort(400, '"chat_id" is required')
|
| 803 |
|
| 804 |
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 808 |
collection_content = collection.get()
|
| 809 |
documents = collection_content.get('documents', [])
|
| 810 |
ids = collection_content.get('ids', [])
|
|
@@ -847,8 +950,27 @@ def chromadb_import():
|
|
| 847 |
|
| 848 |
|
| 849 |
collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
|
|
|
| 850 |
|
| 851 |
return jsonify({"count": len(ids)})
|
| 852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
ignore_auth.append(tts_play_sample)
|
| 854 |
app.run(host=host, port=port)
|
|
|
|
| 21 |
import time
|
| 22 |
import os
|
| 23 |
import gc
|
| 24 |
+
import sys
|
| 25 |
import secrets
|
| 26 |
from PIL import Image
|
| 27 |
import base64
|
|
|
|
| 34 |
|
| 35 |
colorama_init()
|
| 36 |
|
| 37 |
+
if sys.hexversion < 0x030b0000:
|
| 38 |
+
print(f"{Fore.BLUE}{Style.BRIGHT}Python 3.11 or newer is recommended to run this program.{Style.RESET_ALL}")
|
| 39 |
+
time.sleep(2)
|
| 40 |
|
| 41 |
class SplitArgs(argparse.Action):
|
| 42 |
def __call__(self, parser, namespace, values, option_string=None):
|
|
|
|
| 44 |
namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
|
| 45 |
)
|
| 46 |
|
| 47 |
+
#Setting Root Folders for Silero Generations so it is compatible with STSL, should not effect regular runs. - Rolyat
|
| 48 |
+
parent_dir = os.path.dirname(os.path.abspath(__file__))
|
| 49 |
+
SILERO_SAMPLES_PATH = os.path.join(parent_dir, "tts_samples")
|
| 50 |
+
SILERO_SAMPLE_TEXT = os.path.join(parent_dir)
|
| 51 |
+
|
| 52 |
+
# Create directories if they don't exist
|
| 53 |
+
if not os.path.exists(SILERO_SAMPLES_PATH):
|
| 54 |
+
os.makedirs(SILERO_SAMPLES_PATH)
|
| 55 |
+
if not os.path.exists(SILERO_SAMPLE_TEXT):
|
| 56 |
+
os.makedirs(SILERO_SAMPLE_TEXT)
|
| 57 |
|
| 58 |
# Script arguments
|
| 59 |
parser = argparse.ArgumentParser(
|
|
|
|
| 70 |
)
|
| 71 |
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
|
| 72 |
parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
|
| 73 |
+
parser.add_argument("--cuda-device", help="Specify the CUDA device to use")
|
| 74 |
+
parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon")
|
| 75 |
parser.set_defaults(cpu=True)
|
| 76 |
parser.add_argument("--summarization-model", help="Load a custom summarization model")
|
| 77 |
parser.add_argument(
|
|
|
|
| 82 |
parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
|
| 83 |
parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
|
| 84 |
parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
|
| 85 |
+
parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction)
|
| 86 |
parser.add_argument(
|
| 87 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
| 88 |
)
|
|
|
|
| 89 |
sd_group = parser.add_mutually_exclusive_group()
|
| 90 |
|
| 91 |
local_sd = sd_group.add_argument_group("sd-local")
|
|
|
|
| 120 |
|
| 121 |
args = parser.parse_args()
|
| 122 |
|
| 123 |
+
port = args.port if args.port else 5100
|
| 124 |
+
host = "0.0.0.0" if args.listen else "localhost"
|
| 125 |
summarization_model = (
|
| 126 |
args.summarization_model
|
| 127 |
if args.summarization_model
|
|
|
|
| 157 |
print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
|
| 158 |
|
| 159 |
# Models init
|
| 160 |
+
cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
|
| 161 |
+
device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu'
|
| 162 |
device = torch.device(device_string)
|
| 163 |
+
torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
|
| 164 |
|
| 165 |
if not torch.cuda.is_available() and not args.cpu:
|
| 166 |
+
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
|
| 167 |
+
if not torch.backends.mps.is_available() and not args.cpu:
|
| 168 |
+
print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}")
|
| 169 |
+
|
| 170 |
|
| 171 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
| 172 |
|
|
|
|
| 203 |
from diffusers import StableDiffusionPipeline
|
| 204 |
from diffusers import EulerAncestralDiscreteScheduler
|
| 205 |
|
| 206 |
+
print("Initializing Stable Diffusion pipeline...")
|
| 207 |
+
sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
|
|
|
|
|
|
| 208 |
sd_device = torch.device(sd_device_string)
|
| 209 |
+
sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16
|
| 210 |
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
| 211 |
sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
|
| 212 |
).to(sd_device)
|
|
|
|
| 269 |
posthog.capture = lambda *args, **kwargs: None
|
| 270 |
if args.chroma_host is None:
|
| 271 |
if args.chroma_persist:
|
| 272 |
+
chromadb_client = chromadb.PersistentClient(path=args.chroma_folder, settings=Settings(anonymized_telemetry=False))
|
| 273 |
print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
|
| 274 |
else:
|
| 275 |
+
chromadb_client = chromadb.EphemeralClient(Settings(anonymized_telemetry=False))
|
| 276 |
print(f"ChromaDB is running in-memory without persistence.")
|
| 277 |
else:
|
| 278 |
chroma_port=(
|
| 279 |
args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
|
| 280 |
)
|
| 281 |
+
chromadb_client = chromadb.HttpClient(host=args.chroma_host, port=chroma_port, settings=Settings(anonymized_telemetry=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
|
| 283 |
|
| 284 |
+
chromadb_embedder = SentenceTransformer(embedding_model, device=device_string)
|
| 285 |
chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
|
| 286 |
|
| 287 |
# Check if the db is connected and running, otherwise tell the user
|
|
|
|
| 415 |
image.save(buffer, format="JPEG", quality=quality)
|
| 416 |
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 417 |
return img_str
|
|
|
|
|
|
|
| 418 |
|
| 419 |
+
ignore_auth = []
|
| 420 |
+
# Reads an API key from an already existing file. If that file doesn't exist, create it.
|
| 421 |
+
if args.secure:
|
| 422 |
+
try:
|
| 423 |
+
with open("api_key.txt", "r") as txt:
|
| 424 |
+
api_key = txt.read().replace('\n', '')
|
| 425 |
+
except:
|
| 426 |
+
api_key = secrets.token_hex(5)
|
| 427 |
+
with open("api_key.txt", "w") as txt:
|
| 428 |
+
txt.write(api_key)
|
| 429 |
+
|
| 430 |
+
print(f"Your API key is {api_key}")
|
| 431 |
+
elif args.share and args.secure != True:
|
| 432 |
+
print("WARNING: This instance is publicly exposed without an API key! It is highly recommended to restart with the \"--secure\" argument!")
|
| 433 |
+
else:
|
| 434 |
+
print("No API key given because you are running locally.")
|
| 435 |
+
|
| 436 |
|
| 437 |
def is_authorize_ignored(request):
|
| 438 |
view_func = app.view_functions.get(request.endpoint)
|
|
|
|
| 442 |
return True
|
| 443 |
return False
|
| 444 |
|
| 445 |
+
|
| 446 |
@app.before_request
|
| 447 |
def before_request():
|
| 448 |
# Request time measuring
|
|
|
|
| 451 |
# Checks if an API key is present and valid, otherwise return unauthorized
|
| 452 |
# The options check is required so CORS doesn't get angry
|
| 453 |
try:
|
| 454 |
+
if request.method != 'OPTIONS' and args.secure and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
|
| 455 |
print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
|
| 456 |
response = jsonify({ 'error': '401: Invalid API key' })
|
| 457 |
response.status_code = 401
|
| 458 |
+
return response
|
| 459 |
except Exception as e:
|
| 460 |
print(f"API key check error: {e}")
|
| 461 |
+
return "401 Unauthorized\n{}\n\n".format(e), 401
|
| 462 |
|
| 463 |
|
| 464 |
@app.after_request
|
|
|
|
| 670 |
]
|
| 671 |
return jsonify(voices)
|
| 672 |
|
| 673 |
+
# Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
|
| 674 |
@app.route("/api/tts/generate", methods=["POST"])
|
| 675 |
@require_module("silero-tts")
|
| 676 |
def tts_generate():
|
|
|
|
| 682 |
# Remove asterisks
|
| 683 |
voice["text"] = voice["text"].replace("*", "")
|
| 684 |
try:
|
| 685 |
+
# Remove the destination file if it already exists
|
| 686 |
+
if os.path.exists('test.wav'):
|
| 687 |
+
os.remove('test.wav')
|
| 688 |
+
|
| 689 |
audio = tts_service.generate(voice["speaker"], voice["text"])
|
| 690 |
+
audio_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.basename(audio))
|
| 691 |
+
|
| 692 |
+
os.rename(audio, audio_file_path)
|
| 693 |
+
return send_file(audio_file_path, mimetype="audio/x-wav")
|
| 694 |
except Exception as e:
|
| 695 |
print(e)
|
| 696 |
abort(500, voice["speaker"])
|
|
|
|
| 775 |
|
| 776 |
count = collection.count()
|
| 777 |
collection.delete()
|
|
|
|
|
|
|
| 778 |
print("ChromaDB embeddings deleted", count)
|
| 779 |
return 'Ok', 200
|
| 780 |
|
|
|
|
| 798 |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
| 799 |
)
|
| 800 |
|
| 801 |
+
if collection.count() == 0:
|
| 802 |
+
print(f"Queried empty/missing collection for {repr(data['chat_id'])}.")
|
| 803 |
+
return jsonify([])
|
| 804 |
+
|
| 805 |
+
|
| 806 |
n_results = min(collection.count(), n_results)
|
| 807 |
query_result = collection.query(
|
| 808 |
query_texts=[data["query"]],
|
|
|
|
| 828 |
|
| 829 |
return jsonify(messages)
|
| 830 |
|
| 831 |
+
@app.route("/api/chromadb/multiquery", methods=["POST"])
|
| 832 |
+
@require_module("chromadb")
|
| 833 |
+
def chromadb_multiquery():
|
| 834 |
+
data = request.get_json()
|
| 835 |
+
if "chat_list" not in data or not isinstance(data["chat_list"], list):
|
| 836 |
+
abort(400, '"chat_list" is required and should be a list')
|
| 837 |
+
if "query" not in data or not isinstance(data["query"], str):
|
| 838 |
+
abort(400, '"query" is required')
|
| 839 |
+
|
| 840 |
+
if "n_results" not in data or not isinstance(data["n_results"], int):
|
| 841 |
+
n_results = 1
|
| 842 |
+
else:
|
| 843 |
+
n_results = data["n_results"]
|
| 844 |
+
|
| 845 |
+
messages = []
|
| 846 |
+
|
| 847 |
+
for chat_id in data["chat_list"]:
|
| 848 |
+
if not isinstance(chat_id, str):
|
| 849 |
+
continue
|
| 850 |
+
|
| 851 |
+
try:
|
| 852 |
+
chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
|
| 853 |
+
collection = chromadb_client.get_collection(
|
| 854 |
+
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Skip this chat if the collection is empty
|
| 858 |
+
if collection.count() == 0:
|
| 859 |
+
continue
|
| 860 |
+
|
| 861 |
+
n_results_per_chat = min(collection.count(), n_results)
|
| 862 |
+
query_result = collection.query(
|
| 863 |
+
query_texts=[data["query"]],
|
| 864 |
+
n_results=n_results_per_chat,
|
| 865 |
+
)
|
| 866 |
+
documents = query_result["documents"][0]
|
| 867 |
+
ids = query_result["ids"][0]
|
| 868 |
+
metadatas = query_result["metadatas"][0]
|
| 869 |
+
distances = query_result["distances"][0]
|
| 870 |
+
|
| 871 |
+
chat_messages = [
|
| 872 |
+
{
|
| 873 |
+
"id": ids[i],
|
| 874 |
+
"date": metadatas[i]["date"],
|
| 875 |
+
"role": metadatas[i]["role"],
|
| 876 |
+
"meta": metadatas[i]["meta"],
|
| 877 |
+
"content": documents[i],
|
| 878 |
+
"distance": distances[i],
|
| 879 |
+
}
|
| 880 |
+
for i in range(len(ids))
|
| 881 |
+
]
|
| 882 |
+
|
| 883 |
+
messages.extend(chat_messages)
|
| 884 |
+
except Exception as e:
|
| 885 |
+
print(e)
|
| 886 |
+
|
| 887 |
+
#remove duplicate msgs, filter down to the right number
|
| 888 |
+
seen = set()
|
| 889 |
+
messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
|
| 890 |
+
messages = sorted(messages, key=lambda x: x['distance'])[0:n_results]
|
| 891 |
+
|
| 892 |
+
return jsonify(messages)
|
| 893 |
+
|
| 894 |
|
| 895 |
@app.route("/api/chromadb/export", methods=["POST"])
|
| 896 |
@require_module("chromadb")
|
|
|
|
| 900 |
abort(400, '"chat_id" is required')
|
| 901 |
|
| 902 |
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
| 903 |
+
try:
|
| 904 |
+
collection = chromadb_client.get_collection(
|
| 905 |
+
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
| 906 |
+
)
|
| 907 |
+
except Exception as e:
|
| 908 |
+
print(e)
|
| 909 |
+
abort(400, "Chat collection not found in chromadb")
|
| 910 |
+
|
| 911 |
collection_content = collection.get()
|
| 912 |
documents = collection_content.get('documents', [])
|
| 913 |
ids = collection_content.get('ids', [])
|
|
|
|
| 950 |
|
| 951 |
|
| 952 |
collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
|
| 953 |
+
print(f"Imported {len(ids)} (total {collection.count()}) content entries into {repr(data['chat_id'])}")
|
| 954 |
|
| 955 |
return jsonify({"count": len(ids)})
|
| 956 |
|
| 957 |
+
|
| 958 |
+
if args.share:
|
| 959 |
+
from flask_cloudflared import _run_cloudflared
|
| 960 |
+
import inspect
|
| 961 |
+
|
| 962 |
+
sig = inspect.signature(_run_cloudflared)
|
| 963 |
+
sum = sum(
|
| 964 |
+
1
|
| 965 |
+
for param in sig.parameters.values()
|
| 966 |
+
if param.kind == param.POSITIONAL_OR_KEYWORD
|
| 967 |
+
)
|
| 968 |
+
if sum > 1:
|
| 969 |
+
metrics_port = randint(8100, 9000)
|
| 970 |
+
cloudflare = _run_cloudflared(port, metrics_port)
|
| 971 |
+
else:
|
| 972 |
+
cloudflare = _run_cloudflared(port)
|
| 973 |
+
print("Running on", cloudflare)
|
| 974 |
+
|
| 975 |
ignore_auth.append(tts_play_sample)
|
| 976 |
app.run(host=host, port=port)
|