Shroominic
🧹 black cleanup
e0cbbfe
"""
Module implements an agent that uses OpenAI's APIs function enabled API.
This file is a modified version of the original file
from langchain/agents/openai_functions_agent/base.py.
Credits go to the original authors :)
"""
import json
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union
from pydantic import root_validator
from langchain.agents import BaseSingleActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
FunctionMessage,
OutputParserException,
HumanMessage,
SystemMessage,
)
from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
@dataclass
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
async def _parse_ai_message(
message: BaseMessage, llm: BaseLanguageModel
) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
if function_name == "python":
code = function_call["arguments"]
_tool_input = {
"code": code,
}
else:
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n"
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
message_log=[message],
)
return AgentFinish(return_values={"output": message.content}, log=message.content)
class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""An Agent driven by OpenAIs function powered API.
Args:
llm: This should be an instance of ChatOpenAI, specifically a model
that supports using `functions`.
tools: The tools this agent has access to.
prompt: The prompt for this agent, should support agent_scratchpad as one
of the variables. For an easy way to construct this prompt, use
`OpenAIFunctionsAgent.create_prompt(...)`
"""
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
def get_allowed_tools(self) -> List[str]:
"""Get allowed tools."""
return list([t.name for t in self.tools])
@root_validator
def validate_llm(cls, values: dict) -> dict:
if not isinstance(values["llm"], ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
return values
@root_validator
def validate_prompt(cls, values: dict) -> dict:
prompt: BasePromptTemplate = values["prompt"]
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be one of the variables in the prompt, "
f"got {prompt.input_variables}"
)
return values
@property
def input_keys(self) -> List[str]:
"""Get input keys. Input refers to user input here."""
return ["input"]
@property
def functions(self) -> List[dict]:
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
def plan(self):
raise NotImplementedError
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = await _parse_ai_message(predicted_message, self.llm)
return agent_decision
@classmethod
def create_prompt(
cls,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate:
"""Create prompt for this agent.
Args:
system_message: Message to use as the system message that will be the
first in the prompt.
extra_prompt_messages: Prompt messages that will be placed between the
system message and the new human input.
Returns:
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
messages = []
messages.extend(
[
*_prompts,
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages) # type: ignore
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
"""Construct an agent from an LLM and tools."""
if not isinstance(llm, ChatOpenAI):
raise ValueError("Only supported with ChatOpenAI models.")
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
llm=llm,
prompt=prompt,
tools=tools,
callback_manager=callback_manager, # type: ignore
**kwargs,
)