vikramvasudevan commited on
Commit
a88708b
·
verified ·
1 Parent(s): f5fe593

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. embeddings.py +11 -6
  2. graph_helper.py +17 -4
embeddings.py CHANGED
@@ -51,14 +51,19 @@ def _get_openai_embedding(texts: list[str]) -> list[list[float]]:
51
 
52
  return final_embeddings
53
 
 
 
54
  def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]:
55
- """
56
- Get embeddings for a list of texts.
57
- backend = "openai" or "hf"
58
- """
59
  if backend == "hf":
60
- return _get_hf_embedding(texts)
61
- return _get_openai_embedding(texts)
 
 
 
62
 
63
  # -------------------------------
64
  # Example
 
51
 
52
  return final_embeddings
53
 
54
+ embedding_cache = {}
55
+
56
  def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]:
57
+ key = (backend, tuple(texts)) # tuple is hashable
58
+ if key in embedding_cache:
59
+ return embedding_cache[key]
60
+
61
  if backend == "hf":
62
+ embedding_cache[key] = _get_hf_embedding(texts)
63
+ else:
64
+ embedding_cache[key] = _get_openai_embedding(texts)
65
+
66
+ return embedding_cache[key]
67
 
68
  # -------------------------------
69
  # Example
graph_helper.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  from typing import Annotated, TypedDict
 
3
  from langgraph.graph import StateGraph, START, END
4
  from langgraph.checkpoint.memory import MemorySaver
5
  from langgraph.graph.message import add_messages
@@ -45,8 +46,10 @@ def branching_condition(state: ChatState) -> str:
45
  else:
46
  return check_debug_condition(state)
47
 
 
48
  from typing import List, Dict
49
 
 
50
  def truncate_messages_for_token_limit(messages, max_tokens=50000):
51
  """
52
  messages: list of dicts or LangChain messages
@@ -65,7 +68,10 @@ def truncate_messages_for_token_limit(messages, max_tokens=50000):
65
  for call in tool_calls:
66
  # include corresponding tool messages
67
  for m in messages:
68
- if getattr(m, "additional_kwargs", {}).get("tool_call_id") == call["id"]:
 
 
 
69
  group.append(m)
70
 
71
  group_tokens = sum(len(getattr(m, "content", "")) // 4 for m in group)
@@ -78,6 +84,7 @@ def truncate_messages_for_token_limit(messages, max_tokens=50000):
78
 
79
  return result
80
 
 
81
  def generate_graph() -> CompiledStateGraph:
82
  memory = MemorySaver()
83
  tools = [
@@ -91,12 +98,18 @@ def generate_graph() -> CompiledStateGraph:
91
  tool_search_db_by_metadata,
92
  tool_search_db_for_literal,
93
  ]
94
- llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.2).bind_tools(tools)
95
- llm_without_tools = ChatOpenAI(model="gpt-4o-mini", temperature=0.1)
 
 
 
 
96
 
97
  def chatNode(state: ChatState) -> ChatState:
98
  # logger.info("messages before LLM: %s", str(state["messages"]))
99
- state["messages"] = truncate_messages_for_token_limit(messages=state["messages"])
 
 
100
  response = llm.invoke(state["messages"])
101
  # return {"messages": [response]}
102
  return {"messages": state["messages"] + [response]}
 
1
  import json
2
  from typing import Annotated, TypedDict
3
+ from httpx import Timeout
4
  from langgraph.graph import StateGraph, START, END
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langgraph.graph.message import add_messages
 
46
  else:
47
  return check_debug_condition(state)
48
 
49
+
50
  from typing import List, Dict
51
 
52
+
53
  def truncate_messages_for_token_limit(messages, max_tokens=50000):
54
  """
55
  messages: list of dicts or LangChain messages
 
68
  for call in tool_calls:
69
  # include corresponding tool messages
70
  for m in messages:
71
+ if (
72
+ getattr(m, "additional_kwargs", {}).get("tool_call_id")
73
+ == call["id"]
74
+ ):
75
  group.append(m)
76
 
77
  group_tokens = sum(len(getattr(m, "content", "")) // 4 for m in group)
 
84
 
85
  return result
86
 
87
+
88
  def generate_graph() -> CompiledStateGraph:
89
  memory = MemorySaver()
90
  tools = [
 
98
  tool_search_db_by_metadata,
99
  tool_search_db_for_literal,
100
  ]
101
+ llm = ChatOpenAI(
102
+ model="gpt-4o-mini", temperature=0.2, max_retries=0, timeout=Timeout(60.0)
103
+ ).bind_tools(tools)
104
+ llm_without_tools = ChatOpenAI(
105
+ model="gpt-4o-mini", temperature=0.1, max_retries=0, timeout=Timeout(60.0)
106
+ )
107
 
108
  def chatNode(state: ChatState) -> ChatState:
109
  # logger.info("messages before LLM: %s", str(state["messages"]))
110
+ state["messages"] = truncate_messages_for_token_limit(
111
+ messages=state["messages"]
112
+ )
113
  response = llm.invoke(state["messages"])
114
  # return {"messages": [response]}
115
  return {"messages": state["messages"] + [response]}