edbeeching commited on
Commit
61d5b4a
·
1 Parent(s): b26773a

add model gen params

Browse files
Files changed (1) hide show
  1. app.py +58 -37
app.py CHANGED
@@ -70,54 +70,69 @@ SUPPORTED_MODELS = [
70
 
71
 
72
  def fetch_model_generation_params(model_name: str) -> dict:
73
- """Fetch generation parameters from model's generation config on the hub"""
74
  default_params = {
75
  "max_tokens": 1024,
76
  "temperature": 0.7,
77
  "top_k": 50,
78
- "top_p": 0.95
 
 
79
  }
80
 
81
  try:
82
- print(f"Attempting to fetch generation config for: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Try to load the generation config
 
85
  try:
86
- gen_config = GenerationConfig.from_pretrained(model_name, force_download=False)
87
  print(f"Successfully loaded generation config for {model_name}")
88
- print(f"Config attributes: {dir(gen_config)}")
89
  except Exception as e:
90
  print(f"Failed to load GenerationConfig for {model_name}: {e}")
91
- # Try loading from model config instead
92
- try:
93
- from transformers import AutoConfig
94
- model_config = AutoConfig.from_pretrained(model_name, force_download=False)
95
- print(f"Loaded AutoConfig for {model_name} instead")
96
-
97
- # Use some reasonable defaults based on model type
98
- if "qwen" in model_name.lower():
99
- return {"max_tokens": 2048, "temperature": 0.7, "top_k": 50, "top_p": 0.8}
100
- elif "llama" in model_name.lower():
101
- return {"max_tokens": 2048, "temperature": 0.6, "top_k": 40, "top_p": 0.9}
102
- elif "ernie" in model_name.lower():
103
- return {"max_tokens": 1024, "temperature": 0.7, "top_k": 50, "top_p": 0.95}
104
- else:
105
- return default_params
106
-
107
- except Exception as e2:
108
- print(f"Failed to load any config for {model_name}: {e2}")
109
- return default_params
110
-
111
- # Extract relevant parameters with fallbacks to defaults
112
- params = {
113
- "max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', default_params["max_tokens"]),
114
- "temperature": getattr(gen_config, 'temperature', default_params["temperature"]),
115
- "top_k": getattr(gen_config, 'top_k', default_params["top_k"]),
116
- "top_p": getattr(gen_config, 'top_p', default_params["top_p"])
117
- }
118
 
119
  # Ensure parameters are within valid ranges
120
- params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS))
121
  params["temperature"] = max(0.0, min(params["temperature"], 2.0))
122
  params["top_k"] = max(5, min(params["top_k"], 100))
123
  params["top_p"] = max(0.0, min(params["top_p"], 1.0))
@@ -126,7 +141,7 @@ def fetch_model_generation_params(model_name: str) -> dict:
126
  return params
127
 
128
  except Exception as e:
129
- print(f"Could not fetch generation config for {model_name}: {e}")
130
  return default_params
131
 
132
 
@@ -146,8 +161,12 @@ def update_generation_params(model_name: str):
146
  if model_name in MODEL_GEN_PARAMS_CACHE:
147
  params = MODEL_GEN_PARAMS_CACHE[model_name]
148
  print(f"Found cached params for {model_name}: {params}")
 
 
 
 
149
  return (
150
- gr.update(value=params["max_tokens"]), # max_tokens
151
  gr.update(value=params["temperature"]), # temperature
152
  gr.update(value=params["top_k"]), # top_k
153
  gr.update(value=params["top_p"]) # top_p
@@ -156,7 +175,7 @@ def update_generation_params(model_name: str):
156
  # Fallback to defaults if model not in cache
157
  print(f"Model {model_name} not found in cache, using defaults")
158
  return (
159
- gr.update(value=1024), # max_tokens
160
  gr.update(value=0.7), # temperature
161
  gr.update(value=50), # top_k
162
  gr.update(value=0.95) # top_p
@@ -183,7 +202,9 @@ def cache_all_model_params():
183
  "max_tokens": 1024,
184
  "temperature": 0.7,
185
  "top_k": 50,
186
- "top_p": 0.95
 
 
187
  }
188
  MODEL_GEN_PARAMS_CACHE[model_name] = default_params
189
  print(f"Using default params for {model_name}: {default_params}")
 
70
 
71
 
72
  def fetch_model_generation_params(model_name: str) -> dict:
73
+ """Fetch generation parameters and model config from the hub"""
74
  default_params = {
75
  "max_tokens": 1024,
76
  "temperature": 0.7,
77
  "top_k": 50,
78
+ "top_p": 0.95,
79
+ "max_position_embeddings": 2048,
80
+ "recommended_max_tokens": 1024
81
  }
82
 
83
  try:
84
+ print(f"Attempting to fetch configs for: {model_name}")
85
+
86
+ # Always try to load the model config first for max_position_embeddings
87
+ model_config = None
88
+ max_position_embeddings = default_params["max_position_embeddings"]
89
+
90
+ try:
91
+ output_dataset_token = os.getenv("OUTPUT_DATASET_TOKEN")
92
+ model_config = AutoConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token)
93
+ max_position_embeddings = getattr(model_config, 'max_position_embeddings', default_params["max_position_embeddings"])
94
+ print(f"Loaded AutoConfig for {model_name}, max_position_embeddings: {max_position_embeddings}")
95
+ except Exception as e:
96
+ print(f"Failed to load AutoConfig for {model_name}: {e}")
97
+
98
+ # Calculate recommended max tokens (conservative estimate)
99
+ # Leave some room for the prompt, so use ~75% of max_position_embeddings
100
+ recommended_max_tokens = min(int(max_position_embeddings * 0.75), MAX_TOKENS)
101
+ recommended_max_tokens = max(256, recommended_max_tokens) # Ensure minimum
102
 
103
  # Try to load the generation config
104
+ gen_config = None
105
  try:
106
+ gen_config = GenerationConfig.from_pretrained(model_name, force_download=False, token=output_dataset_token)
107
  print(f"Successfully loaded generation config for {model_name}")
 
108
  except Exception as e:
109
  print(f"Failed to load GenerationConfig for {model_name}: {e}")
110
+
111
+ # Extract parameters from generation config or use model-specific defaults
112
+ if gen_config:
113
+ params = {
114
+ "max_tokens": getattr(gen_config, 'max_new_tokens', None) or getattr(gen_config, 'max_length', recommended_max_tokens),
115
+ "temperature": getattr(gen_config, 'temperature', default_params["temperature"]),
116
+ "top_k": getattr(gen_config, 'top_k', default_params["top_k"]),
117
+ "top_p": getattr(gen_config, 'top_p', default_params["top_p"]),
118
+ "max_position_embeddings": max_position_embeddings,
119
+ "recommended_max_tokens": recommended_max_tokens
120
+ }
121
+ else:
122
+ # Use model-specific defaults based on model name
123
+ if "qwen" in model_name.lower():
124
+ params = {"max_tokens": recommended_max_tokens, "temperature": 0.7, "top_k": 50, "top_p": 0.8, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens}
125
+ elif "llama" in model_name.lower():
126
+ params = {"max_tokens": recommended_max_tokens, "temperature": 0.6, "top_k": 40, "top_p": 0.9, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens}
127
+ elif "ernie" in model_name.lower():
128
+ params = {"max_tokens": min(recommended_max_tokens, 1024), "temperature": 0.7, "top_k": 50, "top_p": 0.95, "max_position_embeddings": max_position_embeddings, "recommended_max_tokens": recommended_max_tokens}
129
+ else:
130
+ params = dict(default_params)
131
+ params["max_position_embeddings"] = max_position_embeddings
132
+ params["recommended_max_tokens"] = recommended_max_tokens
 
 
 
 
133
 
134
  # Ensure parameters are within valid ranges
135
+ params["max_tokens"] = max(256, min(params["max_tokens"], MAX_TOKENS, params["recommended_max_tokens"]))
136
  params["temperature"] = max(0.0, min(params["temperature"], 2.0))
137
  params["top_k"] = max(5, min(params["top_k"], 100))
138
  params["top_p"] = max(0.0, min(params["top_p"], 1.0))
 
141
  return params
142
 
143
  except Exception as e:
144
+ print(f"Could not fetch configs for {model_name}: {e}")
145
  return default_params
146
 
147
 
 
161
  if model_name in MODEL_GEN_PARAMS_CACHE:
162
  params = MODEL_GEN_PARAMS_CACHE[model_name]
163
  print(f"Found cached params for {model_name}: {params}")
164
+
165
+ # Set the max_tokens slider maximum to the model's recommended max
166
+ max_tokens_limit = min(params.get("recommended_max_tokens", MAX_TOKENS), MAX_TOKENS)
167
+
168
  return (
169
+ gr.update(value=params["max_tokens"], maximum=max_tokens_limit), # max_tokens with dynamic maximum
170
  gr.update(value=params["temperature"]), # temperature
171
  gr.update(value=params["top_k"]), # top_k
172
  gr.update(value=params["top_p"]) # top_p
 
175
  # Fallback to defaults if model not in cache
176
  print(f"Model {model_name} not found in cache, using defaults")
177
  return (
178
+ gr.update(value=1024, maximum=MAX_TOKENS), # max_tokens
179
  gr.update(value=0.7), # temperature
180
  gr.update(value=50), # top_k
181
  gr.update(value=0.95) # top_p
 
202
  "max_tokens": 1024,
203
  "temperature": 0.7,
204
  "top_k": 50,
205
+ "top_p": 0.95,
206
+ "max_position_embeddings": 2048,
207
+ "recommended_max_tokens": 1024
208
  }
209
  MODEL_GEN_PARAMS_CACHE[model_name] = default_params
210
  print(f"Using default params for {model_name}: {default_params}")