Spaces:
Runtime error
Runtime error
File size: 9,989 Bytes
a14ae24 e0cbbfe a14ae24 e0cbbfe a14ae24 e0cbbfe a14ae24 e0cbbfe a14ae24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
"""
Module implements an agent that uses OpenAI's APIs function enabled API.
This file is a modified version of the original file
from langchain/agents/openai_functions_agent/base.py.
Credits go to the original authors :)
"""
import json
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union
from pydantic import root_validator
from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
FunctionMessage,
OutputParserException,
HumanMessage,
SystemMessage,
)
from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
@dataclass
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
async def _parse_ai_message(
message: BaseMessage, llm: BaseLanguageModel
) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
if function_name == "python":
code = function_call["arguments"]
_tool_input = {
"code": code,
}
else:
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n"
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
message_log=[message],
)
return AgentFinish(return_values={"output": message.content}, log=message.content)
class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""An Agent driven by OpenAIs function powered API.
Args:
llm: This should be an instance of ChatOpenAI, specifically a model
that supports using `functions`.
tools: The tools this agent has access to.
prompt: The prompt for this agent, should support agent_scratchpad as one
of the variables. For an easy way to construct this prompt, use
`OpenAIFunctionsAgent.create_prompt(...)`
"""
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
return list([t.name for t in self.tools])
@root_validator
def validate_llm(cls, values: dict) -> dict:
if not isinstance(values["llm"], ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
return values
@root_validator
def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"]
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be one of the variables in the prompt, "
f"got {prompt.input_variables}"
)
return values
@property
def input_keys(self) -> List[str]:
"""Get input keys. Input refers to user input here."""
return ["input"]
@property
def functions(self) -> List[dict]:
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
def plan(self):
raise NotImplementedError
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = await _parse_ai_message(predicted_message, self.llm)
return agent_decision
@classmethod
def create_prompt(
cls,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate:
"""Create prompt for this agent.
Args:
system_message: Message to use as the system message that will be the
first in the prompt.
extra_prompt_messages: Prompt messages that will be placed between the
system message and the new human input.
Returns:
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages.extend(
[
*_prompts,
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages) # type: ignore
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools."""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
llm=llm,
prompt=prompt,
tools=tools,
callback_manager=callback_manager, # type: ignore
**kwargs,
)
|