from httpx import Timeout from langchain_openai import ChatOpenAI from modules.nodes.state import ChatState from tools import ( tool_format_scripture_answer, tool_get_standardized_prabandham_names, tool_search_db, tool_search_web, tool_push, tool_get_standardized_azhwar_names, tool_get_standardized_divya_desam_names, ) tools = [ tool_search_db, tool_get_standardized_azhwar_names, tool_get_standardized_prabandham_names, tool_get_standardized_divya_desam_names, ## disabling this tool as it is causing more latency # tool_format_scripture_answer, # tool_search_web, # tool_push, ] llm = ChatOpenAI( model="gpt-4o-mini", temperature=0.2, max_retries=0, timeout=Timeout(60.0) ).bind_tools(tools) def _truncate_messages_for_token_limit(messages, max_tokens=50000): """ Truncate messages to stay under token limit while preserving assistant-tool_call integrity. """ return messages total_tokens = 0 result = [] # iterate oldest → newest to preserve conversation for msg in messages: content = getattr(msg, "content", "") msg_tokens = len(content) // 4 # gather tool responses if any tool_calls = getattr(msg, "additional_kwargs", {}).get("tool_calls", []) group = [msg] for call in tool_calls: for m in messages: if ( getattr(m, "additional_kwargs", {}).get("tool_call_id") == call["id"] ): group.append(m) group_tokens = sum(len(getattr(m, "content", "")) // 4 for m in group) # if this whole group would exceed the limit, stop if total_tokens + group_tokens > max_tokens: break total_tokens += group_tokens result.extend(group) return result def chatNode(state: ChatState) -> ChatState: # logger.info("messages before LLM: %s", str(state["messages"])) state["messages"] = _truncate_messages_for_token_limit( messages=state["messages"] ) response = llm.invoke(state["messages"]) state["messages"] = state["messages"] + [response] return state