Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import gradio as gr | |
| import threading | |
| import os | |
| import shutil | |
| import tempfile | |
| import time | |
| from util import process_image_edit, process_local_image_edit | |
| from nfsw import NSFWDetector | |
| # 配置参数 | |
| NSFW_TIME_WINDOW = 5 # 时间窗口:5分钟 | |
| NSFW_LIMIT = 10 # 限制次数:6次 | |
| IP_Dict = {} | |
| NSFW_Dict = {} # 记录每个IP的NSFW违规次数 | |
| NSFW_Time_Dict = {} # 记录每个IP在特定时间窗口的NSFW检测次数,键格式: "ip_timestamp" | |
| def get_current_time_window(): | |
| """ | |
| 获取当前的整点时间窗口 | |
| Returns: | |
| tuple: (窗口开始时间戳, 窗口结束时间戳) | |
| """ | |
| current_time = time.time() | |
| # 获取当前时间的分钟数 | |
| current_struct = time.localtime(current_time) | |
| current_minute = current_struct.tm_min | |
| # 计算当前5分钟时间窗口的开始分钟 | |
| window_start_minute = (current_minute // NSFW_TIME_WINDOW) * NSFW_TIME_WINDOW | |
| # 构建窗口开始时间 | |
| window_start_struct = time.struct_time(( | |
| current_struct.tm_year, current_struct.tm_mon, current_struct.tm_mday, | |
| current_struct.tm_hour, window_start_minute, 0, | |
| current_struct.tm_wday, current_struct.tm_yday, current_struct.tm_isdst | |
| )) | |
| window_start_time = time.mktime(window_start_struct) | |
| window_end_time = window_start_time + (NSFW_TIME_WINDOW * 60) | |
| return window_start_time, window_end_time | |
| def check_nsfw_rate_limit(client_ip): | |
| """ | |
| 检查IP的NSFW检测频率限制(基于整点时间窗口) | |
| Args: | |
| client_ip (str): 客户端IP地址 | |
| Returns: | |
| tuple: (是否超过限制, 剩余等待时间) | |
| """ | |
| current_time = time.time() | |
| window_start_time, window_end_time = get_current_time_window() | |
| # 清理不在当前时间窗口的记录 | |
| current_window_key = f"{client_ip}_{int(window_start_time)}" | |
| # 如果没有当前窗口的记录,创建新的 | |
| if current_window_key not in NSFW_Time_Dict: | |
| NSFW_Time_Dict[current_window_key] = 0 | |
| # 清理旧的窗口记录(保持内存清洁) | |
| keys_to_remove = [] | |
| for key in NSFW_Time_Dict: | |
| if key.startswith(client_ip + "_"): | |
| window_time = int(key.split("_")[1]) | |
| if window_time < window_start_time: | |
| keys_to_remove.append(key) | |
| for key in keys_to_remove: | |
| del NSFW_Time_Dict[key] | |
| # 检查当前窗口是否超过限制 | |
| if NSFW_Time_Dict[current_window_key] >= NSFW_LIMIT: | |
| # 计算到下一个时间窗口的等待时间 | |
| wait_time = window_end_time - current_time | |
| return True, max(0, wait_time) | |
| return False, 0 | |
| def record_nsfw_detection(client_ip): | |
| """ | |
| 记录IP的NSFW检测时间(基于整点时间窗口) | |
| Args: | |
| client_ip (str): 客户端IP地址 | |
| """ | |
| window_start_time, _ = get_current_time_window() | |
| current_window_key = f"{client_ip}_{int(window_start_time)}" | |
| # 增加当前窗口的计数 | |
| if current_window_key not in NSFW_Time_Dict: | |
| NSFW_Time_Dict[current_window_key] = 0 | |
| NSFW_Time_Dict[current_window_key] += 1 | |
| # 记录到NSFW_Dict中(兼容现有逻辑) | |
| if client_ip not in NSFW_Dict: | |
| NSFW_Dict[client_ip] = 0 | |
| NSFW_Dict[client_ip] += 1 | |
| def get_current_window_info(client_ip): | |
| """ | |
| 获取当前窗口的统计信息(用于调试) | |
| Args: | |
| client_ip (str): 客户端IP地址 | |
| Returns: | |
| dict: 当前窗口的统计信息 | |
| """ | |
| window_start_time, window_end_time = get_current_time_window() | |
| current_window_key = f"{client_ip}_{int(window_start_time)}" | |
| current_count = NSFW_Time_Dict.get(current_window_key, 0) | |
| # 格式化时间显示 | |
| start_time_str = time.strftime("%H:%M:%S", time.localtime(window_start_time)) | |
| end_time_str = time.strftime("%H:%M:%S", time.localtime(window_end_time)) | |
| return { | |
| "window_start": start_time_str, | |
| "window_end": end_time_str, | |
| "current_count": current_count, | |
| "limit": NSFW_LIMIT, | |
| "window_key": current_window_key | |
| } | |
| # 初始化NSFW检测器(从Hugging Face下载) | |
| try: | |
| nsfw_detector = NSFWDetector() # 自动从Hugging Face下载falconsai_yolov9_nsfw_model_quantized.pt | |
| print("✅ NSFW检测器初始化成功") | |
| except Exception as e: | |
| print(f"❌ NSFW检测器初始化失败: {e}") | |
| nsfw_detector = None | |
| def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.Progress()): | |
| """ | |
| Interface function for processing image editing | |
| """ | |
| try: | |
| # 提取用户IP | |
| client_ip = request.client.host | |
| x_forwarded_for = dict(request.headers).get('x-forwarded-for') | |
| if x_forwarded_for: | |
| client_ip = x_forwarded_for | |
| if client_ip not in IP_Dict: | |
| IP_Dict[client_ip] = 0 | |
| IP_Dict[client_ip] += 1 | |
| if input_image is None: | |
| return None, "Please upload an image first" | |
| if not prompt or prompt.strip() == "": | |
| return None, "Please enter editing prompt" | |
| # 检查prompt长度是否大于3个字符 | |
| if len(prompt.strip()) <= 3: | |
| return None, "❌ Editing prompt must be more than 3 characters" | |
| except Exception as e: | |
| print(f"⚠️ Request preprocessing error: {e}") | |
| return None, "❌ Request processing error" | |
| # 检查图片是否包含NSFW内容 | |
| nsfw_result = None | |
| if nsfw_detector is not None and input_image is not None: | |
| try: | |
| # 直接使用PIL Image对象进行检测,避免文件路径问题 | |
| nsfw_result = nsfw_detector.predict_pil_label_only(input_image) | |
| if nsfw_result.lower() == "nsfw": | |
| print(f"🔍 NSFW检测结果: ❌❌❌ {nsfw_result} - IP: {client_ip}") | |
| # 检查NSFW频率限制 | |
| is_rate_limited, wait_time = check_nsfw_rate_limit(client_ip) | |
| if is_rate_limited: | |
| # 超过频率限制,显示等待提示并阻止继续 | |
| wait_minutes = int(wait_time / 60) + 1 # 向上取整到分钟 | |
| window_info = get_current_window_info(client_ip) | |
| print(f"⚠️ NSFW频率限制 - IP: {client_ip}") | |
| print(f" 时间窗口: {window_info['window_start']} - {window_info['window_end']}") | |
| print(f" 当前计数: {window_info['current_count']}/{NSFW_LIMIT}, 需要等待 {wait_minutes} 分钟") | |
| return None, f"❌ Please wait {wait_minutes} minutes before generating again" | |
| else: | |
| # 未超过频率限制,记录此次检测但允许继续处理 | |
| record_nsfw_detection(client_ip) | |
| window_info = get_current_window_info(client_ip) | |
| else: | |
| print(f"🔍 NSFW检测结果: ✅✅✅ {nsfw_result} - IP: {client_ip}") | |
| except Exception as e: | |
| print(f"⚠️ NSFW检测失败: {e}") | |
| # 检测失败时允许继续处理 | |
| result_url = None | |
| status_message = "" | |
| def progress_callback(message): | |
| try: | |
| nonlocal status_message | |
| status_message = message | |
| # 增加错误处理,防止 progress 更新失败 | |
| if progress is not None: | |
| progress(0.5, desc=message) | |
| except Exception as e: | |
| print(f"⚠️ Progress update failed: {e}") | |
| try: | |
| # 打印成功访问的信息 | |
| print(f"✅ Processing started - IP: {client_ip}, count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}", flush=True) | |
| # Call image editing processing function | |
| result_url, message = process_image_edit(input_image, prompt.strip(), progress_callback) | |
| if result_url: | |
| print(f"✅ Processing completed successfully - IP: {client_ip}, result_url: {result_url}", flush=True) | |
| try: | |
| if progress is not None: | |
| progress(1.0, desc="Processing completed") | |
| except Exception as e: | |
| print(f"⚠️ Final progress update failed: {e}") | |
| return result_url, "✅ " + message | |
| else: | |
| print(f"❌ Processing failed - IP: {client_ip}, error: {message}", flush=True) | |
| return None, "❌ " + message | |
| except Exception as e: | |
| print(f"❌ Processing exception - IP: {client_ip}, error: {str(e)}") | |
| return None, f"❌ Error occurred during processing: {str(e)}" | |
| # 不再需要复杂的状态管理函数,已简化为内联函数 | |
| def local_edit_interface(image_dict, prompt, request: gr.Request, progress=gr.Progress()): | |
| """ | |
| 处理局部编辑请求 | |
| """ | |
| try: | |
| # 提取用户IP | |
| client_ip = request.client.host | |
| x_forwarded_for = dict(request.headers).get('x-forwarded-for') | |
| if x_forwarded_for: | |
| client_ip = x_forwarded_for | |
| if client_ip not in IP_Dict: | |
| IP_Dict[client_ip] = 0 | |
| IP_Dict[client_ip] += 1 | |
| if image_dict is None: | |
| return None, "Please upload an image and draw the area to edit" | |
| # Check if background and layers exist | |
| if "background" not in image_dict or "layers" not in image_dict: | |
| return None, "Please draw the area to edit on the image" | |
| base_image = image_dict["background"] | |
| layers = image_dict["layers"] | |
| if not layers: | |
| return None, "Please draw the area to edit on the image" | |
| if not prompt or prompt.strip() == "": | |
| return None, "Please enter editing prompt" | |
| # Check prompt length | |
| if len(prompt.strip()) <= 3: | |
| return None, "❌ Editing prompt must be more than 3 characters" | |
| except Exception as e: | |
| print(f"⚠️ Local edit request preprocessing error: {e}") | |
| return None, "❌ Request processing error" | |
| # 检查图片是否包含NSFW内容 | |
| nsfw_result = None | |
| if nsfw_detector is not None and base_image is not None: | |
| try: | |
| nsfw_result = nsfw_detector.predict_pil_label_only(base_image) | |
| if nsfw_result.lower() == "nsfw": | |
| print(f"🔍 NSFW检测结果: ❌❌❌ {nsfw_result} - IP: {client_ip}") | |
| # 检查NSFW频率限制 | |
| is_rate_limited, wait_time = check_nsfw_rate_limit(client_ip) | |
| if is_rate_limited: | |
| wait_minutes = int(wait_time / 60) + 1 | |
| window_info = get_current_window_info(client_ip) | |
| print(f"⚠️ NSFW频率限制 - IP: {client_ip}") | |
| print(f" 时间窗口: {window_info['window_start']} - {window_info['window_end']}") | |
| print(f" 当前计数: {window_info['current_count']}/{NSFW_LIMIT}, 需要等待 {wait_minutes} 分钟") | |
| return None, f"❌ Please wait {wait_minutes} minutes before generating again" | |
| else: | |
| record_nsfw_detection(client_ip) | |
| window_info = get_current_window_info(client_ip) | |
| else: | |
| print(f"🔍 NSFW检测结果: ✅✅✅ {nsfw_result} - IP: {client_ip}") | |
| except Exception as e: | |
| print(f"⚠️ NSFW检测失败: {e}") | |
| # IP访问限制检查 | |
| result_url = None | |
| status_message = "" | |
| def progress_callback(message): | |
| try: | |
| nonlocal status_message | |
| status_message = message | |
| # 增加错误处理,防止 progress 更新失败 | |
| if progress is not None: | |
| progress(0.5, desc=message) | |
| except Exception as e: | |
| print(f"⚠️ Local edit progress update failed: {e}") | |
| try: | |
| print(f"✅ Local editing started - IP: {client_ip}, count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}", flush=True) | |
| # 调用局部图像编辑处理函数 | |
| result_url, message = process_local_image_edit(base_image, layers, prompt.strip(), progress_callback) | |
| if result_url: | |
| print(f"✅ Local editing completed successfully - IP: {client_ip}, result_url: {result_url}", flush=True) | |
| try: | |
| if progress is not None: | |
| progress(1.0, desc="Processing completed") | |
| except Exception as e: | |
| print(f"⚠️ Local edit final progress update failed: {e}") | |
| return result_url, "✅ " + message | |
| else: | |
| print(f"❌ Local editing processing failed - IP: {client_ip}, error: {message}", flush=True) | |
| return None, "❌ " + message | |
| except Exception as e: | |
| print(f"❌ Local editing exception - IP: {client_ip}, error: {str(e)}") | |
| return None, f"❌ Error occurred during processing: {str(e)}" | |
| # Create Gradio interface | |
| def create_app(): | |
| with gr.Blocks( | |
| title="AI Image Editor", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .upload-area { | |
| border: 2px dashed #ccc; | |
| border-radius: 10px; | |
| padding: 20px; | |
| text-align: center; | |
| } | |
| .result-area { | |
| margin-top: 20px; | |
| padding: 20px; | |
| border-radius: 10px; | |
| background-color: #f8f9fa; | |
| } | |
| .use-as-input-btn { | |
| margin-top: 10px; | |
| width: 100%; | |
| } | |
| """, | |
| # 改善并发性能的配置 | |
| head=""" | |
| <script> | |
| // 减少客户端状态更新频率,避免过度的 SSE 连接 | |
| if (window.gradio) { | |
| window.gradio.update_frequency = 2000; // 2秒更新一次 | |
| } | |
| </script> | |
| """ | |
| ) as app: | |
| # 减少State组件,只保留必要的 | |
| # 移除了大部分State组件以减少状态管理复杂度 | |
| gr.Markdown( | |
| """ | |
| # 🎨 AI Image Editor | |
| """, | |
| elem_classes=["main-container"] | |
| ) | |
| with gr.Tabs(): | |
| # Global editing tab | |
| with gr.Tab("🌍 Global Editing"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📸 Upload Image") | |
| input_image = gr.Image( | |
| label="Select image to edit", | |
| type="pil", | |
| height=512, | |
| elem_classes=["upload-area"] | |
| ) | |
| gr.Markdown("### ✍️ Editing Instructions") | |
| prompt_input = gr.Textbox( | |
| label="Enter editing prompt", | |
| placeholder="For example: change background to beach, add rainbow, remove background, etc...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| edit_button = gr.Button( | |
| "🚀 Start Editing", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎯 Editing Result") | |
| output_image = gr.Image( | |
| label="Edited image", | |
| height=320, | |
| elem_classes=["result-area"] | |
| ) | |
| # 添加 "Use as Input" 按钮 | |
| use_as_input_btn = gr.Button( | |
| "🔄 Use as Input", | |
| variant="secondary", | |
| size="sm", | |
| elem_classes=["use-as-input-btn"] | |
| ) | |
| status_output = gr.Textbox( | |
| label="Processing status", | |
| lines=2, | |
| max_lines=3, | |
| interactive=False | |
| ) | |
| # Example area | |
| gr.Markdown("### 💡 Prompt Examples") | |
| with gr.Row(): | |
| example_prompts = [ | |
| "Change the character's background to a sunny seaside with blue waves", | |
| "Change the character's background to New York at night with neon lights", | |
| "Change the character's background to a fairytale castle with bright colors", | |
| "Change background to forest", | |
| "Change background to snow mountain" | |
| ] | |
| for prompt in example_prompts: | |
| gr.Button( | |
| prompt, | |
| size="sm" | |
| ).click( | |
| lambda p=prompt: p, | |
| outputs=prompt_input | |
| ) | |
| # 绑定按钮点击事件 - 简化,移除状态管理 | |
| edit_button.click( | |
| fn=edit_image_interface, | |
| inputs=[input_image, prompt_input], | |
| outputs=[output_image, status_output], | |
| show_progress=True, | |
| # 增加并发设置 | |
| concurrency_limit=10, # 限制并发数 | |
| api_name="global_edit" | |
| ) | |
| # 简化 "Use as Input" 按钮,直接复制图片 | |
| def simple_use_as_input(output_img): | |
| if output_img is not None: | |
| return output_img | |
| return None | |
| use_as_input_btn.click( | |
| fn=simple_use_as_input, | |
| inputs=[output_image], | |
| outputs=[input_image] | |
| ) | |
| # Local editing tab | |
| with gr.Tab("🖌️ Local Editing"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📸 Upload Image and Draw Edit Area") | |
| local_input_image = gr.ImageEditor( | |
| label="Upload image and draw mask", | |
| type="pil", | |
| height=512, | |
| brush=gr.Brush(colors=["#ff0000"], default_size=60), | |
| elem_classes=["upload-area"] | |
| ) | |
| gr.Markdown("### ✍️ Editing Instructions") | |
| local_prompt_input = gr.Textbox( | |
| label="Enter local editing prompt", | |
| placeholder="For example: change selected area hair to golden, add patterns to selected object, change selected area color, etc...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| local_edit_button = gr.Button( | |
| "🎯 Start Local Editing", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎯 Editing Result") | |
| local_output_image = gr.Image( | |
| label="Local edited image", | |
| height=320, | |
| elem_classes=["result-area"] | |
| ) | |
| # 添加 "Use as Input" 按钮 | |
| local_use_as_input_btn = gr.Button( | |
| "🔄 Use as Input", | |
| variant="secondary", | |
| size="sm", | |
| elem_classes=["use-as-input-btn"] | |
| ) | |
| local_status_output = gr.Textbox( | |
| label="Processing status", | |
| lines=2, | |
| max_lines=3, | |
| interactive=False | |
| ) | |
| # Local editing examples | |
| gr.Markdown("### 💡 Local Editing Prompt Examples") | |
| with gr.Row(): | |
| local_example_prompts = [ | |
| "Change selected area hair to golden", | |
| "Add pattern designs to selected clothing", | |
| "Change selected area to different material", | |
| "Add decorations to selected object", | |
| "Change selected area color and style" | |
| ] | |
| for prompt in local_example_prompts: | |
| gr.Button( | |
| prompt, | |
| size="sm" | |
| ).click( | |
| lambda p=prompt: p, | |
| outputs=local_prompt_input | |
| ) | |
| # 绑定局部编辑按钮点击事件 - 简化,移除状态管理 | |
| local_edit_button.click( | |
| fn=local_edit_interface, | |
| inputs=[local_input_image, local_prompt_input], | |
| outputs=[local_output_image, local_status_output], | |
| show_progress=True, | |
| # 增加并发设置 | |
| concurrency_limit=8, # 局部编辑更复杂,限制更少的并发 | |
| api_name="local_edit" | |
| ) | |
| # 简化局部编辑 "Use as Input" 按钮 | |
| def simple_local_use_as_input(output_img): | |
| if output_img is not None: | |
| # 创建简单的 ImageEditor 格式 | |
| editor_data = { | |
| "background": output_img, | |
| "layers": [], | |
| "composite": output_img | |
| } | |
| return editor_data | |
| return None | |
| local_use_as_input_btn.click( | |
| fn=simple_local_use_as_input, | |
| inputs=[local_output_image], | |
| outputs=[local_input_image] | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_app() | |
| # 改善队列配置以处理高并发和防止 SSE 连接问题 | |
| app.queue( | |
| default_concurrency_limit=20, # 默认并发限制 | |
| max_size=50, # 队列最大大小 | |
| api_open=False # 关闭 API 访问,减少资源消耗 | |
| ) | |
| app.launch( | |
| server_name="0.0.0.0", | |
| show_error=True, # 显示详细错误信息 | |
| quiet=False, # 保持日志输出 | |
| max_threads=40, # 增加线程池大小 | |
| height=800, | |
| favicon_path=None # 减少资源加载 | |
| ) | |