ImageEditPro / app.py
selfit-camera's picture
init
fce53a0
raw
history blame
23.6 kB
import gradio as gr
import threading
import os
import shutil
import tempfile
import time
from util import process_image_edit, process_local_image_edit
from nfsw import NSFWDetector
# 配置参数
NSFW_TIME_WINDOW = 5 # 时间窗口:5分钟
NSFW_LIMIT = 10 # 限制次数:6次
IP_Dict = {}
NSFW_Dict = {} # 记录每个IP的NSFW违规次数
NSFW_Time_Dict = {} # 记录每个IP在特定时间窗口的NSFW检测次数,键格式: "ip_timestamp"
def get_current_time_window():
"""
获取当前的整点时间窗口
Returns:
tuple: (窗口开始时间戳, 窗口结束时间戳)
"""
current_time = time.time()
# 获取当前时间的分钟数
current_struct = time.localtime(current_time)
current_minute = current_struct.tm_min
# 计算当前5分钟时间窗口的开始分钟
window_start_minute = (current_minute // NSFW_TIME_WINDOW) * NSFW_TIME_WINDOW
# 构建窗口开始时间
window_start_struct = time.struct_time((
current_struct.tm_year, current_struct.tm_mon, current_struct.tm_mday,
current_struct.tm_hour, window_start_minute, 0,
current_struct.tm_wday, current_struct.tm_yday, current_struct.tm_isdst
))
window_start_time = time.mktime(window_start_struct)
window_end_time = window_start_time + (NSFW_TIME_WINDOW * 60)
return window_start_time, window_end_time
def check_nsfw_rate_limit(client_ip):
"""
检查IP的NSFW检测频率限制(基于整点时间窗口)
Args:
client_ip (str): 客户端IP地址
Returns:
tuple: (是否超过限制, 剩余等待时间)
"""
current_time = time.time()
window_start_time, window_end_time = get_current_time_window()
# 清理不在当前时间窗口的记录
current_window_key = f"{client_ip}_{int(window_start_time)}"
# 如果没有当前窗口的记录,创建新的
if current_window_key not in NSFW_Time_Dict:
NSFW_Time_Dict[current_window_key] = 0
# 清理旧的窗口记录(保持内存清洁)
keys_to_remove = []
for key in NSFW_Time_Dict:
if key.startswith(client_ip + "_"):
window_time = int(key.split("_")[1])
if window_time < window_start_time:
keys_to_remove.append(key)
for key in keys_to_remove:
del NSFW_Time_Dict[key]
# 检查当前窗口是否超过限制
if NSFW_Time_Dict[current_window_key] >= NSFW_LIMIT:
# 计算到下一个时间窗口的等待时间
wait_time = window_end_time - current_time
return True, max(0, wait_time)
return False, 0
def record_nsfw_detection(client_ip):
"""
记录IP的NSFW检测时间(基于整点时间窗口)
Args:
client_ip (str): 客户端IP地址
"""
window_start_time, _ = get_current_time_window()
current_window_key = f"{client_ip}_{int(window_start_time)}"
# 增加当前窗口的计数
if current_window_key not in NSFW_Time_Dict:
NSFW_Time_Dict[current_window_key] = 0
NSFW_Time_Dict[current_window_key] += 1
# 记录到NSFW_Dict中(兼容现有逻辑)
if client_ip not in NSFW_Dict:
NSFW_Dict[client_ip] = 0
NSFW_Dict[client_ip] += 1
def get_current_window_info(client_ip):
"""
获取当前窗口的统计信息(用于调试)
Args:
client_ip (str): 客户端IP地址
Returns:
dict: 当前窗口的统计信息
"""
window_start_time, window_end_time = get_current_time_window()
current_window_key = f"{client_ip}_{int(window_start_time)}"
current_count = NSFW_Time_Dict.get(current_window_key, 0)
# 格式化时间显示
start_time_str = time.strftime("%H:%M:%S", time.localtime(window_start_time))
end_time_str = time.strftime("%H:%M:%S", time.localtime(window_end_time))
return {
"window_start": start_time_str,
"window_end": end_time_str,
"current_count": current_count,
"limit": NSFW_LIMIT,
"window_key": current_window_key
}
# 初始化NSFW检测器(从Hugging Face下载)
try:
nsfw_detector = NSFWDetector() # 自动从Hugging Face下载falconsai_yolov9_nsfw_model_quantized.pt
print("✅ NSFW检测器初始化成功")
except Exception as e:
print(f"❌ NSFW检测器初始化失败: {e}")
nsfw_detector = None
def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.Progress()):
"""
Interface function for processing image editing
"""
try:
# 提取用户IP
client_ip = request.client.host
x_forwarded_for = dict(request.headers).get('x-forwarded-for')
if x_forwarded_for:
client_ip = x_forwarded_for
if client_ip not in IP_Dict:
IP_Dict[client_ip] = 0
IP_Dict[client_ip] += 1
if input_image is None:
return None, "Please upload an image first"
if not prompt or prompt.strip() == "":
return None, "Please enter editing prompt"
# 检查prompt长度是否大于3个字符
if len(prompt.strip()) <= 3:
return None, "❌ Editing prompt must be more than 3 characters"
except Exception as e:
print(f"⚠️ Request preprocessing error: {e}")
return None, "❌ Request processing error"
# 检查图片是否包含NSFW内容
nsfw_result = None
if nsfw_detector is not None and input_image is not None:
try:
# 直接使用PIL Image对象进行检测,避免文件路径问题
nsfw_result = nsfw_detector.predict_pil_label_only(input_image)
if nsfw_result.lower() == "nsfw":
print(f"🔍 NSFW检测结果: ❌❌❌ {nsfw_result} - IP: {client_ip}")
# 检查NSFW频率限制
is_rate_limited, wait_time = check_nsfw_rate_limit(client_ip)
if is_rate_limited:
# 超过频率限制,显示等待提示并阻止继续
wait_minutes = int(wait_time / 60) + 1 # 向上取整到分钟
window_info = get_current_window_info(client_ip)
print(f"⚠️ NSFW频率限制 - IP: {client_ip}")
print(f" 时间窗口: {window_info['window_start']} - {window_info['window_end']}")
print(f" 当前计数: {window_info['current_count']}/{NSFW_LIMIT}, 需要等待 {wait_minutes} 分钟")
return None, f"❌ Please wait {wait_minutes} minutes before generating again"
else:
# 未超过频率限制,记录此次检测但允许继续处理
record_nsfw_detection(client_ip)
window_info = get_current_window_info(client_ip)
else:
print(f"🔍 NSFW检测结果: ✅✅✅ {nsfw_result} - IP: {client_ip}")
except Exception as e:
print(f"⚠️ NSFW检测失败: {e}")
# 检测失败时允许继续处理
result_url = None
status_message = ""
def progress_callback(message):
try:
nonlocal status_message
status_message = message
# 增加错误处理,防止 progress 更新失败
if progress is not None:
progress(0.5, desc=message)
except Exception as e:
print(f"⚠️ Progress update failed: {e}")
try:
# 打印成功访问的信息
print(f"✅ Processing started - IP: {client_ip}, count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}", flush=True)
# Call image editing processing function
result_url, message = process_image_edit(input_image, prompt.strip(), progress_callback)
if result_url:
print(f"✅ Processing completed successfully - IP: {client_ip}, result_url: {result_url}", flush=True)
try:
if progress is not None:
progress(1.0, desc="Processing completed")
except Exception as e:
print(f"⚠️ Final progress update failed: {e}")
return result_url, "✅ " + message
else:
print(f"❌ Processing failed - IP: {client_ip}, error: {message}", flush=True)
return None, "❌ " + message
except Exception as e:
print(f"❌ Processing exception - IP: {client_ip}, error: {str(e)}")
return None, f"❌ Error occurred during processing: {str(e)}"
# 不再需要复杂的状态管理函数,已简化为内联函数
def local_edit_interface(image_dict, prompt, request: gr.Request, progress=gr.Progress()):
"""
处理局部编辑请求
"""
try:
# 提取用户IP
client_ip = request.client.host
x_forwarded_for = dict(request.headers).get('x-forwarded-for')
if x_forwarded_for:
client_ip = x_forwarded_for
if client_ip not in IP_Dict:
IP_Dict[client_ip] = 0
IP_Dict[client_ip] += 1
if image_dict is None:
return None, "Please upload an image and draw the area to edit"
# Check if background and layers exist
if "background" not in image_dict or "layers" not in image_dict:
return None, "Please draw the area to edit on the image"
base_image = image_dict["background"]
layers = image_dict["layers"]
if not layers:
return None, "Please draw the area to edit on the image"
if not prompt or prompt.strip() == "":
return None, "Please enter editing prompt"
# Check prompt length
if len(prompt.strip()) <= 3:
return None, "❌ Editing prompt must be more than 3 characters"
except Exception as e:
print(f"⚠️ Local edit request preprocessing error: {e}")
return None, "❌ Request processing error"
# 检查图片是否包含NSFW内容
nsfw_result = None
if nsfw_detector is not None and base_image is not None:
try:
nsfw_result = nsfw_detector.predict_pil_label_only(base_image)
if nsfw_result.lower() == "nsfw":
print(f"🔍 NSFW检测结果: ❌❌❌ {nsfw_result} - IP: {client_ip}")
# 检查NSFW频率限制
is_rate_limited, wait_time = check_nsfw_rate_limit(client_ip)
if is_rate_limited:
wait_minutes = int(wait_time / 60) + 1
window_info = get_current_window_info(client_ip)
print(f"⚠️ NSFW频率限制 - IP: {client_ip}")
print(f" 时间窗口: {window_info['window_start']} - {window_info['window_end']}")
print(f" 当前计数: {window_info['current_count']}/{NSFW_LIMIT}, 需要等待 {wait_minutes} 分钟")
return None, f"❌ Please wait {wait_minutes} minutes before generating again"
else:
record_nsfw_detection(client_ip)
window_info = get_current_window_info(client_ip)
else:
print(f"🔍 NSFW检测结果: ✅✅✅ {nsfw_result} - IP: {client_ip}")
except Exception as e:
print(f"⚠️ NSFW检测失败: {e}")
# IP访问限制检查
result_url = None
status_message = ""
def progress_callback(message):
try:
nonlocal status_message
status_message = message
# 增加错误处理,防止 progress 更新失败
if progress is not None:
progress(0.5, desc=message)
except Exception as e:
print(f"⚠️ Local edit progress update failed: {e}")
try:
print(f"✅ Local editing started - IP: {client_ip}, count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}", flush=True)
# 调用局部图像编辑处理函数
result_url, message = process_local_image_edit(base_image, layers, prompt.strip(), progress_callback)
if result_url:
print(f"✅ Local editing completed successfully - IP: {client_ip}, result_url: {result_url}", flush=True)
try:
if progress is not None:
progress(1.0, desc="Processing completed")
except Exception as e:
print(f"⚠️ Local edit final progress update failed: {e}")
return result_url, "✅ " + message
else:
print(f"❌ Local editing processing failed - IP: {client_ip}, error: {message}", flush=True)
return None, "❌ " + message
except Exception as e:
print(f"❌ Local editing exception - IP: {client_ip}, error: {str(e)}")
return None, f"❌ Error occurred during processing: {str(e)}"
# Create Gradio interface
def create_app():
with gr.Blocks(
title="AI Image Editor",
theme=gr.themes.Soft(),
css="""
.main-container {
max-width: 1200px;
margin: 0 auto;
}
.upload-area {
border: 2px dashed #ccc;
border-radius: 10px;
padding: 20px;
text-align: center;
}
.result-area {
margin-top: 20px;
padding: 20px;
border-radius: 10px;
background-color: #f8f9fa;
}
.use-as-input-btn {
margin-top: 10px;
width: 100%;
}
""",
# 改善并发性能的配置
head="""
<script>
// 减少客户端状态更新频率,避免过度的 SSE 连接
if (window.gradio) {
window.gradio.update_frequency = 2000; // 2秒更新一次
}
</script>
"""
) as app:
# 减少State组件,只保留必要的
# 移除了大部分State组件以减少状态管理复杂度
gr.Markdown(
"""
# 🎨 AI Image Editor
""",
elem_classes=["main-container"]
)
with gr.Tabs():
# Global editing tab
with gr.Tab("🌍 Global Editing"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📸 Upload Image")
input_image = gr.Image(
label="Select image to edit",
type="pil",
height=512,
elem_classes=["upload-area"]
)
gr.Markdown("### ✍️ Editing Instructions")
prompt_input = gr.Textbox(
label="Enter editing prompt",
placeholder="For example: change background to beach, add rainbow, remove background, etc...",
lines=3,
max_lines=5
)
edit_button = gr.Button(
"🚀 Start Editing",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### 🎯 Editing Result")
output_image = gr.Image(
label="Edited image",
height=320,
elem_classes=["result-area"]
)
# 添加 "Use as Input" 按钮
use_as_input_btn = gr.Button(
"🔄 Use as Input",
variant="secondary",
size="sm",
elem_classes=["use-as-input-btn"]
)
status_output = gr.Textbox(
label="Processing status",
lines=2,
max_lines=3,
interactive=False
)
# Example area
gr.Markdown("### 💡 Prompt Examples")
with gr.Row():
example_prompts = [
"Change the character's background to a sunny seaside with blue waves",
"Change the character's background to New York at night with neon lights",
"Change the character's background to a fairytale castle with bright colors",
"Change background to forest",
"Change background to snow mountain"
]
for prompt in example_prompts:
gr.Button(
prompt,
size="sm"
).click(
lambda p=prompt: p,
outputs=prompt_input
)
# 绑定按钮点击事件 - 简化,移除状态管理
edit_button.click(
fn=edit_image_interface,
inputs=[input_image, prompt_input],
outputs=[output_image, status_output],
show_progress=True,
# 增加并发设置
concurrency_limit=10, # 限制并发数
api_name="global_edit"
)
# 简化 "Use as Input" 按钮,直接复制图片
def simple_use_as_input(output_img):
if output_img is not None:
return output_img
return None
use_as_input_btn.click(
fn=simple_use_as_input,
inputs=[output_image],
outputs=[input_image]
)
# Local editing tab
with gr.Tab("🖌️ Local Editing"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📸 Upload Image and Draw Edit Area")
local_input_image = gr.ImageEditor(
label="Upload image and draw mask",
type="pil",
height=512,
brush=gr.Brush(colors=["#ff0000"], default_size=60),
elem_classes=["upload-area"]
)
gr.Markdown("### ✍️ Editing Instructions")
local_prompt_input = gr.Textbox(
label="Enter local editing prompt",
placeholder="For example: change selected area hair to golden, add patterns to selected object, change selected area color, etc...",
lines=3,
max_lines=5
)
local_edit_button = gr.Button(
"🎯 Start Local Editing",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### 🎯 Editing Result")
local_output_image = gr.Image(
label="Local edited image",
height=320,
elem_classes=["result-area"]
)
# 添加 "Use as Input" 按钮
local_use_as_input_btn = gr.Button(
"🔄 Use as Input",
variant="secondary",
size="sm",
elem_classes=["use-as-input-btn"]
)
local_status_output = gr.Textbox(
label="Processing status",
lines=2,
max_lines=3,
interactive=False
)
# Local editing examples
gr.Markdown("### 💡 Local Editing Prompt Examples")
with gr.Row():
local_example_prompts = [
"Change selected area hair to golden",
"Add pattern designs to selected clothing",
"Change selected area to different material",
"Add decorations to selected object",
"Change selected area color and style"
]
for prompt in local_example_prompts:
gr.Button(
prompt,
size="sm"
).click(
lambda p=prompt: p,
outputs=local_prompt_input
)
# 绑定局部编辑按钮点击事件 - 简化,移除状态管理
local_edit_button.click(
fn=local_edit_interface,
inputs=[local_input_image, local_prompt_input],
outputs=[local_output_image, local_status_output],
show_progress=True,
# 增加并发设置
concurrency_limit=8, # 局部编辑更复杂,限制更少的并发
api_name="local_edit"
)
# 简化局部编辑 "Use as Input" 按钮
def simple_local_use_as_input(output_img):
if output_img is not None:
# 创建简单的 ImageEditor 格式
editor_data = {
"background": output_img,
"layers": [],
"composite": output_img
}
return editor_data
return None
local_use_as_input_btn.click(
fn=simple_local_use_as_input,
inputs=[local_output_image],
outputs=[local_input_image]
)
return app
if __name__ == "__main__":
app = create_app()
# 改善队列配置以处理高并发和防止 SSE 连接问题
app.queue(
default_concurrency_limit=20, # 默认并发限制
max_size=50, # 队列最大大小
api_open=False # 关闭 API 访问,减少资源消耗
)
app.launch(
server_name="0.0.0.0",
show_error=True, # 显示详细错误信息
quiet=False, # 保持日志输出
max_threads=40, # 增加线程池大小
height=800,
favicon_path=None # 减少资源加载
)