File size: 2,261 Bytes
2054458 d4ea062 f763ade 2054458 f763ade abbd661 865324e 707ab5a d4ea062 707ab5a d4ea062 707ab5a f763ade 707ab5a d4ea062 707ab5a d4ea062 707ab5a d4ea062 865324e d4ea062 f763ade d4ea062 f763ade 91b2732 707ab5a 865324e 707ab5a d4ea062 707ab5a d4ea062 f763ade d4ea062 707ab5a d4ea062 8d6f8e3 d4ea062 f763ade |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from huggingface_hub import InferenceClient
from datetime import datetime
import os
import uuid
# ---- System Prompt ----
with open("system_prompt.txt", "r") as f:
SYSTEM_PROMPT = f.read()
# ---- Constants ----
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
DATASET_REPO = "frimelle/companion-chat-logs"
HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space secrets
client = InferenceClient(MODEL_NAME)
# ---- Upload to Dataset ----
def upload_chat_to_dataset(user_message, assistant_message, system_prompt):
row = {
"timestamp": datetime.now().isoformat(),
"session_id": str(uuid.uuid4()),
"user": user_message,
"assistant": assistant_message,
"system_prompt": system_prompt,
}
dataset = Dataset.from_dict({k: [v] for k, v in row.items()})
dataset.push_to_hub(DATASET_REPO, private=True, token=HF_TOKEN)
# ---- Chat Function ----
def respond(message, history, system_message, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
response = ""
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content
if token:
response += token
yield response
# Log the final full message to the dataset
upload_chat_to_dataset(message, response, system_message)
# ---- Gradio UI ----
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value=SYSTEM_PROMPT, label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
title="BoundrAI",
)
if __name__ == "__main__":
demo.launch() |