import gradio as gr from dataclasses import dataclass import os from supabase import create_client, Client from supabase.client import ClientOptions from enum import Enum from datasets import get_dataset_infos from transformers import AutoConfig from huggingface_hub import whoami from typing import Optional, List, Tuple, Union """ Still TODO: - validate the user is PRO - check the output dataset token is valid (hardcoded for now as a secret) - validate max model params """ def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool: """Verifies if the user is a Hugging Face PRO user or part of an enterprise org.""" if not token: return False if isinstance(token, gr.OAuthToken): token_str = token.token elif isinstance(token, str): token_str = token else: return False try: user_info = whoami(token=token_str) return ( user_info.get("isPro", False) or any(org.get("isEnterprise", False) for org in user_info.get("orgs", [])) ) except Exception as e: print(f"Could not verify user's PRO/Enterprise status: {e}") return False class GenerationStatus(Enum): PENDING = "PENDING" RUNNING = "RUNNING" COMPLETED = "COMPLETED" FAILED = "FAILED" MAX_SAMPLES = 10000 # max number of samples in the input dataset MAX_TOKENS = 8192 MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now) @dataclass class GenerationRequest: id: str created_at: str status: GenerationStatus input_dataset_name: str input_dataset_config: str input_dataset_split: str output_dataset_name: str prompt_column: str model_name_or_path: str model_revision: str model_token: str | None system_prompt: str | None max_tokens: int temperature: float top_k: int top_p: float input_dataset_token: str | None output_dataset_token: str username: str email: str num_output_examples: int private: bool = False num_retries: int = 0 def validate_request(request: GenerationRequest) -> GenerationRequest: # checks that the request is valid # - input dataset exists and can be accessed with the provided token try: input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config] except Exception as e: raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.") # check that the input dataset split exists if request.input_dataset_split not in input_dataset_info.splits: raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}") # if num_output_examples is 0, set it to the number of examples in the input dataset split if request.num_output_examples == 0: request.num_output_examples = input_dataset_info.splits[request.input_dataset_split].num_examples else: if request.num_output_examples > input_dataset_info.splits[request.input_dataset_split].num_examples: raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the number of examples in the input dataset split {input_dataset_info.splits[request.input_dataset_split].num_examples}.") request.input_dataset_split = f"{request.input_dataset_split}[:{request.num_output_examples}]" if request.num_output_examples > MAX_SAMPLES: raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the max limit of {MAX_SAMPLES}.") # check the prompt column exists in the dataset if request.prompt_column not in input_dataset_info.features: raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}") # This is currently not supported, the output dataset will be created under the org 'synthetic-data-universe' # check output_dataset name is valid if request.output_dataset_name.count("/") != 1: raise Exception("Output dataset name must be in the format 'dataset_name', e.g., 'my-dataset'. The dataset will be created under the org 'synthetic-data-universe/my-dataset'.") # check the output dataset is valid and accessible with the provided token try: output_dataset_info = get_dataset_infos(request.output_dataset_name, token=request.output_dataset_token) raise Exception(f"Output dataset {request.output_dataset_name} already exists. Please choose a different name.") except Exception as e: pass # dataset does not exist, which is expected # check the models exists try: model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, force_download=True, token=False ) except Exception as e: print(e) raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed. The model may be private or gated, which is not supported at this time.") # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS if model_config.max_position_embeddings < request.max_tokens: raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.") if request.max_tokens > MAX_TOKENS: raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.") # check sampling parameters are valid if request.temperature < 0.0 or request.temperature > 2.0: raise Exception("Temperature must be between 0.0 and 2.0") if request.top_k < 1 or request.top_k > 100: raise Exception("Top K must be between 1 and 100") if request.top_p < 0.0 or request.top_p > 1.0: raise Exception("Top P must be between 0.0 and 1.0") # check valid email address TODO: use py3-validate-email https://stackoverflow.com/questions/8022530/how-to-check-for-valid-email-address if "@" not in request.email or "." not in request.email.split("@")[-1]: raise Exception("Invalid email address") return request def add_request_to_db(request: GenerationRequest): url: str = os.getenv("SUPABASE_URL") key: str = os.getenv("SUPABASE_KEY") try: supabase: Client = create_client( url, key, options=ClientOptions( postgrest_client_timeout=10, storage_client_timeout=10, schema="public", ) ) data = { "status": request.status.value, "input_dataset_name": request.input_dataset_name, "input_dataset_config": request.input_dataset_config, "input_dataset_split": request.input_dataset_split, "output_dataset_name": request.output_dataset_name, "prompt_column": request.prompt_column, "model_name_or_path": request.model_name_or_path, "model_revision": request.model_revision, "model_token": request.model_token, "system_prompt": request.system_prompt, "max_tokens": request.max_tokens, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "input_dataset_token": request.input_dataset_token, "output_dataset_token": request.output_dataset_token, "username": request.username, "email": request.email, "num_output_examples": request.num_output_examples, "private": request.private, } supabase.table("gen-requests").insert(data).execute() except Exception as e: raise Exception("Failed to add request to database") def main(): with gr.Blocks(title="Synthetic Data Generation") as demo: gr.HTML("