synth / app.py
edbeeching
add png with LFS support
e621b4d
raw
history blame
44 kB
import gradio as gr
from dataclasses import dataclass
import os
from supabase import create_client, Client
from supabase.client import ClientOptions
from enum import Enum
from datasets import get_dataset_infos
from transformers import AutoConfig, GenerationConfig
from huggingface_hub import whoami
from typing import Optional, Union
"""
Still TODO:
- validate the user is PRO
- check the output dataset token is valid (hardcoded for now as a secret)
- validate max model params
"""
class GenerationStatus(Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
MAX_SAMPLES_PRO = 10000 # max number of samples for PRO/Enterprise users
MAX_SAMPLES_FREE = 100 # max number of samples for free users
MAX_TOKENS = 8192
MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now)
# Cache for model generation parameters
MODEL_GEN_PARAMS_CACHE = {}
@dataclass
class GenerationRequest:
id: str
created_at: str
status: GenerationStatus
input_dataset_name: str
input_dataset_config: str
input_dataset_split: str
output_dataset_name: str
prompt_column: str
model_name_or_path: str
model_revision: str
model_token: str | None
system_prompt: str | None
max_tokens: int
temperature: float
top_k: int
top_p: float
input_dataset_token: str | None
output_dataset_token: str
username: str
email: str
num_output_examples: int
private: bool = False
num_retries: int = 0
SUPPORTED_MODELS = [
"Qwen/Qwen3-4B-Instruct-2507",
"Qwen/Qwen3-30B-A3B-Instruct-2507",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"baidu/ERNIE-4.5-21B-A3B-Thinking",
"LLM360/K2-Think",
"openai/gpt-oss-20b",
]
def fetch_model_generation_params(model_name: str) -> dict:
"""Fetch generation parameters and model config from the hub"""
default_params = {
"max_tokens": 1024,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
"max_position_embeddings": 2048,
"recommended_max_tokens": 1024
}
try:
print(f"Attempting to fetch configs for: {model_name}")
# Always try to load the model config first for max_position_embeddings
model_config = None
max_position_embeddings = default_params["max_position_embeddings"]
try:
output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN")
model_config = AutoConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token)
max_position_embeddings = getattr(model_config, 'max_position_embeddings', default_params["max_position_embeddings"])
print(f"Loaded AutoConfig for {model_name}, max_position_embeddings: {max_position_embeddings}")
except Exception as e:
print(f"Failed to load AutoConfig for {model_name}: {e}")
# Calculate recommended max tokens (conservative estimate)
# Leave some room for the prompt, so use ~75% of max_position_embeddings
recommended_max_tokens = min(int(max_position_embeddings * 0.75), MAX_TOKENS)
recommended_max_tokens = max(256, recommended_max_tokens) # Ensure minimum
# Try to load the generation config
gen_config = None
try:
gen_config = GenerationConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token)
print(f"Successfully loaded generation config for {model_name}")
except Exception as e:
print(f"Failed to load GenerationConfig for {model_name}: {e}")
# Extract parameters from generation config or use model-specific defaults
if gen_config:
params = {
"max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', recommended_max_tokens),
"temperature": getattr(gen_config, 'temperature', default_params["temperature"]),
"top_k": getattr(gen_config, 'top_k', default_params["top_k"]),
"top_p": getattr(gen_config, 'top_p', default_params["top_p"]),
"max_position_embeddings": max_position_embeddings,
"recommended_max_tokens": recommended_max_tokens
}
else:
params = dict(default_params)
params["max_position_embeddings"] = max_position_embeddings
params["recommended_max_tokens"] = recommended_max_tokens
# Ensure parameters are within valid ranges
params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS, params["recommended_max_tokens"]))
params["temperature"] = max(0.0, min(params["temperature"], 2.0))
params["top_k"] = max(5, min(params["top_k"], 100))
params["top_p"] = max(0.0, min(params["top_p"], 1.0))
print(f"Final params for {model_name}: {params}")
return params
except Exception as e:
print(f"Could not fetch configs for {model_name}: {e}")
return default_params
def update_generation_params(model_name: str):
"""Update generation parameters based on selected model"""
global MODEL_GEN_PARAMS_CACHE
print(f"Updating generation parameters for model: {model_name}")
print(f"Cache is empty: {len(MODEL_GEN_PARAMS_CACHE) == 0}")
print(f"Current cache keys: {list(MODEL_GEN_PARAMS_CACHE.keys())}")
# If cache is empty, try to populate it now
if len(MODEL_GEN_PARAMS_CACHE) == 0:
print("Cache is empty, attempting to populate now...")
cache_all_model_params()
if model_name in MODEL_GEN_PARAMS_CACHE:
params = MODEL_GEN_PARAMS_CACHE[model_name]
print(f"Found cached params for {model_name}: {params}")
# Set the max_tokens slider maximum to the model's recommended max
max_tokens_limit = min(params.get("recommended_max_tokens", MAX_TOKENS), MAX_TOKENS)
return (
gr.update(value=params["max_tokens"], maximum=max_tokens_limit), # max_tokens with dynamic maximum
gr.update(value=params["temperature"]), # temperature
gr.update(value=params["top_k"]), # top_k
gr.update(value=params["top_p"]) # top_p
)
else:
# Fallback to defaults if model not in cache
print(f"Model {model_name} not found in cache, using defaults")
return (
gr.update(value=1024, maximum=MAX_TOKENS), # max_tokens
gr.update(value=0.7), # temperature
gr.update(value=50), # top_k
gr.update(value=0.95) # top_p
)
def cache_all_model_params():
"""Cache generation parameters for all supported models at startup"""
global MODEL_GEN_PARAMS_CACHE
print(f"Starting to cache parameters for {len(SUPPORTED_MODELS)} models...")
print(f"Supported models: {SUPPORTED_MODELS}")
for model_name in SUPPORTED_MODELS:
try:
print(f"Processing model: {model_name}")
params = fetch_model_generation_params(model_name)
MODEL_GEN_PARAMS_CACHE[model_name] = params
print(f"Successfully cached params for {model_name}: {params}")
except Exception as e:
print(f"Exception while caching params for {model_name}: {e}")
# Use default parameters if caching fails
default_params = {
"max_tokens": 1024,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
"max_position_embeddings": 2048,
"recommended_max_tokens": 1024
}
MODEL_GEN_PARAMS_CACHE[model_name] = default_params
print(f"Using default params for {model_name}: {default_params}")
print(f"Caching complete. Final cache contents:")
for model, params in MODEL_GEN_PARAMS_CACHE.items():
print(f" {model}: {params}")
print(f"Cache size: {len(MODEL_GEN_PARAMS_CACHE)} models")
def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool:
"""Verifies if the user is a Hugging Face PRO user or part of an enterprise org."""
if not token:
return False
if isinstance(token, gr.OAuthToken):
token_str = token.token
elif isinstance(token, str):
token_str = token
else:
return False
try:
user_info = whoami(token=token_str)
return (
user_info.get("isPro", False) or
any(org.get("isEnterprise", False) for org in user_info.get("orgs", []))
)
except Exception as e:
print(f"Could not verify user's PRO/Enterprise status: {e}")
return False
def validate_request(request: GenerationRequest, oauth_token: Optional[Union[gr.OAuthToken, str]] = None) -> GenerationRequest:
# checks that the request is valid
# - input dataset exists and can be accessed with the provided token
try:
input_dataset_info = get_dataset_infos(request.input_dataset_name, token=request.input_dataset_token)[request.input_dataset_config]
except Exception as e:
raise Exception(f"Dataset {request.input_dataset_name} does not exist or cannot be accessed with the provided token.")
# check that the input dataset split exists
if request.input_dataset_split not in input_dataset_info.splits:
raise Exception(f"Dataset split {request.input_dataset_split} does not exist in dataset {request.input_dataset_name}. Available splits: {list(input_dataset_info.splits.keys())}")
# if num_output_examples is 0, set it to the number of examples in the input dataset split
if request.num_output_examples == 0:
request.num_output_examples = input_dataset_info.splits[request.input_dataset_split].num_examples
else:
if request.num_output_examples > input_dataset_info.splits[request.input_dataset_split].num_examples:
raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the number of examples in the input dataset split {input_dataset_info.splits[request.input_dataset_split].num_examples}.")
request.input_dataset_split = f"{request.input_dataset_split}[:{request.num_output_examples}]"
# Check user tier and apply appropriate limits
# Anonymous users (oauth_token is None) are treated as free tier
is_pro = verify_pro_status(oauth_token) if oauth_token else False
max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE
if request.num_output_examples > max_samples:
if oauth_token is None:
user_tier = "non-signed-in"
else:
user_tier = "PRO/Enterprise" if is_pro else "free"
raise Exception(f"Requested number of output examples {request.num_output_examples} exceeds the max limit of {max_samples} for {user_tier} users.")
# check the prompt column exists in the dataset
if request.prompt_column not in input_dataset_info.features:
raise Exception(f"Prompt column {request.prompt_column} does not exist in dataset {request.input_dataset_name}. Available columns: {list(input_dataset_info.features.keys())}")
# This is currently not supported, the output dataset will be created under the org 'synthetic-data-universe'
# check output_dataset name is valid
if request.output_dataset_name.count("/") != 1:
raise Exception("Output dataset will be popululated automatically. The dataset will be created under the org 'synthetic-data-universe/my-dataset'.")
# check the output dataset is valid and accessible with the provided token
try:
get_dataset_infos(request.output_dataset_name, token=request.output_dataset_token)
raise Exception(f"Output dataset {request.output_dataset_name} already exists. Please choose a different name.")
except Exception:
pass # dataset does not exist, which is expected
# check the output dataset name doesn't already exist in the database
try:
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_KEY")
if url and key:
supabase = create_client(
url,
key,
options=ClientOptions(
postgrest_client_timeout=10,
storage_client_timeout=10,
schema="public",
)
)
existing_request = supabase.table("gen-requests").select("id").eq("output_dataset_name", request.output_dataset_name).execute()
if existing_request.data:
raise Exception(f"Output dataset {request.output_dataset_name} is already being generated or has been requested. Please choose a different name.")
except Exception as e:
# If it's our custom exception about dataset already existing, re-raise it
if "already being generated" in str(e):
raise e
# Otherwise, ignore database connection errors and continue
pass
# check the models exists
try:
model_config = AutoConfig.from_pretrained(request.model_name_or_path,
revision=request.model_revision,
force_download=True,
token=False
)
except Exception as e:
print(e)
raise Exception(f"Model {request.model_name_or_path} revision {request.model_revision} does not exist or cannot be accessed. The model may be private or gated, which is not supported at this time.")
# check the model max position embeddings is greater than the requested max tokens and less than MAX_TOKENS
if model_config.max_position_embeddings < request.max_tokens:
raise Exception(f"Model {request.model_name_or_path} max position embeddings {model_config.max_position_embeddings} is less than the requested max tokens {request.max_tokens}.")
if request.max_tokens > MAX_TOKENS:
raise Exception(f"Requested max tokens {request.max_tokens} exceeds the limit of {MAX_TOKENS}.")
# check sampling parameters are valid
if request.temperature < 0.0 or request.temperature > 2.0:
raise Exception("Temperature must be between 0.0 and 2.0")
if request.top_k < 1 or request.top_k > 100:
raise Exception("Top K must be between 1 and 100")
if request.top_p < 0.0 or request.top_p > 1.0:
raise Exception("Top P must be between 0.0 and 1.0")
return request
def load_dataset_info(dataset_name, model_name, oauth_token=None):
"""Load dataset information and return choices for dropdowns"""
if not dataset_name.strip():
return (
gr.update(choices=[], value=None), # config
gr.update(choices=[], value=None), # split
gr.update(choices=[], value=None), # prompt_column
gr.update(value="", interactive=True), # output_dataset_name
gr.update(interactive=False), # num_output_samples
"Please enter a dataset name first."
)
try:
# Get dataset info
dataset_infos = get_dataset_infos(dataset_name)
if not dataset_infos:
raise Exception("No configs found for this dataset")
# Get available configs
config_choices = list(dataset_infos.keys())
default_config = config_choices[0] if config_choices else None
# Get splits and features for the default config
if default_config:
config_info = dataset_infos[default_config]
split_choices = list(config_info.splits.keys())
default_split = split_choices[0] if split_choices else None
# Get column choices (features)
column_choices = list(config_info.features.keys())
default_column = None
# Try to find a likely prompt column
for col in column_choices:
if any(keyword in col.lower() for keyword in ['prompt', 'text', 'question', 'input']):
default_column = col
break
if not default_column and column_choices:
default_column = column_choices[0]
# Get sample count for the default split
dataset_sample_count = config_info.splits[default_split].num_examples if default_split else 0
else:
split_choices = []
column_choices = []
default_split = None
default_column = None
dataset_sample_count = 0
# Determine user limits
is_pro = verify_pro_status(oauth_token) if oauth_token else False
user_max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE
# Set slider maximum to the minimum of dataset samples and user limit
slider_max = min(dataset_sample_count, user_max_samples) if dataset_sample_count > 0 else user_max_samples
# Get username from OAuth token
username = "anonymous"
if oauth_token:
try:
if isinstance(oauth_token, gr.OAuthToken):
token_str = oauth_token.token
elif isinstance(oauth_token, str):
token_str = oauth_token
else:
token_str = None
if token_str:
user_info = whoami(token=token_str)
username = user_info.get("name", "anonymous")
except Exception:
username = "anonymous"
# Generate a suggested output dataset name: username-model-dataset
dataset_base_name = dataset_name.split('/')[-1] if '/' in dataset_name else dataset_name
# Extract model short name (e.g., "Qwen/Qwen3-4B-Instruct-2507" -> "qwen3-4b")
model_short_name = model_name.split('/')[-1]
# Remove common suffixes and simplify
# Build the output name: username-model-dataset
suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}"
# Limit to 86 characters
if len(suggested_output_name) > 86:
# Truncate dataset name to fit within limit
available_for_dataset = 86 - len(username) - len(model_short_name) - 2 # -2 for the hyphens
if available_for_dataset > 0:
dataset_base_name = dataset_base_name[:available_for_dataset]
suggested_output_name = f"{username}-{model_short_name}-{dataset_base_name}"
else:
suggested_output_name = f"{username}-{model_short_name}"
status_msg = f"βœ… Dataset info loaded successfully! Found {len(config_choices)} config(s), {len(split_choices)} split(s), and {len(column_choices)} column(s)."
if dataset_sample_count > 0:
status_msg += f" Dataset has {dataset_sample_count:,} samples."
if dataset_sample_count > user_max_samples:
user_tier = "PRO/Enterprise" if is_pro else "free tier"
status_msg += f" Limited to {user_max_samples:,} samples for {user_tier} users."
return (
gr.update(choices=config_choices, value=default_config, interactive=True), # config
gr.update(choices=split_choices, value=default_split, interactive=True), # split
gr.update(choices=column_choices, value=default_column, interactive=True), # prompt_column
gr.update(value=suggested_output_name, interactive=True), # output_dataset_name
gr.update(interactive=True, maximum=slider_max, value=0), # num_output_samples
status_msg
)
except Exception as e:
return (
gr.update(choices=[], value=None, interactive=False), # config
gr.update(choices=[], value=None, interactive=False), # split
gr.update(choices=[], value=None, interactive=False), # prompt_column
gr.update(value="", interactive=False), # output_dataset_name
gr.update(interactive=False), # num_output_samples
f"❌ Error loading dataset info: {str(e)}"
)
def add_request_to_db(request: GenerationRequest):
url: str = os.getenv("SUPABASE_URL")
key: str = os.getenv("SUPABASE_KEY")
try:
supabase: Client = create_client(
url,
key,
options=ClientOptions(
postgrest_client_timeout=10,
storage_client_timeout=10,
schema="public",
)
)
data = {
"status": request.status.value,
"input_dataset_name": request.input_dataset_name,
"input_dataset_config": request.input_dataset_config,
"input_dataset_split": request.input_dataset_split,
"output_dataset_name": request.output_dataset_name,
"prompt_column": request.prompt_column,
"model_name_or_path": request.model_name_or_path,
"model_revision": request.model_revision,
"model_token": request.model_token,
"system_prompt": request.system_prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"input_dataset_token": request.input_dataset_token,
"output_dataset_token": request.output_dataset_token,
"username": request.username,
"email": request.email,
"num_output_examples": request.num_output_examples,
"private": request.private,
}
supabase.table("gen-requests").insert(data).execute()
except Exception as e:
raise Exception(f"Failed to add request to database: {str(e)}")
def get_generation_stats_safe():
"""Safely fetch generation request statistics with proper error handling"""
try:
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_KEY")
if not url or not key:
raise Exception("Missing SUPABASE_URL or SUPABASE_KEY environment variables")
supabase = create_client(
url,
key,
options=ClientOptions(
postgrest_client_timeout=10,
storage_client_timeout=10,
schema="public",
)
)
# Fetch data excluding sensitive token fields
response = supabase.table("gen-requests").select(
"id, created_at, status, input_dataset_name, input_dataset_config, "
"input_dataset_split, output_dataset_name, prompt_column, "
"model_name_or_path, model_revision, max_tokens, temperature, "
"top_k, top_p, username, num_output_examples, private"
).order("created_at", desc=True).limit(50).execute()
return {"status": "success", "data": response.data}
except Exception as e:
return {"status": "error", "message": str(e), "data": []}
# Old commented code removed - replaced with DatabaseManager and get_generation_stats_safe()
def main():
# Cache model generation parameters at startup
print("Caching model generation parameters...")
cache_all_model_params()
print("Model parameter caching complete.")
with gr.Blocks(title="DataForge - Synthetic Data Generation") as demo:
gr.Image("dataforge.png", show_label=False, show_download_button=False, container=False, height=300)
# Store the current oauth token for use in submit_request
current_oauth_token = gr.State(None)
with gr.Row():
gr.Markdown("") # Empty space for alignment
login_button = gr.LoginButton(value="πŸ”‘ Sign in", size="sm")
gr.Markdown("") # Empty space for alignment
signin_message = gr.Markdown("## πŸ”‘ Sign In Required\n\nPlease sign in with your Hugging Face account to access the synthetic data generation service. Click the **Sign in** button above to continue.", visible=True)
# Main description
gr.Markdown("""
This tool allows you to **generate synthetic data from existing datasets**, for all your **fine-tuning/research/data augmentation** needs!
DataForge is built on top of [DataTrove](https://github.com/huggingface/datatrove), our backend data generation script is open-source and available on [GitHub](https://github.com/huggingface/dataforge). DataForge is **FREE** for HuggingFace PRO users (10,000 samples) β€’ 100 samples for free users.
""")
gr.Markdown("**All generated datasets will be publicly available under the [synthetic-data-universe](https://huggingface.co/synthetic-data-universe) organization.**")
# Usage guide and examples (right below description)
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("Usage Guide", open=False):
gr.Markdown("""
**Step-by-Step Process:**
1. **Choose Model**: Select from 20+ models
2. **Load Dataset**: Enter a HF dataset name
3. **Load Info**: Click "Load Dataset Info"
4. **Configure**: Set generation parameters
5. **Submit**: Monitor progress in Statistics tab
**Requirements:**
- Input dataset must be public on HF Hub
- Model must be publicly accessible
- Free users: 100 samples max, PRO: 10K max
- Token limit: 8,192 per sample
""")
with gr.Column(scale=1):
with gr.Accordion("Examples", open=False):
gr.Markdown("""
**Popular Use Cases:**
**Conversational**: Multi-turn dialogues
- Models: Llama-3.2-3B, Mistral-7B
- Temperature: 0.7-0.9
**Code**: Problem β†’ Solution
- Models: Qwen2.5-Coder, DeepSeek-Coder
- Temperature: 0.1-0.3
**Example datasets to try:**
```
simplescaling/s1K-1.1
HuggingFaceH4/ultrachat_200k
iamtarun/python_code_instructions_18k_alpaca
```
""")
# Sign in button
main_interface = gr.Column(visible=False)
with main_interface:
with gr.Tabs():
with gr.TabItem("Generate Data"):
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown("## Model information")
with gr.Column():
with gr.Row():
model_name_or_path = gr.Dropdown(
choices=SUPPORTED_MODELS,
label="Select Model",
value="Qwen/Qwen3-4B-Instruct-2507",
info="Choose from popular instruction-tuned models under 40B parameters"
)
# model_token = gr.Textbox(label="Model Token (Optional)", type="password", placeholder="Your HF token with read/write access to the model...")
with gr.Row():
system_prompt = gr.Textbox(label="System Prompt (Optional)", placeholder="Optional system prompt... e.g., You are a helpful assistant.", info="Sets the AI's role/behavior. Leave empty for default model behavior.")
gr.Markdown("### Generation Parameters")
with gr.Row():
with gr.Column():
with gr.Row():
max_tokens = gr.Slider(label="Max Tokens", value=1024, minimum=256, maximum=MAX_TOKENS, step=256, info="Maximum tokens to generate per sample. Higher = longer responses.")
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1, info="Creativity level: 0.1=focused, 0.7=balanced, 1.0+=creative")
with gr.Row():
top_k = gr.Slider(label="Top K", value=50, minimum=5, maximum=100, step=5, info="Limits word choices to top K options. Lower = more focused.")
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05, info="Nucleus sampling: 0.9=focused, 0.95=balanced diversity")
with gr.Column():
with gr.Group():
gr.Markdown("## Dataset information")
# Dynamic user limit info - default to anonymous user
user_limit_info = gr.Markdown(value="πŸ‘€ **Anonymous User**: You can generate up to 100 samples per request. Use the sign-in button above for PRO benefits (10,000 samples).", visible=True)
with gr.Row():
with gr.Column():
input_dataset_name = gr.Textbox(label="Input Dataset Name", placeholder="e.g., simplescaling/s1K-1.1", info="Public HF dataset with prompts to generate from")
load_info_btn = gr.Button("πŸ“Š Load Dataset Info", size="sm", variant="secondary")
load_info_status = gr.Markdown("", visible=True)
with gr.Column():
output_dataset_name = gr.Textbox(label="Output Dataset Name", placeholder="e.g., my-generated-dataset, must be unique. Will be created under the org 'synthetic-data-universe'", value=None, interactive=False, info="Click Load Info to populate")
with gr.Row():
with gr.Column():
input_dataset_config = gr.Dropdown(label="Dataset Config", choices=[], value=None, interactive=False, info="Click Load Info to populate")
prompt_column = gr.Dropdown(label="Prompt Column", choices=[], value=None, interactive=False, info="Click Load Info to populate")
with gr.Column():
input_dataset_split = gr.Dropdown(label="Dataset Split", choices=[], value=None, interactive=False, info="Click Load Info to populate")
num_output_samples = gr.Slider(label="Number of samples, leave as '0' for all", value=0, minimum=0, maximum=MAX_SAMPLES_FREE, step=1, interactive=False, info="Click Load Info to populate")
submit_btn = gr.Button("Submit Generation Request", variant="primary")
output_status = gr.Textbox(label="Status", interactive=False)
with gr.TabItem("Statistics Dashboard"):
gr.Markdown("## DataForge Generation Statistics")
gr.Markdown("πŸ“Š View recent synthetic data generation requests and their status.")
with gr.Row():
refresh_stats_btn = gr.Button("πŸ”„ Refresh Statistics", size="sm", variant="secondary")
clear_stats_btn = gr.Button("πŸ—‘οΈ Clear Display", size="sm")
stats_status = gr.Markdown("Click 'Refresh Statistics' to load recent generation requests.", visible=True)
stats_dataframe = gr.Dataframe(
headers=["ID", "Created", "Status", "Input Dataset", "Output Dataset", "Model", "Samples", "User"],
datatype=["str", "str", "str", "str", "str", "str", "number", "str"],
interactive=False,
wrap=True,
value=[],
label="Recent Generation Requests (Last 50)",
visible=False
)
def load_statistics():
"""Load and format statistics data"""
try:
# Use the new safe database function
result = get_generation_stats_safe()
if result["status"] == "error":
return (
f"❌ **Error loading statistics**: {result['message']}",
gr.update(visible=False),
gr.update(visible=True)
)
data = result["data"]
if not data:
return (
"πŸ“ **No data found**: The database appears to be empty or the table doesn't exist yet.",
gr.update(visible=False),
gr.update(visible=True)
)
# Format data for display
formatted_data = []
for item in data:
# Format timestamp
created_at = item.get('created_at', 'Unknown')
if created_at and created_at != 'Unknown':
try:
from datetime import datetime
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
created_at = dt.strftime('%Y-%m-%d %H:%M')
except:
pass
formatted_data.append([
str(item.get('id', ''))[:8] + "..." if len(str(item.get('id', ''))) > 8 else str(item.get('id', '')),
created_at,
item.get('status', 'Unknown'),
(item.get('input_dataset_name', '')[:30] + "...") if len(item.get('input_dataset_name', '')) > 30 else item.get('input_dataset_name', ''),
(item.get('output_dataset_name', '')[:30] + "...") if len(item.get('output_dataset_name', '')) > 30 else item.get('output_dataset_name', ''),
(item.get('model_name_or_path', '')[:25] + "...") if len(item.get('model_name_or_path', '')) > 25 else item.get('model_name_or_path', ''),
item.get('num_output_examples', 0),
item.get('username', 'Anonymous')
])
return (
f"βœ… **Statistics loaded successfully**: Found {len(formatted_data)} recent requests.",
gr.update(value=formatted_data, visible=True),
gr.update(visible=True)
)
except Exception as e:
return (
f"❌ **Unexpected error**: {str(e)}",
gr.update(visible=False),
gr.update(visible=True)
)
def clear_statistics():
"""Clear the statistics display"""
return (
"Click 'Refresh Statistics' to load recent generation requests.",
gr.update(value=[], visible=False),
gr.update(visible=True)
)
# Connect buttons to functions
refresh_stats_btn.click(
load_statistics,
outputs=[stats_status, stats_dataframe, stats_status]
)
clear_stats_btn.click(
clear_statistics,
outputs=[stats_status, stats_dataframe, stats_status]
)
def submit_request(input_dataset_name, input_split, input_dataset_config, output_dataset_name, prompt_col, model_name, sys_prompt,
max_tok, temp, top_k_val, top_p_val, num_output_samples, oauth_token=None):
MASTER_ORG = "synthetic-data-universe/"
model_token = False # This is currently not supported
input_dataset_token = None # This is currently not supported
output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN")
# Get username from OAuth token
username = "anonymous"
if oauth_token:
try:
if isinstance(oauth_token, gr.OAuthToken):
token_str = oauth_token.token
elif isinstance(oauth_token, str):
token_str = oauth_token
else:
token_str = None
if token_str:
user_info = whoami(token=token_str)
username = user_info.get("name", "unknown")
except Exception:
username = "unknown"
try:
request = GenerationRequest(
id="", # Will be generated when adding to the database
created_at="", # Will be set when adding to the database
status=GenerationStatus.PENDING,
input_dataset_name=input_dataset_name,
input_dataset_split=input_split,
input_dataset_config=input_dataset_config,
output_dataset_name=MASTER_ORG + output_dataset_name,
prompt_column=prompt_col,
model_name_or_path=model_name,
model_revision="main",
model_token=model_token,
system_prompt=sys_prompt if sys_prompt else None,
max_tokens=int(max_tok),
temperature=temp,
top_k=int(top_k_val),
top_p=top_p_val,
input_dataset_token=input_dataset_token if input_dataset_token else None,
output_dataset_token=output_dataset_token,
num_output_examples=num_output_samples, # will be set after validating the input dataset
username=username,
email="n/a",
)
# check the input dataset exists and can be accessed with the provided token
request = validate_request(request, oauth_token)
add_request_to_db(request)
return "Request submitted successfully!"
except Exception as e:
return f"Error: {str(e)}"
# Wire up the Load Dataset Info button
load_info_btn.click(
load_dataset_info,
inputs=[input_dataset_name, model_name_or_path, current_oauth_token],
outputs=[input_dataset_config, input_dataset_split, prompt_column, output_dataset_name, num_output_samples, load_info_status]
)
# Wire up model change to update generation parameters
model_name_or_path.change(
update_generation_params,
inputs=[model_name_or_path],
outputs=[max_tokens, temperature, top_k, top_p]
)
submit_btn.click(
submit_request,
inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path,
system_prompt, max_tokens, temperature, top_k, top_p, num_output_samples, current_oauth_token],
outputs=output_status
)
def update_user_limits(oauth_token):
if oauth_token is None:
return "πŸ‘€ **Anonymous User**: You can generate up to 100 samples per request. Use the sign-in button above for PRO benefits (10,000 samples)."
is_pro = verify_pro_status(oauth_token)
if is_pro:
return "✨ **PRO User**: You can generate up to 10,000 samples per request."
else:
return "πŸ‘€ **Free User**: You can generate up to 100 samples per request. [Upgrade to PRO](http://huggingface.co/subscribe/pro?source=synthetic-data-universe) for 10,000 samples."
def control_access(profile: Optional[gr.OAuthProfile] = None, oauth_token: Optional[gr.OAuthToken] = None):
# Require users to be signed in
if oauth_token is None:
# User is not signed in - show sign-in prompt, hide main interface
return (
gr.update(visible=False), # main_interface
gr.update(visible=True), # signin_message
oauth_token, # current_oauth_token
"", # user_limit_info (empty when not signed in)
gr.update(), # num_output_samples (no change)
gr.update(value="πŸ”‘ Sign in") # login_button
)
else:
# User is signed in - show main interface, hide sign-in prompt
limit_msg = update_user_limits(oauth_token)
is_pro = verify_pro_status(oauth_token)
max_samples = MAX_SAMPLES_PRO if is_pro else MAX_SAMPLES_FREE
if is_pro:
button_text = f"✨ Signed in as PRO ({profile.name if profile else 'User'})"
else:
button_text = f"πŸ‘€ Signed in as {profile.name if profile else 'User'}"
return (
gr.update(visible=True), # main_interface
gr.update(visible=False), # signin_message
oauth_token, # current_oauth_token
limit_msg, # user_limit_info
gr.update(maximum=max_samples), # num_output_samples
gr.update(value=button_text) # login_button
)
# Handle login state changes - LoginButton automatically handles auth state changes
# The demo.load will handle both initial load and auth changes
demo.load(control_access, inputs=None, outputs=[main_interface, signin_message, current_oauth_token, user_limit_info, num_output_samples, login_button])
demo.queue(max_size=None, default_concurrency_limit=None).launch(show_error=True)
if __name__ == "__main__":
main()