Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from dataclasses import dataclass | |
| import os | |
| from supabase import create_client, Client | |
| from supabase.client import ClientOptions | |
| from enum import Enum | |
| from datasets import get_dataset_infos | |
| from transformers import AutoConfig, GenerationConfig | |
| from huggingface_hub import whoami | |
| from typing import Optional, Union | |
| """ | |
| Still TODO: | |
| - validate the user is PRO | |
| - check the output dataset token is valid (hardcoded for now as a secret) | |
| - validate max model params | |
| """ | |
| class GenerationStatus(Enum): | |
| PENDING = "PENDING" | |
| RUNNING = "RUNNING" | |
| COMPLETED = "COMPLETED" | |
| FAILED = "FAILED" | |
| MAX_SAMPLES_PRO = 10000 # max number of samples for PRO/Enterprise users | |
| MAX_SAMPLES_FREE = 100 # max number of samples for free users | |
| MAX_TOKENS = 8192 | |
| MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now) | |
| # Cache for model generation parameters | |
| MODEL_GEN_PARAMS_CACHE = {} | |
| class GenerationRequest: | |
| id: str | |
| created_at: str | |
| status: GenerationStatus | |
| input_dataset_name: str | |
| input_dataset_config: str | |
| input_dataset_split: str | |
| output_dataset_name: str | |
| prompt_column: str | |
| model_name_or_path: str | |
| model_revision: str | |
| model_token: str | None | |
| system_prompt: str | None | |
| max_tokens: int | |
| temperature: float | |
| top_k: int | |
| top_p: float | |
| input_dataset_token: str | None | |
| output_dataset_token: str | |
| username: str | |
| email: str | |
| num_output_examples: int | |
| private: bool = False | |
| num_retries: int = 0 | |
| SUPPORTED_MODELS = [ | |
| "Qwen/Qwen3-4B-Instruct-2507", | |
| "Qwen/Qwen3-30B-A3B-Instruct-2507", | |
| "meta-llama/Llama-3.2-1B-Instruct", | |
| "meta-llama/Llama-3.2-3B-Instruct", | |
| "baidu/ERNIE-4.5-21B-A3B-Thinking", | |
| "LLM360/K2-Think", | |
| "openai/gpt-oss-20b", | |
| ] | |
| def fetch_model_generation_params(model_name: str) -> dict: | |
| """Fetch generation parameters and model config from the hub""" | |
| default_params = { | |
| "max_tokens": 1024, | |
| "temperature": 0.7, | |
| "top_k": 50, | |
| "top_p": 0.95, | |
| "max_position_embeddings": 2048, | |
| "recommended_max_tokens": 1024 | |
| } | |
| try: | |
| print(f"Attempting to fetch configs for: {model_name}") | |
| # Always try to load the model config first for max_position_embeddings | |
| model_config = None | |
| max_position_embeddings = default_params["max_position_embeddings"] | |
| try: | |
| output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN") | |
| model_config = AutoConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token) | |
| max_position_embeddings = getattr(model_config, 'max_position_embeddings', default_params["max_position_embeddings"]) | |
| print(f"Loaded AutoConfig for {model_name}, max_position_embeddings: {max_position_embeddings}") | |
| except Exception as e: | |
| print(f"Failed to load AutoConfig for {model_name}: {e}") | |
| # Calculate recommended max tokens (conservative estimate) | |
| # Leave some room for the prompt, so use ~75% of max_position_embeddings | |
| recommended_max_tokens = min(int(max_position_embeddings * 0.75), MAX_TOKENS) | |
| recommended_max_tokens = max(256, recommended_max_tokens) # Ensure minimum | |
| # Try to load the generation config | |
| gen_config = None | |
| try: | |
| gen_config = GenerationConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token) | |
| print(f"Successfully loaded generation config for {model_name}") | |
| except Exception as e: | |
| print(f"Failed to load GenerationConfig for {model_name}: {e}") | |
| # Extract parameters from generation config or use model-specific defaults | |
| if gen_config: | |
| params = { | |
| "max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', recommended_max_tokens), | |
| "temperature": getattr(gen_config, 'temperature', default_params["temperature"]), | |
| "top_k": getattr(gen_config, 'top_k', default_params["top_k"]), | |
| "top_p": getattr(gen_config, 'top_p', default_params["top_p"]), | |
| "max_position_embeddings": max_position_embeddings, | |
| "recommended_max_tokens": recommended_max_tokens | |
| } | |
| else: | |
| params = dict(default_params) | |
| params["max_position_embeddings"] = max_position_embeddings | |
| params["recommended_max_tokens"] = recommended_max_tokens | |
| # Ensure parameters are within valid ranges | |
| params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS, params["recommended_max_tokens"])) | |
| params["temperature"] = max(0.0, min(params["temperature"], 2.0)) | |
| params["top_k"] = max(5, min(params["top_k"], 100)) | |
| params["top_p"] = max(0.0, min(params["top_p"], 1.0)) | |
| print(f"Final params for {model_name}: {params}") | |
| return params | |
| except Exception as e: | |
| print(f"Could not fetch configs for {model_name}: {e}") | |
| return default_params | |
| def update_generation_params(model_name: str): | |
| """Update generation parameters based on selected model""" | |
| global MODEL_GEN_PARAMS_CACHE | |
| print(f"Updating generation parameters for model: {model_name}") | |
| print(f"Cache is empty: {len(MODEL_GEN_PARAMS_CACHE) == 0}") | |
| print(f"Current cache keys: {list(MODEL_GEN_PARAMS_CACHE.keys())}") | |
| # If cache is empty, try to populate it now | |
| if len(MODEL_GEN_PARAMS_CACHE) == 0: | |
| print("Cache is empty, attempting to populate now...") | |
| cache_all_model_params() | |
| if model_name in MODEL_GEN_PARAMS_CACHE: | |
| params = MODEL_GEN_PARAMS_CACHE[model_name] | |
| print(f"Found cached params for {model_name}: {params}") | |
| # Set the max_tokens slider maximum to the model's recommended max | |
| max_tokens_limit = min(params.get("recommended_max_tokens", MAX_TOKENS), MAX_TOKENS) | |
| return ( | |
| gr.update(value=params["max_tokens"], maximum=max_tokens_limit), # max_tokens with dynamic maximum | |
| gr.update(value=params["temperature"]), # temperature | |
| gr.update(value=params["top_k"]), # top_k | |
| gr.update(value=params["top_p"]) # top_p | |
| ) | |
| else: | |
| # Fallback to defaults if model not in cache | |
| print(f"Model {model_name} not found in cache, using defaults") | |
| return ( | |
| gr.update(value=1024, maximum=MAX_TOKENS), # max_tokens | |
| gr.update(value=0.7), # temperature | |
| gr.update(value=50), # top_k | |
| gr.update(value=0.95) # top_p | |
| ) | |
| def cache_all_model_params(): | |
| """Cache generation parameters for all supported models at startup""" | |
| global MODEL_GEN_PARAMS_CACHE | |
| print(f"Starting to cache parameters for {len(SUPPORTED_MODELS)} models...") | |
| print(f"Supported models: {SUPPORTED_MODELS}") | |
| for model_name in SUPPORTED_MODELS: | |
| try: | |
| print(f"Processing model: {model_name}") | |
| params = fetch_model_generation_params(model_name) | |
| MODEL_GEN_PARAMS_CACHE[model_name] = params | |
| print(f"Successfully cached params for {model_name}: {params}") | |
| except Exception as e: | |
| print(f"Exception while caching params for {model_name}: {e}") | |
| # Use default parameters if caching fails | |
| default_params = { | |
| "max_tokens": 1024, | |
| "temperature": 0.7, | |
| "top_k": 50, | |
| "top_p": 0.95, | |
| "max_position_embeddings": 2048, | |
| "recommended_max_tokens": 1024 | |
| } | |
| MODEL_GEN_PARAMS_CACHE[model_name] = default_params | |
| print(f"Using default params for {model_name}: {default_params}") | |
| print(f"Caching complete. Final cache contents:") | |
| for model, params in MODEL_GEN_PARAMS_CACHE.items(): | |
| print(f" {model}: {params}") | |
| print(f"Cache size: {len(MODEL_GEN_PARAMS_CACHE)} models") | |
| def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool: | |
| """Verifies if the user is a Hugging Face PRO user or part of an enterprise org.""" | |
| if not token: | |
| return False | |
| if isinstance(token, gr.OAuthToken): | |
| token_str = token.token | |
| elif isinstance(token, str): | |
| token_str = token | |
| else: | |
| return False | |
| try: | |
| user_info = whoami(token=token_str) | |
| return ( | |
| user_info.get("isPro", False) or | |
| any(org.get("isEnterprise", False) for org in user_info.get("orgs", [])) | |
| ) | |
| except Exception as e: | |
| print(f"Could not verify user's PRO/Enterprise status: {e}") | |
| return False | |
| def validate_request(request: GenerationRequest, oauth_token: Optional[Union[gr.OAuthToken, str]] = None) -> GenerationRequest: | |
| # checks that the request is valid | |
| # - input dataset exists and can be accessed with the provided token | |
| try: | |
| input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config] | |
| except Exception as e: | |
| raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.") | |
| # check that the input dataset split exists | |
| if request.input_dataset_split not in input_dataset_info.splits: | |
| raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}") | |
| # if num_output_examples is 0, set it to the number of examples in the input dataset split | |
| if request.num_output_examples == 0: | |
| request.num_output_examples = input_dataset_info.splits[request.input_dataset_split].num_examples | |
| else: | |
| if request.num_output_examples > input_dataset_info.splits[request.input_dataset_split].num_examples: | |
| raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the number of examples in the input dataset split {input_dataset_info.splits[request.input_dataset_split].num_examples}.") | |
| request.input_dataset_split = f"{request.input_dataset_split}[:{request.num_output_examples}]" | |
| # Check user tier and apply appropriate limits | |
| # Anonymous users (oauth_token is None) are treated as free tier | |
| is_pro = verify_pro_status(oauth_token) if oauth_token else False | |
| max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE | |
| if request.num_output_examples > max_samples: | |
| if oauth_token is None: | |
| user_tier = "non-signed-in" | |
| else: | |
| user_tier = "PRO/Enterprise" if is_pro else "free" | |
| raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the max limit of {max_samples} for {user_tier} users.") | |
| # check the prompt column exists in the dataset | |
| if request.prompt_column not in input_dataset_info.features: | |
| raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}") | |
| # This is currently not supported, the output dataset will be created under the org 'synthetic-data-universe' | |
| # check output_dataset name is valid | |
| if request.output_dataset_name.count("/") != 1: | |
| raise Exception("Output dataset will be popululated automatically. The dataset will be created under the org 'synthetic-data-universe/my-dataset'.") | |
| # check the output dataset is valid and accessible with the provided token | |
| try: | |
| get_dataset_infos(request.output_dataset_name, token=request.output_dataset_token) | |
| raise Exception(f"Output dataset {request.output_dataset_name} already exists. Please choose a different name.") | |
| except Exception: | |
| pass # dataset does not exist, which is expected | |
| # check the output dataset name doesn't already exist in the database | |
| try: | |
| url = os.getenv("SUPABASE_URL") | |
| key = os.getenv("SUPABASE_KEY") | |
| if url and key: | |
| supabase = create_client( | |
| url, | |
| key, | |
| options=ClientOptions( | |
| postgrest_client_timeout=10, | |
| storage_client_timeout=10, | |
| schema="public", | |
| ) | |
| ) | |
| existing_request = supabase.table("gen-requests").select("id").eq("output_dataset_name", request.output_dataset_name).execute() | |
| if existing_request.data: | |
| raise Exception(f"Output dataset {request.output_dataset_name} is already being generated or has been requested. Please choose a different name.") | |
| except Exception as e: | |
| # If it's our custom exception about dataset already existing, re-raise it | |
| if "already being generated" in str(e): | |
| raise e | |
| # Otherwise, ignore database connection errors and continue | |
| pass | |
| # check the models exists | |
| try: | |
| model_config = AutoConfig.from_pretrained(request.model_name_or_path, | |
| revision=request.model_revision, | |
| force_download=True, | |
| token=False | |
| ) | |
| except Exception as e: | |
| print(e) | |
| raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed. The model may be private or gated, which is not supported at this time.") | |
| # check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS | |
| if model_config.max_position_embeddings < request.max_tokens: | |
| raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.") | |
| if request.max_tokens > MAX_TOKENS: | |
| raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.") | |
| # check sampling parameters are valid | |
| if request.temperature < 0.0 or request.temperature > 2.0: | |
| raise Exception("Temperature must be between 0.0 and 2.0") | |
| if request.top_k < 1 or request.top_k > 100: | |
| raise Exception("Top K must be between 1 and 100") | |
| if request.top_p < 0.0 or request.top_p > 1.0: | |
| raise Exception("Top P must be between 0.0 and 1.0") | |
| return request | |
| def load_dataset_info(dataset_name, model_name, oauth_token=None): | |
| """Load dataset information and return choices for dropdowns""" | |
| if not dataset_name.strip(): | |
| return ( | |
| gr.update(choices=[], value=None), # config | |
| gr.update(choices=[], value=None), # split | |
| gr.update(choices=[], value=None), # prompt_column | |
| gr.update(value="", interactive=True), # output_dataset_name | |
| gr.update(interactive=False), # num_output_samples | |
| "Please enter a dataset name first." | |
| ) | |
| try: | |
| # Get dataset info | |
| dataset_infos = get_dataset_infos(dataset_name) | |
| if not dataset_infos: | |
| raise Exception("No configs found for this dataset") | |
| # Get available configs | |
| config_choices = list(dataset_infos.keys()) | |
| default_config = config_choices[0] if config_choices else None | |
| # Get splits and features for the default config | |
| if default_config: | |
| config_info = dataset_infos[default_config] | |
| split_choices = list(config_info.splits.keys()) | |
| default_split = split_choices[0] if split_choices else None | |
| # Get column choices (features) | |
| column_choices = list(config_info.features.keys()) | |
| default_column = None | |
| # Try to find a likely prompt column | |
| for col in column_choices: | |
| if any(keyword in col.lower() for keyword in ['prompt', 'text', 'question', 'input']): | |
| default_column = col | |
| break | |
| if not default_column and column_choices: | |
| default_column = column_choices[0] | |
| # Get sample count for the default split | |
| dataset_sample_count = config_info.splits[default_split].num_examples if default_split else 0 | |
| else: | |
| split_choices = [] | |
| column_choices = [] | |
| default_split = None | |
| default_column = None | |
| dataset_sample_count = 0 | |
| # Determine user limits | |
| is_pro = verify_pro_status(oauth_token) if oauth_token else False | |
| user_max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE | |
| # Set slider maximum to the minimum of dataset samples and user limit | |
| slider_max = min(dataset_sample_count, user_max_samples) if dataset_sample_count > 0 else user_max_samples | |
| # Get username from OAuth token | |
| username = "anonymous" | |
| if oauth_token: | |
| try: | |
| if isinstance(oauth_token, gr.OAuthToken): | |
| token_str = oauth_token.token | |
| elif isinstance(oauth_token, str): | |
| token_str = oauth_token | |
| else: | |
| token_str = None | |
| if token_str: | |
| user_info = whoami(token=token_str) | |
| username = user_info.get("name", "anonymous") | |
| except Exception: | |
| username = "anonymous" | |
| # Generate a suggested output dataset name: username-model-dataset | |
| dataset_base_name = dataset_name.split('/')[-1] if '/' in dataset_name else dataset_name | |
| # Extract model short name (e.g., "Qwen/Qwen3-4B-Instruct-2507" -> "qwen3-4b") | |
| model_short_name = model_name.split('/')[-1] | |
| # Remove common suffixes and simplify | |
| # Build the output name: username-model-dataset | |
| suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}" | |
| # Limit to 86 characters | |
| if len(suggested_output_name) > 86: | |
| # Truncate dataset name to fit within limit | |
| available_for_dataset = 86 - len(username) - len(model_short_name) - 2 # -2 for the hyphens | |
| if available_for_dataset > 0: | |
| dataset_base_name = dataset_base_name[:available_for_dataset] | |
| suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}" | |
| else: | |
| suggested_output_name = f"{username}-{model_short_name}" | |
| status_msg = f"β Dataset info loaded successfully! Found {len(config_choices)} config(s), {len(split_choices)} split(s), and {len(column_choices)} column(s)." | |
| if dataset_sample_count > 0: | |
| status_msg += f" Dataset has {dataset_sample_count:,} samples." | |
| if dataset_sample_count > user_max_samples: | |
| user_tier = "PRO/Enterprise" if is_pro else "free tier" | |
| status_msg += f" Limited to {user_max_samples:,} samples for {user_tier} users." | |
| return ( | |
| gr.update(choices=config_choices, value=default_config, interactive=True), # config | |
| gr.update(choices=split_choices, value=default_split, interactive=True), # split | |
| gr.update(choices=column_choices, value=default_column, interactive=True), # prompt_column | |
| gr.update(value=suggested_output_name, interactive=True), # output_dataset_name | |
| gr.update(interactive=True, maximum=slider_max, value=0), # num_output_samples | |
| status_msg | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.update(choices=[], value=None, interactive=False), # config | |
| gr.update(choices=[], value=None, interactive=False), # split | |
| gr.update(choices=[], value=None, interactive=False), # prompt_column | |
| gr.update(value="", interactive=False), # output_dataset_name | |
| gr.update(interactive=False), # num_output_samples | |
| f"β Error loading dataset info: {str(e)}" | |
| ) | |
| def add_request_to_db(request: GenerationRequest): | |
| url: str = os.getenv("SUPABASE_URL") | |
| key: str = os.getenv("SUPABASE_KEY") | |
| try: | |
| supabase: Client = create_client( | |
| url, | |
| key, | |
| options=ClientOptions( | |
| postgrest_client_timeout=10, | |
| storage_client_timeout=10, | |
| schema="public", | |
| ) | |
| ) | |
| data = { | |
| "status": request.status.value, | |
| "input_dataset_name": request.input_dataset_name, | |
| "input_dataset_config": request.input_dataset_config, | |
| "input_dataset_split": request.input_dataset_split, | |
| "output_dataset_name": request.output_dataset_name, | |
| "prompt_column": request.prompt_column, | |
| "model_name_or_path": request.model_name_or_path, | |
| "model_revision": request.model_revision, | |
| "model_token": request.model_token, | |
| "system_prompt": request.system_prompt, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_k": request.top_k, | |
| "top_p": request.top_p, | |
| "input_dataset_token": request.input_dataset_token, | |
| "output_dataset_token": request.output_dataset_token, | |
| "username": request.username, | |
| "email": request.email, | |
| "num_output_examples": request.num_output_examples, | |
| "private": request.private, | |
| } | |
| supabase.table("gen-requests").insert(data).execute() | |
| except Exception as e: | |
| raise Exception(f"Failed to add request to database: {str(e)}") | |
| def get_generation_stats_safe(): | |
| """Safely fetch generation request statistics with proper error handling""" | |
| try: | |
| url = os.getenv("SUPABASE_URL") | |
| key = os.getenv("SUPABASE_KEY") | |
| if not url or not key: | |
| raise Exception("Missing SUPABASE_URL or SUPABASE_KEY environment variables") | |
| supabase = create_client( | |
| url, | |
| key, | |
| options=ClientOptions( | |
| postgrest_client_timeout=10, | |
| storage_client_timeout=10, | |
| schema="public", | |
| ) | |
| ) | |
| # Fetch data excluding sensitive token fields | |
| response = supabase.table("gen-requests").select( | |
| "id, created_at, status, input_dataset_name, input_dataset_config, " | |
| "input_dataset_split, output_dataset_name, prompt_column, " | |
| "model_name_or_path, model_revision, max_tokens, temperature, " | |
| "top_k, top_p, username, num_output_examples, private" | |
| ).order("created_at", desc=True).limit(50).execute() | |
| return {"status": "success", "data": response.data} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e), "data": []} | |
| # Old commented code removed - replaced with DatabaseManager and get_generation_stats_safe() | |
| def main(): | |
| # Cache model generation parameters at startup | |
| print("Caching model generation parameters...") | |
| cache_all_model_params() | |
| print("Model parameter caching complete.") | |
| with gr.Blocks(title="DataForge - Synthetic Data Generation") as demo: | |
| gr.Image("dataforge.png", show_label=False, show_download_button=False, container=False, height=300) | |
| # Store the current oauth token for use in submit_request | |
| current_oauth_token = gr.State(None) | |
| with gr.Row(): | |
| gr.Markdown("") # Empty space for alignment | |
| login_button = gr.LoginButton(value="π Sign in", size="sm") | |
| gr.Markdown("") # Empty space for alignment | |
| signin_message = gr.Markdown("## π Sign In Required\n\nPlease sign in with your Hugging Face account to access the synthetic data generation service. Click the **Sign in** button above to continue.", visible=True) | |
| # Main description | |
| gr.Markdown(""" | |
| This tool allows you to **generate synthetic data from existing datasets**, for all your **fine-tuning/research/data augmentation** needs! | |
| DataForge is built on top of [DataTrove](https://github.com/huggingface/datatrove), our backend data generation script is open-source and available on [GitHub](https://github.com/huggingface/dataforge). DataForge is **FREE** for HuggingFace PRO users (10,000 samples) β’ 100 samples for free users. | |
| """) | |
| gr.Markdown("**All generated datasets will be publicly available under the [synthetic-data-universe](https://huggingface.co/synthetic-data-universe) organization.**") | |
| # Usage guide and examples (right below description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Usage Guide", open=False): | |
| gr.Markdown(""" | |
| **Step-by-Step Process:** | |
| 1. **Choose Model**: Select from 20+ models | |
| 2. **Load Dataset**: Enter a HF dataset name | |
| 3. **Load Info**: Click "Load Dataset Info" | |
| 4. **Configure**: Set generation parameters | |
| 5. **Submit**: Monitor progress in Statistics tab | |
| **Requirements:** | |
| - Input dataset must be public on HF Hub | |
| - Model must be publicly accessible | |
| - Free users: 100 samples max, PRO: 10K max | |
| - Token limit: 8,192 per sample | |
| """) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Examples", open=False): | |
| gr.Markdown(""" | |
| **Popular Use Cases:** | |
| **Conversational**: Multi-turn dialogues | |
| - Models: Llama-3.2-3B, Mistral-7B | |
| - Temperature: 0.7-0.9 | |
| **Code**: Problem β Solution | |
| - Models: Qwen2.5-Coder, DeepSeek-Coder | |
| - Temperature: 0.1-0.3 | |
| **Example datasets to try:** | |
| ``` | |
| simplescaling/s1K-1.1 | |
| HuggingFaceH4/ultrachat_200k | |
| iamtarun/python_code_instructions_18k_alpaca | |
| ``` | |
| """) | |
| # Sign in button | |
| main_interface = gr.Column(visible=False) | |
| with main_interface: | |
| with gr.Tabs(): | |
| with gr.TabItem("Generate Data"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Model information") | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name_or_path = gr.Dropdown( | |
| choices=SUPPORTED_MODELS, | |
| label="Select Model", | |
| value="Qwen/Qwen3-4B-Instruct-2507", | |
| info="Choose from popular instruction-tuned models under 40B parameters" | |
| ) | |
| # model_token = gr.Textbox(label="Model Token (Optional)", type="password", placeholder="Your HF token with read/write access to the model...") | |
| with gr.Row(): | |
| system_prompt = gr.Textbox(label="System Prompt (Optional)", placeholder="Optional system prompt... e.g., You are a helpful assistant.", info="Sets the AI's role/behavior. Leave empty for default model behavior.") | |
| gr.Markdown("### Generation Parameters") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| max_tokens = gr.Slider(label="Max Tokens", value=1024, minimum=256, maximum=MAX_TOKENS, step=256, info="Maximum tokens to generate per sample. Higher = longer responses.") | |
| temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1, info="Creativity level: 0.1=focused, 0.7=balanced, 1.0+=creative") | |
| with gr.Row(): | |
| top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5, info="Limits word choices to top K options. Lower = more focused.") | |
| top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05, info="Nucleus sampling: 0.9=focused, 0.95=balanced diversity") | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Dataset information") | |
| # Dynamic user limit info - default to anonymous user | |
| user_limit_info = gr.Markdown(value="π€ **Anonymous User**: You can generate up to 100 samples per request. Use the sign-in button above for PRO benefits (10,000 samples).", visible=True) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_dataset_name = gr.Textbox(label="Input Dataset Name", placeholder="e.g., simplescaling/s1K-1.1", info="Public HF dataset with prompts to generate from") | |
| load_info_btn = gr.Button("π Load Dataset Info", size="sm", variant="secondary") | |
| load_info_status = gr.Markdown("", visible=True) | |
| with gr.Column(): | |
| output_dataset_name = gr.Textbox(label="Output Dataset Name", placeholder="e.g., my-generated-dataset, must be unique. Will be created under the org 'synthetic-data-universe'", value=None, interactive=False, info="Click Load Info to populate") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_dataset_config = gr.Dropdown(label="Dataset Config", choices=[], value=None, interactive=False, info="Click Load Info to populate") | |
| prompt_column = gr.Dropdown(label="Prompt Column", choices=[], value=None, interactive=False, info="Click Load Info to populate") | |
| with gr.Column(): | |
| input_dataset_split = gr.Dropdown(label="Dataset Split", choices=[], value=None, interactive=False, info="Click Load Info to populate") | |
| num_output_samples = gr.Slider(label="Number of samples, leave as '0' for all", value=0, minimum=0, maximum=MAX_SAMPLES_FREE, step=1, interactive=False, info="Click Load Info to populate") | |
| submit_btn = gr.Button("Submit Generation Request", variant="primary") | |
| output_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.TabItem("Statistics Dashboard"): | |
| gr.Markdown("## DataForge Generation Statistics") | |
| gr.Markdown("π View recent synthetic data generation requests and their status.") | |
| with gr.Row(): | |
| refresh_stats_btn = gr.Button("π Refresh Statistics", size="sm", variant="secondary") | |
| clear_stats_btn = gr.Button("ποΈ Clear Display", size="sm") | |
| stats_status = gr.Markdown("Click 'Refresh Statistics' to load recent generation requests.", visible=True) | |
| stats_dataframe = gr.Dataframe( | |
| headers=["ID", "Created", "Status", "Input Dataset", "Output Dataset", "Model", "Samples", "User"], | |
| datatype=["str", "str", "str", "str", "str", "str", "number", "str"], | |
| interactive=False, | |
| wrap=True, | |
| value=[], | |
| label="Recent Generation Requests (Last 50)", | |
| visible=False | |
| ) | |
| def load_statistics(): | |
| """Load and format statistics data""" | |
| try: | |
| # Use the new safe database function | |
| result = get_generation_stats_safe() | |
| if result["status"] == "error": | |
| return ( | |
| f"β **Error loading statistics**: {result['message']}", | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ) | |
| data = result["data"] | |
| if not data: | |
| return ( | |
| "π **No data found**: The database appears to be empty or the table doesn't exist yet.", | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ) | |
| # Format data for display | |
| formatted_data = [] | |
| for item in data: | |
| # Format timestamp | |
| created_at = item.get('created_at', 'Unknown') | |
| if created_at and created_at != 'Unknown': | |
| try: | |
| from datetime import datetime | |
| dt = datetime.fromisoformat(created_at.replace('Z', '+00:00')) | |
| created_at = dt.strftime('%Y-%m-%d %H:%M') | |
| except: | |
| pass | |
| formatted_data.append([ | |
| str(item.get('id', ''))[:8] + "..." if len(str(item.get('id', ''))) > 8 else str(item.get('id', '')), | |
| created_at, | |
| item.get('status', 'Unknown'), | |
| (item.get('input_dataset_name', '')[:30] + "...") if len(item.get('input_dataset_name', '')) > 30 else item.get('input_dataset_name', ''), | |
| (item.get('output_dataset_name', '')[:30] + "...") if len(item.get('output_dataset_name', '')) > 30 else item.get('output_dataset_name', ''), | |
| (item.get('model_name_or_path', '')[:25] + "...") if len(item.get('model_name_or_path', '')) > 25 else item.get('model_name_or_path', ''), | |
| item.get('num_output_examples', 0), | |
| item.get('username', 'Anonymous') | |
| ]) | |
| return ( | |
| f"β **Statistics loaded successfully**: Found {len(formatted_data)} recent requests.", | |
| gr.update(value=formatted_data, visible=True), | |
| gr.update(visible=True) | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"β **Unexpected error**: {str(e)}", | |
| gr.update(visible=False), | |
| gr.update(visible=True) | |
| ) | |
| def clear_statistics(): | |
| """Clear the statistics display""" | |
| return ( | |
| "Click 'Refresh Statistics' to load recent generation requests.", | |
| gr.update(value=[], visible=False), | |
| gr.update(visible=True) | |
| ) | |
| # Connect buttons to functions | |
| refresh_stats_btn.click( | |
| load_statistics, | |
| outputs=[stats_status, stats_dataframe, stats_status] | |
| ) | |
| clear_stats_btn.click( | |
| clear_statistics, | |
| outputs=[stats_status, stats_dataframe, stats_status] | |
| ) | |
| def submit_request(input_dataset_name, input_split, input_dataset_config, output_dataset_name, prompt_col, model_name, sys_prompt, | |
| max_tok, temp, top_k_val, top_p_val, num_output_samples, oauth_token=None): | |
| MASTER_ORG = "synthetic-data-universe/" | |
| model_token = False # This is currently not supported | |
| input_dataset_token = None # This is currently not supported | |
| output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN") | |
| # Get username from OAuth token | |
| username = "anonymous" | |
| if oauth_token: | |
| try: | |
| if isinstance(oauth_token, gr.OAuthToken): | |
| token_str = oauth_token.token | |
| elif isinstance(oauth_token, str): | |
| token_str = oauth_token | |
| else: | |
| token_str = None | |
| if token_str: | |
| user_info = whoami(token=token_str) | |
| username = user_info.get("name", "unknown") | |
| except Exception: | |
| username = "unknown" | |
| try: | |
| request = GenerationRequest( | |
| id="", # Will be generated when adding to the database | |
| created_at="", # Will be set when adding to the database | |
| status=GenerationStatus.PENDING, | |
| input_dataset_name=input_dataset_name, | |
| input_dataset_split=input_split, | |
| input_dataset_config=input_dataset_config, | |
| output_dataset_name=MASTER_ORG + output_dataset_name, | |
| prompt_column=prompt_col, | |
| model_name_or_path=model_name, | |
| model_revision="main", | |
| model_token=model_token, | |
| system_prompt=sys_prompt if sys_prompt else None, | |
| max_tokens=int(max_tok), | |
| temperature=temp, | |
| top_k=int(top_k_val), | |
| top_p=top_p_val, | |
| input_dataset_token=input_dataset_token if input_dataset_token else None, | |
| output_dataset_token=output_dataset_token, | |
| num_output_examples=num_output_samples, # will be set after validating the input dataset | |
| username=username, | |
| email="n/a", | |
| ) | |
| # check the input dataset exists and can be accessed with the provided token | |
| request = validate_request(request, oauth_token) | |
| add_request_to_db(request) | |
| return "Request submitted successfully!" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Wire up the Load Dataset Info button | |
| load_info_btn.click( | |
| load_dataset_info, | |
| inputs=[input_dataset_name, model_name_or_path, current_oauth_token], | |
| outputs=[input_dataset_config, input_dataset_split, prompt_column, output_dataset_name, num_output_samples, load_info_status] | |
| ) | |
| # Wire up model change to update generation parameters | |
| model_name_or_path.change( | |
| update_generation_params, | |
| inputs=[model_name_or_path], | |
| outputs=[max_tokens, temperature, top_k, top_p] | |
| ) | |
| submit_btn.click( | |
| submit_request, | |
| inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path, | |
| system_prompt, max_tokens, temperature, top_k, top_p, num_output_samples, current_oauth_token], | |
| outputs=output_status | |
| ) | |
| def update_user_limits(oauth_token): | |
| if oauth_token is None: | |
| return "π€ **Anonymous User**: You can generate up to 100 samples per request. Use the sign-in button above for PRO benefits (10,000 samples)." | |
| is_pro = verify_pro_status(oauth_token) | |
| if is_pro: | |
| return "β¨ **PRO User**: You can generate up to 10,000 samples per request." | |
| else: | |
| return "π€ **Free User**: You can generate up to 100 samples per request. [Upgrade to PRO](http://huggingface.co/subscribe/pro?source=synthetic-data-universe) for 10,000 samples." | |
| def control_access(profile: Optional[gr.OAuthProfile] = None, oauth_token: Optional[gr.OAuthToken] = None): | |
| # Require users to be signed in | |
| if oauth_token is None: | |
| # User is not signed in - show sign-in prompt, hide main interface | |
| return ( | |
| gr.update(visible=False), # main_interface | |
| gr.update(visible=True), # signin_message | |
| oauth_token, # current_oauth_token | |
| "", # user_limit_info (empty when not signed in) | |
| gr.update(), # num_output_samples (no change) | |
| gr.update(value="π Sign in") # login_button | |
| ) | |
| else: | |
| # User is signed in - show main interface, hide sign-in prompt | |
| limit_msg = update_user_limits(oauth_token) | |
| is_pro = verify_pro_status(oauth_token) | |
| max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE | |
| if is_pro: | |
| button_text = f"β¨ Signed in as PRO ({profile.name if profile else 'User'})" | |
| else: | |
| button_text = f"π€ Signed in as {profile.name if profile else 'User'}" | |
| return ( | |
| gr.update(visible=True), # main_interface | |
| gr.update(visible=False), # signin_message | |
| oauth_token, # current_oauth_token | |
| limit_msg, # user_limit_info | |
| gr.update(maximum=max_samples), # num_output_samples | |
| gr.update(value=button_text) # login_button | |
| ) | |
| # Handle login state changes - LoginButton automatically handles auth state changes | |
| # The demo.load will handle both initial load and auth changes | |
| demo.load(control_access, inputs=None, outputs=[main_interface, signin_message, current_oauth_token, user_limit_info, num_output_samples, login_button]) | |
| demo.queue(max_size=None, default_concurrency_limit=None).launch(show_error=True) | |
| if __name__ == "__main__": | |
| main() | |