|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
from typing_extensions import Self |
|
|
|
|
|
from camel.agents import ChatAgent |
|
|
from camel.data_collector.base import BaseDataCollector |
|
|
from camel.messages import AlpacaItem, BaseMessage |
|
|
from camel.schemas import OpenAISchemaConverter |
|
|
|
|
|
|
|
|
DEFAULT_CONVERTER_PROMPTS = """ |
|
|
Extract key entities and attributes from the conversations |
|
|
and convert them into a structured JSON format. |
|
|
For example: |
|
|
Instruction: You are a helpful assistant. |
|
|
User: When is the release date of the video game Portal? |
|
|
Assistant: The release date of the video game Portal is October 9. |
|
|
Your output should be: |
|
|
{ |
|
|
"instruction": "You are a helpful assistant. When is the release date of the video game Portal?", |
|
|
"input": "", |
|
|
"output": "The release date of the video game Portal is October 9." |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
class AlpacaDataCollector(BaseDataCollector): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.system_message: Optional[BaseMessage] = None |
|
|
self.agent_name: Optional[str] = None |
|
|
|
|
|
def record( |
|
|
self, |
|
|
agent: Union[List[ChatAgent], ChatAgent], |
|
|
) -> Self: |
|
|
r"""Inject an agent into the data collector. |
|
|
|
|
|
Args: |
|
|
agent (Union[List[ChatAgent], ChatAgent]): |
|
|
The agent to inject. |
|
|
""" |
|
|
if not self.agent_name: |
|
|
_agent = agent if isinstance(agent, ChatAgent) else agent[0] |
|
|
self.agent_name = _agent.role_name |
|
|
self.system_message = _agent._system_message |
|
|
super().record(agent) |
|
|
return self |
|
|
|
|
|
def convert(self) -> Dict[str, Any]: |
|
|
r"""Convert the collected data into a dictionary.""" |
|
|
if self.agent_name is None: |
|
|
raise ValueError("No agent injected") |
|
|
|
|
|
history = self.get_agent_history(self.agent_name) |
|
|
if not history: |
|
|
raise ValueError("No data collected.") |
|
|
|
|
|
|
|
|
if len(history) == 3 and history[0].role == "system": |
|
|
history = history[1:] |
|
|
elif len(history) != 2: |
|
|
raise ValueError( |
|
|
f"AlpacaDataCollector only supports one message pair, but " |
|
|
f"got {len(history)}" |
|
|
) |
|
|
|
|
|
input_message, output_message = history |
|
|
instruction = ( |
|
|
self.system_message.content if self.system_message else "" |
|
|
) + str(input_message.message) |
|
|
|
|
|
data = { |
|
|
"instruction": instruction, |
|
|
"input": "", |
|
|
"output": output_message.message, |
|
|
} |
|
|
self.data.append(data) |
|
|
return data |
|
|
|
|
|
def llm_convert( |
|
|
self, |
|
|
converter: Optional[OpenAISchemaConverter] = None, |
|
|
prompt: Optional[str] = None, |
|
|
) -> Dict[str, str]: |
|
|
r"""Convert collected data using an LLM schema converter. |
|
|
|
|
|
Args: |
|
|
converter (Optional[OpenAISchemaConverter], optional): |
|
|
The converter to use. (default: :obj:`OpenAISchemaConverter`) |
|
|
prompt (Optional[str], optional): Prompt to guide the conversion. |
|
|
(default: :obj:`DEFAULT_CONVERTER_PROMPTS`) |
|
|
|
|
|
Returns: |
|
|
Dict[str, str]: The converted data. |
|
|
|
|
|
Raises: |
|
|
ValueError: If no agent is injected or data cannot be collected. |
|
|
""" |
|
|
prompt = prompt or DEFAULT_CONVERTER_PROMPTS |
|
|
converter = converter or OpenAISchemaConverter() |
|
|
|
|
|
system = self.system_message.content if self.system_message else "" |
|
|
context = [f"Instruction: {system}\n"] |
|
|
|
|
|
for message in self.get_agent_history(str(self.agent_name)): |
|
|
if message.role == "user": |
|
|
context.append(f"User: {message.message}\n") |
|
|
else: |
|
|
context.append(f"{message.name}: {message.message}\n") |
|
|
return converter.convert( |
|
|
"\n".join(context), AlpacaItem, prompt=prompt |
|
|
).model_dump() |
|
|
|