Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,7 +12,19 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 12 |
if not HF_TOKEN:
|
| 13 |
raise ValueError("HF_TOKEN environment variable is not set")
|
| 14 |
|
| 15 |
-
def query(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
print("Starting query function...")
|
| 17 |
|
| 18 |
if not prompt:
|
|
@@ -235,18 +247,22 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
|
|
| 235 |
else:
|
| 236 |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
|
| 237 |
|
| 238 |
-
# Prepare payload
|
|
|
|
| 239 |
payload = {
|
| 240 |
"inputs": prompt,
|
| 241 |
-
"is_negative": is_negative,
|
| 242 |
-
"steps": steps,
|
| 243 |
-
"cfg_scale": cfg_scale,
|
| 244 |
-
"seed": seed if seed != -1 else random.randint(1, 1000000000),
|
| 245 |
-
"strength": strength,
|
| 246 |
"parameters": {
|
|
|
|
|
|
|
|
|
|
| 247 |
"width": width,
|
| 248 |
-
"height": height
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
}
|
| 251 |
|
| 252 |
# Improved retry logic with exponential backoff
|
|
@@ -256,20 +272,27 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
|
|
| 256 |
|
| 257 |
while current_retry < max_retries:
|
| 258 |
try:
|
| 259 |
-
response = requests.post(API_URL, headers=headers, json=payload, timeout=180)
|
| 260 |
-
response.raise_for_status()
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
image = Image.open(io.BytesIO(response.content))
|
|
|
|
| 263 |
print(f'Generation {key} completed successfully')
|
| 264 |
return image
|
| 265 |
|
| 266 |
-
except (requests.exceptions.Timeout,
|
| 267 |
-
requests.exceptions.
|
|
|
|
|
|
|
| 268 |
current_retry += 1
|
| 269 |
if current_retry < max_retries:
|
| 270 |
wait_time = backoff_factor ** current_retry # Exponential backoff
|
| 271 |
print(f"Network error occurred: {str(e)}. Retrying in {wait_time} seconds... (Attempt {current_retry + 1}/{max_retries})")
|
| 272 |
-
time.sleep(wait_time)
|
| 273 |
continue
|
| 274 |
else:
|
| 275 |
# Detailed error message based on exception type
|
|
@@ -513,4 +536,4 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
|
|
| 513 |
dalle.load(fn=update_network_status, outputs=network_status)
|
| 514 |
|
| 515 |
if __name__ == "__main__":
|
| 516 |
-
dalle.launch(show_api=False, share=False)
|
|
|
|
| 12 |
if not HF_TOKEN:
|
| 13 |
raise ValueError("HF_TOKEN environment variable is not set")
|
| 14 |
|
| 15 |
+
def query(
|
| 16 |
+
prompt,
|
| 17 |
+
model,
|
| 18 |
+
custom_lora,
|
| 19 |
+
negative_prompt="", # โ ๊ธฐ์กด is_negative=False โ negative_prompt="" ๋ก ๋ณ๊ฒฝ
|
| 20 |
+
steps=35,
|
| 21 |
+
cfg_scale=7,
|
| 22 |
+
sampler="DPM++ 2M Karras",
|
| 23 |
+
seed=-1,
|
| 24 |
+
strength=0.7,
|
| 25 |
+
width=1024,
|
| 26 |
+
height=1024
|
| 27 |
+
):
|
| 28 |
print("Starting query function...")
|
| 29 |
|
| 30 |
if not prompt:
|
|
|
|
| 247 |
else:
|
| 248 |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
|
| 249 |
|
| 250 |
+
# Prepare payload in Hugging Face Inference API style
|
| 251 |
+
# (negative_prompt, steps, cfg_scale, seed, strength ๋ฑ์ parameters ์์ ๋ฐฐ์น)
|
| 252 |
payload = {
|
| 253 |
"inputs": prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
"parameters": {
|
| 255 |
+
"negative_prompt": negative_prompt,
|
| 256 |
+
"num_inference_steps": steps,
|
| 257 |
+
"guidance_scale": cfg_scale,
|
| 258 |
"width": width,
|
| 259 |
+
"height": height,
|
| 260 |
+
"strength": strength,
|
| 261 |
+
# seed๋ฅผ ์ง์ํ๋ ๋ชจ๋ธ/์๋ํฌ์ธํธ์ ๋ฐ๋ผ ๋ฌด์๋ ์๋ ์์
|
| 262 |
+
"seed": seed if seed != -1 else random.randint(1, 1000000000),
|
| 263 |
+
},
|
| 264 |
+
# ๋ชจ๋ธ์ด ๋ก๋ฉ ์ค์ผ ๊ฒฝ์ฐ ๊ธฐ๋ค๋ฆฌ๋๋ก ์ค์
|
| 265 |
+
"options": {"wait_for_model": True}
|
| 266 |
}
|
| 267 |
|
| 268 |
# Improved retry logic with exponential backoff
|
|
|
|
| 272 |
|
| 273 |
while current_retry < max_retries:
|
| 274 |
try:
|
| 275 |
+
response = requests.post(API_URL, headers=headers, json=payload, timeout=180)
|
|
|
|
| 276 |
|
| 277 |
+
# ๋๋ฒ๊น
์ฉ ์ ๋ณด ์ถ๋ ฅ
|
| 278 |
+
print("Response Content-Type:", response.headers.get("Content-Type"))
|
| 279 |
+
print("Response Text (snippet):", response.text[:500])
|
| 280 |
+
|
| 281 |
+
response.raise_for_status() # HTTP ์๋ฌ ์ฝ๋ ์ ์์ธ ๋ฐ์
|
| 282 |
image = Image.open(io.BytesIO(response.content))
|
| 283 |
+
|
| 284 |
print(f'Generation {key} completed successfully')
|
| 285 |
return image
|
| 286 |
|
| 287 |
+
except (requests.exceptions.Timeout,
|
| 288 |
+
requests.exceptions.ConnectionError,
|
| 289 |
+
requests.exceptions.HTTPError,
|
| 290 |
+
requests.exceptions.RequestException) as e:
|
| 291 |
current_retry += 1
|
| 292 |
if current_retry < max_retries:
|
| 293 |
wait_time = backoff_factor ** current_retry # Exponential backoff
|
| 294 |
print(f"Network error occurred: {str(e)}. Retrying in {wait_time} seconds... (Attempt {current_retry + 1}/{max_retries})")
|
| 295 |
+
time.sleep(wait_time)
|
| 296 |
continue
|
| 297 |
else:
|
| 298 |
# Detailed error message based on exception type
|
|
|
|
| 536 |
dalle.load(fn=update_network_status, outputs=network_status)
|
| 537 |
|
| 538 |
if __name__ == "__main__":
|
| 539 |
+
dalle.launch(show_api=False, share=False)
|