Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| import warnings | |
| from typing import Optional | |
| import torch | |
| from accelerate.utils import extract_model_from_parallel | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| from transformers.utils import is_rich_available | |
| if is_rich_available(): | |
| from rich import print | |
| from rich.text import Text | |
| class StringStoppingCriteria(StoppingCriteria): | |
| """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" | |
| def __init__(self, stop_strings, tokenizer): | |
| self.stop_strings = stop_strings | |
| self.tokenizer = tokenizer | |
| self.first_call = True | |
| def __call__(self, input_ids, scores, **kwargs): | |
| """Returns true if all generated sequences contain any of the stop strings.""" | |
| if self.first_call: | |
| self.generated_tokens = [1 for _ in range(input_ids.shape[0])] | |
| self.start_length = input_ids.shape[-1] - 1 | |
| self.first_call = False | |
| decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) | |
| done = [] | |
| for i, decoded_generation in enumerate(decoded_generations): | |
| sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings) | |
| done.append(sequence_complete) | |
| if not sequence_complete: | |
| self.generated_tokens[i] += 1 | |
| if all(done): | |
| self.first_call = True | |
| return all(done) | |
| class TextHistory: | |
| """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" | |
| def __init__(self, text, tokens, system=True): | |
| """ | |
| Initialize TextHistory. | |
| Args: | |
| text (`str`): The text of the first segment. | |
| tokens (`torch.LongTensor`): The tokens of the first segment. | |
| system (`bool`, *optional*): Whether the first segment is a system or user segment. | |
| """ | |
| self.system_spans = [] | |
| self.text_spans = [] | |
| self.token_spans = [] | |
| self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) | |
| self.text = "" | |
| self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) | |
| self.completed = False | |
| self.truncated = False | |
| self.reward = 0.0 | |
| self.prompt_color = "black on grey85" | |
| self.system_color = "black on cyan3" | |
| self.model_color = "black on deep_sky_blue1" | |
| self.reward_color = "black on plum1" | |
| self.append_segment(text, tokens, system=system) | |
| def append_segment(self, text, tokens, system=True): | |
| """ | |
| Append a new segment to the history. | |
| Args: | |
| text (`str`): The text of the new segment. | |
| tokens (`torch.LongTensor`): The tokens of the new segment. | |
| system (`bool`, *optional*): Whether the new segment is a system or user segment. | |
| """ | |
| if len(text) == 0 or len(tokens) == 0: | |
| raise ValueError("Can't append empty text or token list to history.") | |
| original_text_length = len(self.text) | |
| self.text += text | |
| self.text_spans.append((original_text_length, len(self.text))) | |
| self.system_spans.append(system) | |
| original_token_length = len(self.tokens) | |
| self.tokens = torch.cat((self.tokens, tokens)) | |
| if system: | |
| self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) | |
| else: | |
| self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) | |
| self.token_spans.append((original_token_length, len(self.tokens))) | |
| def complete(self, truncated=False): | |
| """ | |
| Mark the history as completed. | |
| """ | |
| self.completed = True | |
| self.truncated = truncated | |
| def last_text_segment(self): | |
| """ | |
| Get the last text segment. | |
| """ | |
| start, end = self.text_spans[-1] | |
| return self.text[start:end] | |
| def split_query_response_tokens(self): | |
| """ | |
| Split the tokens into query and response tokens. | |
| """ | |
| split_index = self.token_spans[0][1] | |
| query = self.tokens[:split_index] | |
| response = self.tokens[split_index:] | |
| mask = self.token_masks[split_index:] | |
| return query, response, mask | |
| def show_text(self, show_legend=False): | |
| """ | |
| Print the text history. | |
| """ | |
| if not is_rich_available(): | |
| raise ImportError( | |
| "The `rich` library is required to display text with formatting. Install it using `pip install rich`." | |
| ) | |
| text = Text(self.text) | |
| text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) | |
| for i, (start, end) in enumerate(self.text_spans[1:]): | |
| if self.system_spans[i + 1]: | |
| text.stylize(self.system_color, start, end) | |
| else: | |
| text.stylize(self.model_color, start, end) | |
| text.append(f"\n\nReward: {self.reward}", style=self.reward_color) | |
| print(text) | |
| if show_legend: | |
| self.show_colour_legend() | |
| def show_tokens(self, tokenizer, show_legend=False): | |
| """ | |
| Print the history tokens. | |
| """ | |
| if not is_rich_available(): | |
| raise ImportError( | |
| "The `rich` library is required to display tokens with formatting. " | |
| "Install it using `pip install rich`." | |
| ) | |
| text = Text() | |
| prompt_end = self.token_spans[0][1] | |
| for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): | |
| if i < prompt_end: | |
| text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) | |
| text.append(" ") | |
| elif mask == 0: | |
| text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) | |
| text.append(" ") | |
| else: | |
| text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) | |
| text.append(" ") | |
| text.append(f"\n\nReward: {self.reward}", style=self.reward_color) | |
| print(text) | |
| if show_legend: | |
| self.show_colour_legend() | |
| def show_colour_legend(self): | |
| """ | |
| Print the colour legend. | |
| """ | |
| if not is_rich_available(): | |
| raise ImportError( | |
| "The `rich` library is required to display colour legends with formatting. " | |
| "Install it using `pip install rich`." | |
| ) | |
| text = Text("\n\n(Colour Legend: ") | |
| text.append("Prompt", style=self.prompt_color) | |
| text.append("|") | |
| text.append("System", style=self.system_color) | |
| text.append("|") | |
| text.append("Model", style=self.model_color) | |
| text.append("|") | |
| text.append("Reward", style=self.reward_color) | |
| text.append(")") | |
| print(text) | |
| class TextEnvironment: | |
| """ | |
| The TextEnvironment enables interaction of a LLM with an environment using tools. | |
| """ | |
| def __init__( | |
| self, | |
| model=None, | |
| tokenizer=None, | |
| tools=None, | |
| reward_fn=None, | |
| prompt=None, | |
| max_turns=4, | |
| max_tool_response=100, | |
| max_length=None, | |
| generation_kwargs=None, | |
| ): | |
| """ | |
| Initialize TextEnvironment. | |
| Args: | |
| model (`PreTrainedModelWrapper`): The model to use for generation. | |
| tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. | |
| tools (list): A list of tools to use for interaction. | |
| reward_fn (function): A function that takes a string and returns a reward. | |
| prompt (str): The base prompt to use for generation. Is prepended to the tasks. | |
| max_turns (Optional[int]): The maximum number of turns to allow. | |
| max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. | |
| max_length (Optional[int]): The maximum number of tokens to allow in an episode. | |
| generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. | |
| """ | |
| warnings.warn( | |
| "This class is deprecated and will be removed in version 0.21.0. To enable tool use with LLMs, check out smolagents (https://huggingface.co/docs/smolagents/index)", | |
| DeprecationWarning, | |
| ) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.prompt = prompt | |
| if isinstance(tools, dict): | |
| self.tools = tools | |
| else: | |
| self.tools = {tool.__class__.__name__: tool for tool in tools} | |
| self.reward_fn = reward_fn | |
| self.max_length = max_length | |
| self.request_token = "<request>" | |
| self.call_token = "<call>" | |
| self.response_token = "<response>" | |
| self.submit_token = "<submit>" | |
| self.max_turns = max_turns | |
| self.max_tool_response = max_tool_response | |
| if generation_kwargs is None: | |
| self.generation_kwargs = dict() | |
| else: | |
| self.generation_kwargs = generation_kwargs | |
| self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") | |
| self.current_device = extract_model_from_parallel(self.model).pretrained_model.device | |
| def run(self, queries, **rewards_kwargs): | |
| """ | |
| Run the environment on a list of queries. | |
| Args: | |
| queries (list[str]): A list of queries to run the model in the environment on. | |
| """ | |
| turns = 0 | |
| queries = [self.prompt + task for task in queries] | |
| queries_tokens = [ | |
| self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) | |
| for query in queries | |
| ] | |
| histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] | |
| while any(not history.completed for history in histories) and turns < self.max_turns: | |
| histories = self.generate(histories) | |
| histories = self.tasks_end_check(histories) | |
| # TODO: make this parallel rather than for-loop | |
| for i in range(len(histories)): | |
| histories[i] = self.step(histories[i]) | |
| histories = self.tasks_end_check(histories, model_turn=False) | |
| turns += 1 | |
| self.compute_reward(histories, **rewards_kwargs) | |
| # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively | |
| queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) | |
| rewards = [history.reward for history in histories] | |
| return queries, responses, masks, rewards, histories | |
| def step(self, history): | |
| """ | |
| Step the environment forward one turn. | |
| Args: | |
| history (`TextHistory`): The history to step forward. | |
| """ | |
| truncated, ended = self.task_end_check(history) | |
| if ended: | |
| history.complete(truncated=truncated) | |
| if history.completed: | |
| return history | |
| tool, query = self.parse_tool_call(history.last_text_segment) | |
| if tool is None or query is None: | |
| response = f"Unknown tool call: {history.last_text_segment}" | |
| else: | |
| if tool not in self.tools: | |
| response = f"Unknown tool {tool}." | |
| try: | |
| response = self.tools[tool](query) | |
| except Exception as error: | |
| response = f"Tool error: {str(error)}" | |
| if len(response) > self.max_tool_response: | |
| response = response[: (self.max_tool_response - 3)] + "..." | |
| history.append_segment( | |
| response + self.response_token, | |
| self.tokenizer(response + self.response_token, return_tensors="pt") | |
| .input_ids[0] | |
| .to(self.model.pretrained_model.device), | |
| system=True, | |
| ) | |
| return history | |
| def parse_tool_call(self, text): | |
| """ | |
| Parse request string. Expected format: <request><tool_name>query<call> | |
| """ | |
| result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) | |
| # if we can't find a <request>/<call> span we return none | |
| if result is None: | |
| return None, None | |
| else: | |
| extracted_text = result.group() | |
| result = re.search(r"<(.*?)>", extracted_text) | |
| # if we can't find a tool name we return none | |
| if result is None: | |
| return None, None | |
| else: | |
| tool = result.group(1) | |
| # split off the tool name | |
| query = ">".join(extracted_text.split(">")[1:]) | |
| return tool, query | |
| def compute_reward(self, histories, **reward_kwargs): | |
| """ | |
| Compute the reward for a list of histories. | |
| """ | |
| rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) | |
| for history, reward in zip(histories, rewards): | |
| history.reward = reward | |
| return histories | |
| def generate(self, histories): | |
| """ | |
| Generate responses for a list of histories. | |
| """ | |
| active_histories = [i for i, history in enumerate(histories) if not history.completed] | |
| query_tensors = [histories[i].tokens for i in active_histories] | |
| response_tensors = self._generate_batched(query_tensors) | |
| response_texts = self.tokenizer.batch_decode(response_tensors) | |
| for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): | |
| histories[i].append_segment(response_text, response_tensor, system=False) | |
| return histories | |
| def tasks_end_check(self, histories, model_turn=True): | |
| """ | |
| Check if the current generation sequences have finished. | |
| """ | |
| for history in histories: | |
| if not history.completed: | |
| truncated, ended = self.task_end_check(history, model_turn=model_turn) | |
| if ended: | |
| history.complete(truncated=truncated) | |
| return histories | |
| def task_end_check(self, history, model_turn=True): | |
| """ | |
| Check if the current generation sequence has finished. | |
| """ | |
| truncated = False | |
| ended = False | |
| if history.completed: | |
| return truncated, ended | |
| if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: | |
| truncated = True | |
| ended = True | |
| elif self.tokenizer.eos_token in history.text: | |
| ended = True | |
| elif model_turn and not ( | |
| (self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) | |
| or self.submit_token in history.last_text_segment | |
| ): | |
| ended = True | |
| elif self.submit_token in history.last_text_segment: | |
| ended = True | |
| return truncated, ended | |
| def _generate_batched( | |
| self, | |
| query_tensors, | |
| batch_size: int = 16, | |
| pad_to_multiple_of: Optional[int] = None, | |
| ): | |
| """ | |
| Generate responses for a list of query tensors. | |
| Args: | |
| query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. | |
| batch_size (int): The batch size to use for generation. | |
| pad_to_multiple_of (int): The padding length to use for generation. | |
| """ | |
| outputs = [] | |
| padding_side_default = self.tokenizer.padding_side | |
| if not self.is_encoder_decoder: | |
| self.tokenizer.padding_side = "left" | |
| # in case we have fewer examples than bs | |
| batch_size = min(len(query_tensors), batch_size) | |
| for i in range(0, len(query_tensors), batch_size): | |
| # prevent overflow if query tensors are not even multiple of bs | |
| end_index = min(len(query_tensors), i + batch_size) | |
| batch = query_tensors[i:end_index] | |
| batch_mask = [torch.ones_like(element) for element in batch] | |
| inputs = {"input_ids": batch, "attention_mask": batch_mask} | |
| padded_inputs = self.tokenizer.pad( | |
| inputs, | |
| padding=True, | |
| max_length=None, | |
| pad_to_multiple_of=pad_to_multiple_of, | |
| return_tensors="pt", | |
| ).to(self.current_device) | |
| stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) | |
| self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) | |
| generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) | |
| for generation, mask, generated_tokens in zip( | |
| generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens | |
| ): | |
| if not self.is_encoder_decoder: | |
| output = generation[(1 - mask).sum() :] # remove padding | |
| else: | |
| output = generation | |
| if not self.is_encoder_decoder: | |
| output = output[(mask).sum() :] # remove prompt | |
| # remove chunk generated after stopping criteria in batch mode | |
| outputs.append(output[:generated_tokens]) | |
| self.tokenizer.padding_side = padding_side_default | |
| return outputs | |