Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from dataclasses import dataclass | |
| import os | |
| from supabase import create_client, Client | |
| from supabase.client import ClientOptions | |
| from enum import Enum | |
| from datasets import get_dataset_infos | |
| from transformers import AutoConfig | |
| """ | |
| Still TODO: | |
| - validate the user is PRO | |
| - check the output dataset token is valid (hardcoded for now as a secret) | |
| - validate max model params | |
| """ | |
| class GenerationStatus(Enum): | |
| PENDING = "PENDING" | |
| RUNNING = "RUNNING" | |
| COMPLETED = "COMPLETED" | |
| FAILED = "FAILED" | |
| MAX_SAMPLES = 10000 # max number of samples in the input dataset | |
| MAX_TOKENS = 8192 | |
| MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now) | |
| class GenerationRequest: | |
| id: str | |
| created_at: str | |
| status: GenerationStatus | |
| input_dataset_name: str | |
| input_dataset_config: str | |
| input_dataset_split: str | |
| output_dataset_name: str | |
| prompt_column: str | |
| model_name_or_path: str | |
| model_revision: str | |
| model_token: str | None | |
| system_prompt: str | None | |
| max_tokens: int | |
| temperature: float | |
| top_k: int | |
| top_p: float | |
| input_dataset_token: str | None | |
| output_dataset_token: str | |
| username: str | |
| email: str | |
| num_output_examples: int | |
| private: bool = False | |
| num_retries: int = 0 | |
| def validate_request(request: GenerationRequest): | |
| # checks that the request is valid | |
| # - input dataset exists and can be accessed with the provided token | |
| try: | |
| input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config] | |
| except Exception as e: | |
| raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.") | |
| # check that the input dataset split exists | |
| if request.input_dataset_split not in input_dataset_info.splits: | |
| raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}") | |
| # check that the number of samples is less than MAX_SAMPLES | |
| if input_dataset_info.splits[request.input_dataset_split].num_examples > MAX_SAMPLES: | |
| request.num_output_examples = input_dataset_info.splits[request.input_dataset_split].num_examples | |
| raise Exception(f"Dataset split {request.input_dataset_split} in dataset {request.input_dataset_name} exceeds max sample limit of {MAX_SAMPLES}.") | |
| # check the prompt column exists in the dataset | |
| if request.prompt_column not in input_dataset_info.features: | |
| raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}") | |
| # check the models exists | |
| try: | |
| model_config = AutoConfig.from_pretrained(request.model_name_or_path, revision=request.model_revision, token=request.model_token) | |
| except Exception as e: | |
| print(e) | |
| raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed with the provided token.") | |
| # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS | |
| if model_config.max_position_embeddings < request.max_tokens: | |
| raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.") | |
| if request.max_tokens > MAX_TOKENS: | |
| raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.") | |
| # check sampling parameters are valid | |
| if request.temperature < 0.0 or request.temperature > 2.0: | |
| raise Exception("Temperature must be between 0.0 and 2.0") | |
| if request.top_k < 1 or request.top_k > 100: | |
| raise Exception("Top K must be between 1 and 100") | |
| if request.top_p < 0.0 or request.top_p > 1.0: | |
| raise Exception("Top P must be between 0.0 and 1.0") | |
| # check valid email address TODO: use py3-validate-email https://stackoverflow.com/questions/8022530/how-to-check-for-valid-email-address | |
| if "@" not in request.email or "." not in request.email.split("@")[-1]: | |
| raise Exception("Invalid email address") | |
| def add_request_to_db(request: GenerationRequest): | |
| url: str = os.getenv("SUPABASE_URL") | |
| key: str = os.getenv("SUPABASE_KEY") | |
| try: | |
| supabase: Client = create_client( | |
| url, | |
| key, | |
| options=ClientOptions( | |
| postgrest_client_timeout=10, | |
| storage_client_timeout=10, | |
| schema="public", | |
| ) | |
| ) | |
| data = { | |
| "status": request.status.value, | |
| "input_dataset_name": request.input_dataset_name, | |
| "input_dataset_config": request.input_dataset_config, | |
| "input_dataset_split": request.input_dataset_split, | |
| "output_dataset_name": request.output_dataset_name, | |
| "prompt_column": request.prompt_column, | |
| "model_name_or_path": request.model_name_or_path, | |
| "model_revision": request.model_revision, | |
| "model_token": request.model_token, | |
| "system_prompt": request.system_prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "input_dataset_token": request.input_dataset_token, | |
| "output_dataset_token": request.output_dataset_token, | |
| "username": request.username, | |
| "email": request.email, | |
| "num_output_examples": request.num_output_examples, | |
| "private": request.private, | |
| } | |
| supabase.table("gen-requests").insert(data).execute() | |
| except Exception as e: | |
| raise Exception("Failed to add request to database") | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="Synthetic Data Generation") as interface: | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown("# Synthetic Data Generation Request") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| Welcome to the Synthetic Data Generation service! This tool allows you to generate synthetic data using large language models. Generation is FREE for Hugging Face PRO users and uses idle GPUs on the HF science cluster.\n | |
| """) | |
| with gr.Group(): | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| **How it works:** | |
| 1. Provide an input dataset with prompts | |
| 2. Select a public language model for generation | |
| 3. Configure generation parameters | |
| 4. Submit your request and receive generated data | |
| """) | |
| gr.Markdown(""" | |
| **Requirements:** | |
| - Input dataset must be publicly accessible | |
| - Output dataset repository must exist and you must have write access | |
| - Model must be accessible (public or with valid token) | |
| - Maximum 10,000 samples per dataset | |
| - Maximum of 8192 generation tokens | |
| """) | |
| with gr.Group(): | |
| gr.Markdown("## Dataset information") | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_dataset_name = gr.Textbox(label="Input Dataset Name", placeholder="e.g., simplescaling/s1K-1.1") | |
| input_dataset_split = gr.Textbox(label="Input Dataset Split", value="train", placeholder="e.g., train, test, validation") | |
| input_dataset_config = gr.Textbox(label="Input Dataset Config", value="default", placeholder="e.g., default, custom") | |
| prompt_column = gr.Textbox(label="Prompt Column", placeholder="e.g., text, prompt, question") | |
| with gr.Column(): | |
| output_dataset_name = gr.Textbox(label="Output Dataset Name", placeholder="e.g., my-generated-dataset, must be unique. Will be created under the org 'synthetic-data-universe'") | |
| with gr.Group(): | |
| gr.Markdown("## Model information") | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name_or_path = gr.Textbox(label="Model Name or Path", placeholder="e.g., Qwen/Qwen3-4B-Instruct-2507") | |
| model_revision = gr.Textbox(label="Model Revision", value="main", placeholder="e.g., main, v1.0") | |
| # model_token = gr.Textbox(label="Model Token (Optional)", type="password", placeholder="Your HF token with read/write access to the model...") | |
| with gr.Group(): | |
| gr.Markdown("## Generation Parameters") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| max_tokens = gr.Slider(label="Max Tokens", value=512, minimum=256, maximum=MAX_TOKENS, step=256) | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1) | |
| with gr.Row(): | |
| top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5) | |
| top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05) | |
| with gr.Row(): | |
| system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=3, placeholder="Optional system prompt... e.g., You are a helpful assistant.") | |
| with gr.Group(): | |
| gr.Markdown("## User Information, for notification when your job is completed") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| email = gr.Textbox(label="Email", placeholder="your.email@example.com") | |
| # with gr.Row(): | |
| # input_dataset_token = gr.Textbox(label="Input dataset token", type="password", placeholder="Your HF token with read access to the input dataset, leave blank if public dataset") | |
| # output_dataset_token = gr.Textbox(label="Output dataset token", type="password", placeholder="Your HF token with write access to the output dataset") | |
| submit_btn = gr.Button("Submit Generation Request", variant="primary") | |
| output_status = gr.Textbox(label="Status", interactive=False) | |
| def submit_request(input_dataset_name, input_split, input_dataset_config, output_dataset_name, prompt_col, model_name, model_rev, sys_prompt, | |
| max_tok, temp, top_k_val, top_p_val, email_addr): | |
| MASTER_ORG = "synthetic-data-universe/" | |
| model_token = None # This is currently not supported | |
| input_dataset_token = None # This is currently not supported | |
| output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN") | |
| try: | |
| request = GenerationRequest( | |
| id="", # Will be generated when adding to the database | |
| created_at="", # Will be set when adding to the database | |
| status=GenerationStatus.PENDING, | |
| input_dataset_name=input_dataset_name, | |
| input_dataset_split=input_split, | |
| input_dataset_config=input_dataset_config, | |
| output_dataset_name=MASTER_ORG + output_dataset_name, | |
| prompt_column=prompt_col, | |
| model_name_or_path=model_name, | |
| model_revision=model_rev, | |
| model_token=model_token if model_token else None, | |
| system_prompt=sys_prompt if sys_prompt else None, | |
| max_tokens=int(max_tok), | |
| temperature=temp, | |
| top_k=int(top_k_val), | |
| top_p=top_p_val, | |
| input_dataset_token=input_dataset_token if input_dataset_token else None, | |
| output_dataset_token=output_dataset_token, | |
| username="user", | |
| email=email_addr | |
| ) | |
| # check the input dataset exists and can be accessed with the provided token | |
| validate_request(request) | |
| add_request_to_db(request) | |
| return "Request submitted successfully!" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| submit_btn.click( | |
| submit_request, | |
| inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path, | |
| model_revision, system_prompt, max_tokens, temperature, top_k, top_p, | |
| email], | |
| outputs=output_status | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| app = create_gradio_interface() | |
| app.launch() |