edbeeching commited on
Commit
b26773a
·
1 Parent(s): a2a9a72

add gen param cache

Browse files
Files changed (1) hide show
  1. app.py +176 -36
app.py CHANGED
@@ -6,7 +6,7 @@ from supabase import create_client, Client
6
  from supabase.client import ClientOptions
7
  from enum import Enum
8
  from datasets import get_dataset_infos
9
- from transformers import AutoConfig
10
  from huggingface_hub import whoami
11
  from typing import Optional, Union
12
 
@@ -17,41 +17,6 @@ from typing import Optional, Union
17
  - validate max model params
18
  """
19
 
20
- SUPPORTED_MODELS = [
21
- "Qwen/Qwen3-4B-Instruct-2507",
22
- "Qwen/Qwen3-30B-A3B-Instruct-2507",
23
- "meta-llama/Llama-3.2-1B-Instruct",
24
- "meta-llama/Llama-3.2-3B-Instruct",
25
- "baidu/ERNIE-4.5-21B-A3B-Thinking",
26
- "LLM360/K2-Think",
27
- "openai/gpt-oss-20b",
28
- ]
29
-
30
-
31
- def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool:
32
- """Verifies if the user is a Hugging Face PRO user or part of an enterprise org."""
33
- if not token:
34
- return False
35
-
36
- if isinstance(token, gr.OAuthToken):
37
- token_str = token.token
38
- elif isinstance(token, str):
39
- token_str = token
40
- else:
41
- return False
42
-
43
- try:
44
- user_info = whoami(token=token_str)
45
- return (
46
- user_info.get("isPro", False) or
47
- any(org.get("isEnterprise", False) for org in user_info.get("orgs", []))
48
- )
49
- except Exception as e:
50
- print(f"Could not verify user's PRO/Enterprise status: {e}")
51
- return False
52
-
53
-
54
-
55
  class GenerationStatus(Enum):
56
  PENDING = "PENDING"
57
  RUNNING = "RUNNING"
@@ -64,6 +29,9 @@ MAX_SAMPLES_FREE = 100 # max number of samples for free users
64
  MAX_TOKENS = 8192
65
  MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now)
66
 
 
 
 
67
  @dataclass
68
  class GenerationRequest:
69
  id: str
@@ -89,6 +57,166 @@ class GenerationRequest:
89
  num_output_examples: int
90
  private: bool = False
91
  num_retries: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def validate_request(request: GenerationRequest, oauth_token: Optional[Union[gr.OAuthToken, str]] = None) -> GenerationRequest:
94
  # checks that the request is valid
@@ -401,6 +529,11 @@ def get_generation_stats_safe():
401
 
402
 
403
  def main():
 
 
 
 
 
404
  with gr.Blocks(title="Synthetic Data Generation") as demo:
405
  gr.HTML("<h3 style='text-align:center'>Generate synthetic data with AI models. Free to use! Sign up for PRO benefits (10k samples vs 100). <a href='http://huggingface.co/subscribe/pro?source=synthetic-data-universe' target='_blank'>Upgrade to PRO</a></h3>", elem_id="sub_title")
406
 
@@ -659,6 +792,13 @@ def main():
659
  outputs=[input_dataset_config, input_dataset_split, prompt_column, output_dataset_name, num_output_samples, load_info_status]
660
  )
661
 
 
 
 
 
 
 
 
662
  submit_btn.click(
663
  submit_request,
664
  inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path,
 
6
  from supabase.client import ClientOptions
7
  from enum import Enum
8
  from datasets import get_dataset_infos
9
+ from transformers import AutoConfig, GenerationConfig
10
  from huggingface_hub import whoami
11
  from typing import Optional, Union
12
 
 
17
  - validate max model params
18
  """
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class GenerationStatus(Enum):
21
  PENDING = "PENDING"
22
  RUNNING = "RUNNING"
 
29
  MAX_TOKENS = 8192
30
  MAX_MODEL_PARAMS = 20_000_000_000 # 20 billion parameters (for now)
31
 
32
+ # Cache for model generation parameters
33
+ MODEL_GEN_PARAMS_CACHE = {}
34
+
35
  @dataclass
36
  class GenerationRequest:
37
  id: str
 
57
  num_output_examples: int
58
  private: bool = False
59
  num_retries: int = 0
60
+
61
+ SUPPORTED_MODELS = [
62
+ "Qwen/Qwen3-4B-Instruct-2507",
63
+ "Qwen/Qwen3-30B-A3B-Instruct-2507",
64
+ "meta-llama/Llama-3.2-1B-Instruct",
65
+ "meta-llama/Llama-3.2-3B-Instruct",
66
+ "baidu/ERNIE-4.5-21B-A3B-Thinking",
67
+ "LLM360/K2-Think",
68
+ "openai/gpt-oss-20b",
69
+ ]
70
+
71
+
72
+ def fetch_model_generation_params(model_name: str) -> dict:
73
+ """Fetch generation parameters from model's generation config on the hub"""
74
+ default_params = {
75
+ "max_tokens": 1024,
76
+ "temperature": 0.7,
77
+ "top_k": 50,
78
+ "top_p": 0.95
79
+ }
80
+
81
+ try:
82
+ print(f"Attempting to fetch generation config for: {model_name}")
83
+
84
+ # Try to load the generation config
85
+ try:
86
+ gen_config = GenerationConfig.from_pretrained(model_name, force_download=False)
87
+ print(f"Successfully loaded generation config for {model_name}")
88
+ print(f"Config attributes: {dir(gen_config)}")
89
+ except Exception as e:
90
+ print(f"Failed to load GenerationConfig for {model_name}: {e}")
91
+ # Try loading from model config instead
92
+ try:
93
+ from transformers import AutoConfig
94
+ model_config = AutoConfig.from_pretrained(model_name, force_download=False)
95
+ print(f"Loaded AutoConfig for {model_name} instead")
96
+
97
+ # Use some reasonable defaults based on model type
98
+ if "qwen" in model_name.lower():
99
+ return {"max_tokens": 2048, "temperature": 0.7, "top_k": 50, "top_p": 0.8}
100
+ elif "llama" in model_name.lower():
101
+ return {"max_tokens": 2048, "temperature": 0.6, "top_k": 40, "top_p": 0.9}
102
+ elif "ernie" in model_name.lower():
103
+ return {"max_tokens": 1024, "temperature": 0.7, "top_k": 50, "top_p": 0.95}
104
+ else:
105
+ return default_params
106
+
107
+ except Exception as e2:
108
+ print(f"Failed to load any config for {model_name}: {e2}")
109
+ return default_params
110
+
111
+ # Extract relevant parameters with fallbacks to defaults
112
+ params = {
113
+ "max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', default_params["max_tokens"]),
114
+ "temperature": getattr(gen_config, 'temperature', default_params["temperature"]),
115
+ "top_k": getattr(gen_config, 'top_k', default_params["top_k"]),
116
+ "top_p": getattr(gen_config, 'top_p', default_params["top_p"])
117
+ }
118
+
119
+ # Ensure parameters are within valid ranges
120
+ params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS))
121
+ params["temperature"] = max(0.0, min(params["temperature"], 2.0))
122
+ params["top_k"] = max(5, min(params["top_k"], 100))
123
+ params["top_p"] = max(0.0, min(params["top_p"], 1.0))
124
+
125
+ print(f"Final params for {model_name}: {params}")
126
+ return params
127
+
128
+ except Exception as e:
129
+ print(f"Could not fetch generation config for {model_name}: {e}")
130
+ return default_params
131
+
132
+
133
+ def update_generation_params(model_name: str):
134
+ """Update generation parameters based on selected model"""
135
+ global MODEL_GEN_PARAMS_CACHE
136
+
137
+ print(f"Updating generation parameters for model: {model_name}")
138
+ print(f"Cache is empty: {len(MODEL_GEN_PARAMS_CACHE) == 0}")
139
+ print(f"Current cache keys: {list(MODEL_GEN_PARAMS_CACHE.keys())}")
140
+
141
+ # If cache is empty, try to populate it now
142
+ if len(MODEL_GEN_PARAMS_CACHE) == 0:
143
+ print("Cache is empty, attempting to populate now...")
144
+ cache_all_model_params()
145
+
146
+ if model_name in MODEL_GEN_PARAMS_CACHE:
147
+ params = MODEL_GEN_PARAMS_CACHE[model_name]
148
+ print(f"Found cached params for {model_name}: {params}")
149
+ return (
150
+ gr.update(value=params["max_tokens"]), # max_tokens
151
+ gr.update(value=params["temperature"]), # temperature
152
+ gr.update(value=params["top_k"]), # top_k
153
+ gr.update(value=params["top_p"]) # top_p
154
+ )
155
+ else:
156
+ # Fallback to defaults if model not in cache
157
+ print(f"Model {model_name} not found in cache, using defaults")
158
+ return (
159
+ gr.update(value=1024), # max_tokens
160
+ gr.update(value=0.7), # temperature
161
+ gr.update(value=50), # top_k
162
+ gr.update(value=0.95) # top_p
163
+ )
164
+
165
+
166
+ def cache_all_model_params():
167
+ """Cache generation parameters for all supported models at startup"""
168
+ global MODEL_GEN_PARAMS_CACHE
169
+
170
+ print(f"Starting to cache parameters for {len(SUPPORTED_MODELS)} models...")
171
+ print(f"Supported models: {SUPPORTED_MODELS}")
172
+
173
+ for model_name in SUPPORTED_MODELS:
174
+ try:
175
+ print(f"Processing model: {model_name}")
176
+ params = fetch_model_generation_params(model_name)
177
+ MODEL_GEN_PARAMS_CACHE[model_name] = params
178
+ print(f"Successfully cached params for {model_name}: {params}")
179
+ except Exception as e:
180
+ print(f"Exception while caching params for {model_name}: {e}")
181
+ # Use default parameters if caching fails
182
+ default_params = {
183
+ "max_tokens": 1024,
184
+ "temperature": 0.7,
185
+ "top_k": 50,
186
+ "top_p": 0.95
187
+ }
188
+ MODEL_GEN_PARAMS_CACHE[model_name] = default_params
189
+ print(f"Using default params for {model_name}: {default_params}")
190
+
191
+ print(f"Caching complete. Final cache contents:")
192
+ for model, params in MODEL_GEN_PARAMS_CACHE.items():
193
+ print(f" {model}: {params}")
194
+ print(f"Cache size: {len(MODEL_GEN_PARAMS_CACHE)} models")
195
+
196
+ def verify_pro_status(token: Optional[Union[gr.OAuthToken, str]]) -> bool:
197
+ """Verifies if the user is a Hugging Face PRO user or part of an enterprise org."""
198
+ if not token:
199
+ return False
200
+
201
+ if isinstance(token, gr.OAuthToken):
202
+ token_str = token.token
203
+ elif isinstance(token, str):
204
+ token_str = token
205
+ else:
206
+ return False
207
+
208
+ try:
209
+ user_info = whoami(token=token_str)
210
+ return (
211
+ user_info.get("isPro", False) or
212
+ any(org.get("isEnterprise", False) for org in user_info.get("orgs", []))
213
+ )
214
+ except Exception as e:
215
+ print(f"Could not verify user's PRO/Enterprise status: {e}")
216
+ return False
217
+
218
+
219
+
220
 
221
  def validate_request(request: GenerationRequest, oauth_token: Optional[Union[gr.OAuthToken, str]] = None) -> GenerationRequest:
222
  # checks that the request is valid
 
529
 
530
 
531
  def main():
532
+ # Cache model generation parameters at startup
533
+ print("Caching model generation parameters...")
534
+ cache_all_model_params()
535
+ print("Model parameter caching complete.")
536
+
537
  with gr.Blocks(title="Synthetic Data Generation") as demo:
538
  gr.HTML("<h3 style='text-align:center'>Generate synthetic data with AI models. Free to use! Sign up for PRO benefits (10k samples vs 100). <a href='http://huggingface.co/subscribe/pro?source=synthetic-data-universe' target='_blank'>Upgrade to PRO</a></h3>", elem_id="sub_title")
539
 
 
792
  outputs=[input_dataset_config, input_dataset_split, prompt_column, output_dataset_name, num_output_samples, load_info_status]
793
  )
794
 
795
+ # Wire up model change to update generation parameters
796
+ model_name_or_path.change(
797
+ update_generation_params,
798
+ inputs=[model_name_or_path],
799
+ outputs=[max_tokens, temperature, top_k, top_p]
800
+ )
801
+
802
  submit_btn.click(
803
  submit_request,
804
  inputs=[input_dataset_name, input_dataset_split, input_dataset_config, output_dataset_name, prompt_column, model_name_or_path,