# from fastapi import FastAPI, Request, Form, UploadFile, File # from fastapi.templating import Jinja2Templates # from fastapi.responses import HTMLResponse, RedirectResponse # from fastapi.staticfiles import StaticFiles # from dotenv import load_dotenv # import os, io # from PIL import Image # import markdown # import google.generativeai as genai # # Load environment variable # load_dotenv() # API_KEY = os.getenv("GOOGLE_API_KEY") # genai.configure(api_key=API_KEY) # app = FastAPI() # templates = Jinja2Templates(directory="templates") # app.mount("/static", StaticFiles(directory="static"), name="static") # model = genai.GenerativeModel('gemini-2.0-flash') # # Create a global chat session # chat = None # chat_history = [] # @app.get("/", response_class=HTMLResponse) # async def root(request: Request): # return templates.TemplateResponse("index.html", { # "request": request, # "chat_history": chat_history, # }) # @app.post("/", response_class=HTMLResponse) # async def handle_input( # request: Request, # user_input: str = Form(...), # image: UploadFile = File(None) # ): # global chat, chat_history # # Initialize chat session if needed # if chat is None: # chat = model.start_chat(history=[]) # parts = [] # if user_input: # parts.append(user_input) # # For display in the UI # user_message = user_input # if image and image.content_type.startswith("image/"): # data = await image.read() # try: # img = Image.open(io.BytesIO(data)) # parts.append(img) # user_message += " [Image uploaded]" # Indicate image in chat history # except Exception as e: # chat_history.append({ # "role": "model", # "content": markdown.markdown(f"**Error loading image:** {e}") # }) # return RedirectResponse("/", status_code=303) # # Store user message for display # chat_history.append({"role": "user", "content": user_message}) # try: # # Send message to Gemini model # resp = chat.send_message(parts) # # Add model response to history # raw = resp.text # chat_history.append({"role": "model", "content": raw}) # except Exception as e: # err = f"**Error:** {e}" # chat_history.append({ # "role": "model", # "content": markdown.markdown(err) # }) # # Post-Redirect-Get # return RedirectResponse("/", status_code=303) # # Clear chat history and start fresh # @app.post("/new") # async def new_chat(): # global chat, chat_history # chat = None # chat_history.clear() # return RedirectResponse("/", status_code=303) import os import io import streamlit as st from dotenv import load_dotenv from PIL import Image import google.generativeai as genai from langgraph.graph import StateGraph, END from typing import TypedDict, List, Union # --------------------------- # Load API Key # --------------------------- load_dotenv() API_KEY = os.getenv("GOOGLE_API_KEY") genai.configure(api_key=API_KEY) model = genai.GenerativeModel("gemini-2.0-flash") # --------------------------- # State Definition # --------------------------- class ChatState(TypedDict): user_input: str image: Union[Image.Image, None] raw_response: str final_response: str chat_history: List[dict] # --------------------------- # LangGraph Nodes # --------------------------- def input_node(state: ChatState) -> ChatState: return state def processing_node(state: ChatState) -> ChatState: parts = [state["user_input"]] if state["image"]: parts.append(state["image"]) try: chat = model.start_chat(history=[]) resp = chat.send_message(parts) state["raw_response"] = resp.text except Exception as e: state["raw_response"] = f"Error: {e}" return state def checking_node(state: ChatState) -> ChatState: raw = state["raw_response"] # Remove unnecessary lines from Gemini responses if raw.startswith("Sure!") or "The image shows" in raw: lines = raw.split("\n") filtered = [ line for line in lines if not line.startswith("Sure!") and "The image shows" not in line ] final = "\n".join(filtered).strip() state["final_response"] = final else: state["final_response"] = raw # Save to session chat history st.session_state.chat_history.append({"role": "user", "content": state["user_input"]}) st.session_state.chat_history.append({"role": "model", "content": state["final_response"]}) return state # --------------------------- # Build the LangGraph # --------------------------- builder = StateGraph(ChatState) builder.add_node("input", input_node) builder.add_node("processing", processing_node) builder.add_node("checking", checking_node) builder.set_entry_point("input") builder.add_edge("input", "processing") builder.add_edge("processing", "checking") builder.add_edge("checking", END) graph = builder.compile() # --------------------------- # Streamlit UI Setup # --------------------------- st.set_page_config(page_title="Math Chatbot", layout="centered") st.title("Math Chatbot") # Initialize session state if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Display chat history for msg in st.session_state.chat_history: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # --------------------------- # Sidebar # --------------------------- with st.sidebar: st.header("Options") if st.button("New Chat"): st.session_state.chat_history = [] st.rerun() # --------------------------- # Chat Input Form # --------------------------- with st.form("chat_form", clear_on_submit=True): user_input = st.text_input("Your message:", placeholder="Ask your math problem here") uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) submitted = st.form_submit_button("Send") if submitted: # Load image safely image = None if uploaded_file: try: image = Image.open(io.BytesIO(uploaded_file.read())) except Exception as e: st.error(f"Error loading image: {e}") st.stop() # Prepare state input_state = { "user_input": user_input, "image": image, "raw_response": "", "final_response": "", "chat_history": st.session_state.chat_history, } # Run LangGraph output = graph.invoke(input_state) # Show model response with st.chat_message("model"): st.markdown(output["final_response"])