Spaces:
Build error
Build error
| import streamlit as st | |
| from utils import ( | |
| load_model, | |
| load_finetuned_model, | |
| generate_response, | |
| get_hf_token | |
| ) | |
| import os | |
| import json | |
| from datetime import datetime | |
| st.set_page_config(page_title="Gemma Chat", layout="wide") | |
| # ------------------------------- | |
| # π‘ Theme Toggle | |
| # ------------------------------- | |
| dark_mode = st.sidebar.toggle("π Dark Mode", value=False) | |
| if dark_mode: | |
| st.markdown( | |
| """ | |
| <style> | |
| body { background-color: #1e1e1e; color: #ffffff; } | |
| .stTextInput, .stTextArea, .stSelectbox, .stSlider { color: #ffffff !important; } | |
| </style> | |
| """, unsafe_allow_html=True | |
| ) | |
| st.title("π¬ Chat with Gemma Model") | |
| # ------------------------------- | |
| # π Model Source Selection | |
| # ------------------------------- | |
| model_source = st.sidebar.radio("π Select Model Source", ["Local (.pt)", "Hugging Face"]) | |
| # ------------------------------- | |
| # π₯ Dynamic Model List | |
| # ------------------------------- | |
| if model_source == "Local (.pt)": | |
| model_dir = "models" | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| local_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] | |
| if local_models: | |
| selected_model = st.sidebar.selectbox("π οΈ Select Local Model", local_models) | |
| model_path = os.path.join(model_dir, selected_model) | |
| else: | |
| st.warning("β οΈ No fine-tuned models found. Fine-tune a model first.") | |
| st.stop() | |
| else: | |
| hf_models = [ | |
| "google/gemma-3-1b-it", | |
| "google/gemma-3-4b-pt", | |
| "google/gemma-3-4b-it", | |
| "google/gemma-3-12b-pt", | |
| "google/gemma-3-12b-it", | |
| "google/gemma-3-27b-pt", | |
| "google/gemma-3-27b-it" | |
| ] | |
| selected_model = st.sidebar.selectbox("π οΈ Select Hugging Face Model", hf_models) | |
| model_path = None | |
| # ------------------------------- | |
| # π₯ Model Loading | |
| # ------------------------------- | |
| hf_token = get_hf_token() | |
| if model_source == "Local (.pt)": | |
| tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) # Base model first | |
| model = load_finetuned_model(model, model_path) | |
| if model: | |
| st.success(f"β Local fine-tuned model loaded: `{selected_model}`") | |
| else: | |
| st.error("β Failed to load local model.") | |
| st.stop() | |
| else: | |
| tokenizer, model = load_model(selected_model, hf_token) | |
| if model: | |
| st.success(f"β Hugging Face model loaded: `{selected_model}`") | |
| else: | |
| st.error("β Failed to load Hugging Face model.") | |
| st.stop() | |
| # ------------------------------- | |
| # βοΈ Model Configuration Panel | |
| # ------------------------------- | |
| st.sidebar.header("βοΈ Model Configuration") | |
| temperature = st.sidebar.slider("π₯ Temperature", 0.1, 1.5, 0.7, 0.1) | |
| top_p = st.sidebar.slider("π― Top-p", 0.1, 1.0, 0.9, 0.1) | |
| repetition_penalty = st.sidebar.slider("π Repetition Penalty", 0.5, 2.0, 1.0, 0.1) | |
| # ------------------------------- | |
| # π¬ Chat Interface | |
| # ------------------------------- | |
| if "conversation" not in st.session_state: | |
| st.session_state.conversation = [] | |
| prompt = st.text_area("π¬ Enter your message:", "Hello, how are you?", key="prompt", height=100) | |
| max_length = st.slider("π Max Response Length", min_value=50, max_value=1000, value=300, step=50) | |
| # ------------------------------- | |
| # π Streaming Response Function | |
| # ------------------------------- | |
| def stream_response(): | |
| """ | |
| Streams the response token by token. | |
| """ | |
| response = generate_response(prompt, model, tokenizer, max_length) | |
| if response: | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| st.session_state.conversation.append({"sender": "π€ You", "message": prompt, "timestamp": timestamp}) | |
| st.session_state.conversation.append({"sender": "π€ AI", "message": response, "timestamp": timestamp}) | |
| return response | |
| else: | |
| st.error("β Failed to generate response.") | |
| return None | |
| # ------------------------------- | |
| # π― Conversation Controls | |
| # ------------------------------- | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| if col1.button("π Generate (CTRL+Enter)", help="Use CTRL + Enter to generate"): | |
| stream_response() | |
| if col2.button("ποΈ Clear Conversation"): | |
| st.session_state.conversation = [] | |
| # Export & Import | |
| if col3.download_button("π₯ Export Chat", json.dumps(st.session_state.conversation, indent=4), "chat_history.json"): | |
| st.success("β Chat exported successfully!") | |
| uploaded_file = st.file_uploader("π€ Import Conversation", type=["json"]) | |
| if uploaded_file is not None: | |
| st.session_state.conversation = json.load(uploaded_file) | |
| st.success("β Conversation imported successfully!") | |
| # ------------------------------- | |
| # π οΈ Display Conversation | |
| # ------------------------------- | |
| st.subheader("π Conversation History") | |
| for msg in st.session_state.conversation: | |
| with st.container(): | |
| st.markdown(f"**{msg['sender']}** \nπ {msg['timestamp']}") | |
| st.write(msg['message']) | |