Spaces:
Paused
Paused
| import os, json, time | |
| import litellm | |
| from litellm.utils import ModelResponse | |
| import requests, threading | |
| from typing import Optional, Union, Literal | |
| class BudgetManager: | |
| def __init__( | |
| self, | |
| project_name: str, | |
| client_type: str = "local", | |
| api_base: Optional[str] = None, | |
| ): | |
| self.client_type = client_type | |
| self.project_name = project_name | |
| self.api_base = api_base or "https://api.litellm.ai" | |
| ## load the data or init the initial dictionaries | |
| self.load_data() | |
| def print_verbose(self, print_statement): | |
| try: | |
| if litellm.set_verbose: | |
| import logging | |
| logging.info(print_statement) | |
| except: | |
| pass | |
| def load_data(self): | |
| if self.client_type == "local": | |
| # Check if user dict file exists | |
| if os.path.isfile("user_cost.json"): | |
| # Load the user dict | |
| with open("user_cost.json", "r") as json_file: | |
| self.user_dict = json.load(json_file) | |
| else: | |
| self.print_verbose("User Dictionary not found!") | |
| self.user_dict = {} | |
| self.print_verbose(f"user dict from local: {self.user_dict}") | |
| elif self.client_type == "hosted": | |
| # Load the user_dict from hosted db | |
| url = self.api_base + "/get_budget" | |
| headers = {"Content-Type": "application/json"} | |
| data = {"project_name": self.project_name} | |
| response = requests.post(url, headers=headers, json=data) | |
| response = response.json() | |
| if response["status"] == "error": | |
| self.user_dict = ( | |
| {} | |
| ) # assume this means the user dict hasn't been stored yet | |
| else: | |
| self.user_dict = response["data"] | |
| def create_budget( | |
| self, | |
| total_budget: float, | |
| user: str, | |
| duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, | |
| created_at: float = time.time(), | |
| ): | |
| self.user_dict[user] = {"total_budget": total_budget} | |
| if duration is None: | |
| return self.user_dict[user] | |
| if duration == "daily": | |
| duration_in_days = 1 | |
| elif duration == "weekly": | |
| duration_in_days = 7 | |
| elif duration == "monthly": | |
| duration_in_days = 28 | |
| elif duration == "yearly": | |
| duration_in_days = 365 | |
| else: | |
| raise ValueError( | |
| """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" | |
| ) | |
| self.user_dict[user] = { | |
| "total_budget": total_budget, | |
| "duration": duration_in_days, | |
| "created_at": created_at, | |
| "last_updated_at": created_at, | |
| } | |
| self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution | |
| return self.user_dict[user] | |
| def projected_cost(self, model: str, messages: list, user: str): | |
| text = "".join(message["content"] for message in messages) | |
| prompt_tokens = litellm.token_counter(model=model, text=text) | |
| prompt_cost, _ = litellm.cost_per_token( | |
| model=model, prompt_tokens=prompt_tokens, completion_tokens=0 | |
| ) | |
| current_cost = self.user_dict[user].get("current_cost", 0) | |
| projected_cost = prompt_cost + current_cost | |
| return projected_cost | |
| def get_total_budget(self, user: str): | |
| return self.user_dict[user]["total_budget"] | |
| def update_cost( | |
| self, | |
| user: str, | |
| completion_obj: Optional[ModelResponse] = None, | |
| model: Optional[str] = None, | |
| input_text: Optional[str] = None, | |
| output_text: Optional[str] = None, | |
| ): | |
| if model and input_text and output_text: | |
| prompt_tokens = litellm.token_counter( | |
| model=model, messages=[{"role": "user", "content": input_text}] | |
| ) | |
| completion_tokens = litellm.token_counter( | |
| model=model, messages=[{"role": "user", "content": output_text}] | |
| ) | |
| ( | |
| prompt_tokens_cost_usd_dollar, | |
| completion_tokens_cost_usd_dollar, | |
| ) = litellm.cost_per_token( | |
| model=model, | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| ) | |
| cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar | |
| elif completion_obj: | |
| cost = litellm.completion_cost(completion_response=completion_obj) | |
| model = completion_obj[ | |
| "model" | |
| ] # if this throws an error try, model = completion_obj['model'] | |
| else: | |
| raise ValueError( | |
| "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager" | |
| ) | |
| self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get( | |
| "current_cost", 0 | |
| ) | |
| if "model_cost" in self.user_dict[user]: | |
| self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][ | |
| "model_cost" | |
| ].get(model, 0) | |
| else: | |
| self.user_dict[user]["model_cost"] = {model: cost} | |
| self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution | |
| return {"user": self.user_dict[user]} | |
| def get_current_cost(self, user): | |
| return self.user_dict[user].get("current_cost", 0) | |
| def get_model_cost(self, user): | |
| return self.user_dict[user].get("model_cost", 0) | |
| def is_valid_user(self, user: str) -> bool: | |
| return user in self.user_dict | |
| def get_users(self): | |
| return list(self.user_dict.keys()) | |
| def reset_cost(self, user): | |
| self.user_dict[user]["current_cost"] = 0 | |
| self.user_dict[user]["model_cost"] = {} | |
| return {"user": self.user_dict[user]} | |
| def reset_on_duration(self, user: str): | |
| # Get current and creation time | |
| last_updated_at = self.user_dict[user]["last_updated_at"] | |
| current_time = time.time() | |
| # Convert duration from days to seconds | |
| duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 | |
| # Check if duration has elapsed | |
| if current_time - last_updated_at >= duration_in_seconds: | |
| # Reset cost if duration has elapsed and update the creation time | |
| self.reset_cost(user) | |
| self.user_dict[user]["last_updated_at"] = current_time | |
| self._save_data_thread() # Save the data | |
| def update_budget_all_users(self): | |
| for user in self.get_users(): | |
| if "duration" in self.user_dict[user]: | |
| self.reset_on_duration(user) | |
| def _save_data_thread(self): | |
| thread = threading.Thread( | |
| target=self.save_data | |
| ) # [Non-Blocking]: saves data without blocking execution | |
| thread.start() | |
| def save_data(self): | |
| if self.client_type == "local": | |
| import json | |
| # save the user dict | |
| with open("user_cost.json", "w") as json_file: | |
| json.dump( | |
| self.user_dict, json_file, indent=4 | |
| ) # Indent for pretty formatting | |
| return {"status": "success"} | |
| elif self.client_type == "hosted": | |
| url = self.api_base + "/set_budget" | |
| headers = {"Content-Type": "application/json"} | |
| data = {"project_name": self.project_name, "user_dict": self.user_dict} | |
| response = requests.post(url, headers=headers, json=data) | |
| response = response.json() | |
| return response | |