ImageEditPro / util.py
selfit-camera's picture
init
f7d53f5
import os
import sys
import cv2
import json
import random
import time
import datetime
import requests
import func_timeout
import numpy as np
import gradio as gr
import boto3
import tempfile
import io
import uuid
from botocore.client import Config
from PIL import Image
# TOKEN = os.environ['TOKEN']
# APIKEY = os.environ['APIKEY']
# UKAPIURL = os.environ['UKAPIURL']
OneKey = os.environ['OneKey'].strip()
OneKey = OneKey.split("#")
TOKEN = OneKey[0]
APIKEY = OneKey[1]
UKAPIURL = OneKey[2]
LLMKEY = OneKey[3]
R2_ACCESS_KEY = OneKey[4]
R2_SECRET_KEY = OneKey[5]
R2_ENDPOINT = OneKey[6]
# tmpFolder is no longer needed since we upload directly from memory
# tmpFolder = "tmp"
# os.makedirs(tmpFolder, exist_ok=True)
# Legacy function - no longer used since we upload directly from memory
# def upload_user_img(clientIp, timeId, img):
# fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
# local_path = os.path.join(tmpFolder, fileName)
# img = cv2.imread(img)
# cv2.imwrite(os.path.join(tmpFolder, fileName), img)
#
# json_data = {
# "token": TOKEN,
# "input1": fileName,
# "input2": "",
# "protocol": "",
# "cloud": "ali"
# }
#
# session = requests.session()
# ret = requests.post(
# f"{UKAPIURL}/upload",
# headers={'Content-Type': 'application/json'},
# json=json_data
# )
#
# res = ""
# if ret.status_code==200:
# if 'upload1' in ret.json():
# upload_url = ret.json()['upload1']
# headers = {'Content-Type': 'image/jpeg'}
# response = session.put(upload_url, data=open(local_path, 'rb').read(), headers=headers)
# # print(response.status_code)
# if response.status_code == 200:
# res = upload_url
# if os.path.exists(local_path):
# os.remove(local_path)
# return res
class R2Api:
def __init__(self, session=None):
super().__init__()
self.R2_BUCKET = "omni-creator"
self.domain = "https://www.omnicreator.net/"
self.R2_ACCESS_KEY = R2_ACCESS_KEY
self.R2_SECRET_KEY = R2_SECRET_KEY
self.R2_ENDPOINT = R2_ENDPOINT
self.client = boto3.client(
"s3",
endpoint_url=self.R2_ENDPOINT,
aws_access_key_id=self.R2_ACCESS_KEY,
aws_secret_access_key=self.R2_SECRET_KEY,
config=Config(signature_version="s3v4")
)
self.session = requests.Session() if session is None else session
def upload_from_memory(self, image_data, filename, content_type='image/jpeg'):
"""
Upload image data directly from memory to R2
Args:
image_data (bytes): Image data in bytes
filename (str): Filename for the uploaded file
content_type (str): MIME type of the image
Returns:
str: URL of the uploaded file
"""
t1 = time.time()
headers = {"Content-Type": content_type}
cloud_path = f"ImageEdit/Uploads/{str(datetime.date.today())}/{filename}"
url = self.client.generate_presigned_url(
"put_object",
Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": content_type},
ExpiresIn=604800
)
retry_count = 0
while retry_count < 3:
try:
response = self.session.put(url, data=image_data, headers=headers, timeout=15)
if response.status_code == 200:
break
else:
print(f"⚠️ Upload failed with status code: {response.status_code}")
retry_count += 1
except (requests.exceptions.Timeout, requests.exceptions.RequestException) as e:
print(f"⚠️ Upload retry {retry_count + 1}/3 failed: {e}")
retry_count += 1
if retry_count == 3:
raise Exception(f'Failed to upload file to R2 after 3 retries! Last error: {str(e)}')
time.sleep(1) # 等待1秒后重试
continue
print("upload_from_memory time is ====>", time.time() - t1)
return f"{self.domain}{cloud_path}"
def upload_user_img_r2(clientIp, timeId, pil_image):
"""
Upload PIL Image directly to R2 without saving to local file
Args:
clientIp (str): Client IP address
timeId (int): Timestamp
pil_image (PIL.Image): PIL Image object
Returns:
str: Uploaded URL
"""
# Generate unique filename using UUID to prevent file conflicts in concurrent environment
unique_id = str(uuid.uuid4())
fileName = f"user_img_{unique_id}_{timeId}.jpg"
# Convert PIL Image to bytes
img_buffer = io.BytesIO()
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
pil_image.save(img_buffer, format='JPEG', quality=95)
img_data = img_buffer.getvalue()
# Upload directly from memory
res = R2Api().upload_from_memory(img_data, fileName, 'image/jpeg')
return res
def create_mask_from_layers(base_image, layers):
"""
Create mask image from ImageEditor layers
Args:
base_image (PIL.Image): Original image
layers (list): ImageEditor layer data
Returns:
PIL.Image: Black and white mask image
"""
from PIL import Image, ImageDraw
import numpy as np
# Create blank mask with same size as original image
mask = Image.new('L', base_image.size, 0) # 'L' mode is grayscale, 0 is black
if not layers:
return mask
# Iterate through all layers, set drawn areas to white
for layer in layers:
if layer is not None:
# Convert layer to numpy array
layer_array = np.array(layer)
# Check layer format
if len(layer_array.shape) == 3: # RGB/RGBA format
# If RGBA, check alpha channel
if layer_array.shape[2] == 4:
# Use alpha channel as mask
alpha_channel = layer_array[:, :, 3]
# Set non-transparent areas (alpha > 0) to white
mask_array = np.where(alpha_channel > 0, 255, 0).astype(np.uint8)
else:
# RGB format, check if not pure black (0,0,0)
# Assume drawn areas are non-black
non_black = np.any(layer_array > 0, axis=2)
mask_array = np.where(non_black, 255, 0).astype(np.uint8)
elif len(layer_array.shape) == 2: # Grayscale
# Use grayscale values directly, set non-zero areas to white
mask_array = np.where(layer_array > 0, 255, 0).astype(np.uint8)
else:
continue
# Convert mask_array to PIL image and merge into total mask
layer_mask = Image.fromarray(mask_array, mode='L')
# Resize to match original image
if layer_mask.size != base_image.size:
layer_mask = layer_mask.resize(base_image.size, Image.LANCZOS)
# Merge masks (use maximum value to ensure all drawn areas are included)
mask_array_current = np.array(mask)
layer_mask_array = np.array(layer_mask)
combined_mask_array = np.maximum(mask_array_current, layer_mask_array)
mask = Image.fromarray(combined_mask_array, mode='L')
return mask
def upload_mask_image_r2(client_ip, time_id, mask_image):
"""
Upload mask image to R2 directly from memory
Args:
client_ip (str): Client IP
time_id (int): Timestamp
mask_image (PIL.Image): Mask image
Returns:
str: Uploaded URL
"""
# Generate unique filename using UUID to prevent file conflicts in concurrent environment
unique_id = str(uuid.uuid4())
file_name = f"mask_img_{unique_id}_{time_id}.png"
try:
# Convert mask image to bytes
img_buffer = io.BytesIO()
mask_image.save(img_buffer, format='PNG')
img_data = img_buffer.getvalue()
# Upload directly from memory
res = R2Api().upload_from_memory(img_data, file_name, 'image/png')
return res
except Exception as e:
print(f"Failed to upload mask image: {e}")
return None
def submit_image_edit_task(user_image_url, prompt, task_type="80", mask_image_url="", reference_image_url=""):
"""
Submit image editing task with improved error handling using API v2
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"user_image": user_image_url,
"user_mask": mask_image_url,
"type": task_type,
"text": prompt,
"user_uuid": APIKEY,
"priority": 0,
"secret_key": "219ngu"
}
if reference_image_url:
data["user_image2"] = reference_image_url
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
response = requests.post(
f'{UKAPIURL}/public_image_edit_v2',
headers=headers,
json=data,
timeout=30 # 增加超时时间
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
return result['data']['task_id'], None
else:
return None, f"API Error: {result.get('message', 'Unknown error')}"
elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试
retry_count += 1
if retry_count < max_retries:
print(f"⚠️ Server error {response.status_code}, retrying {retry_count}/{max_retries}")
time.sleep(2) # 等待2秒后重试
continue
else:
return None, f"HTTP Error after {max_retries} retries: {response.status_code}"
else:
return None, f"HTTP Error: {response.status_code}"
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
retry_count += 1
if retry_count < max_retries:
print(f"⚠️ Network error, retrying {retry_count}/{max_retries}: {e}")
time.sleep(2)
continue
else:
return None, f"Network error after {max_retries} retries: {str(e)}"
except Exception as e:
return None, f"Request Exception: {str(e)}"
return None, f"Failed after {max_retries} retries"
def check_task_status(task_id):
"""
Query task status with improved error handling using API v2
"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {APIKEY}'
}
data = {
"task_id": task_id
}
retry_count = 0
max_retries = 2 # 状态查询重试次数少一些
while retry_count < max_retries:
try:
response = requests.post(
f'{UKAPIURL}/status_image_edit_v2',
headers=headers,
json=data,
timeout=15 # 状态查询超时时间短一些
)
if response.status_code == 200:
result = response.json()
if result.get('code') == 0:
task_data = result['data']
status = task_data['status']
image_url = task_data.get('image_url')
# Extract and log queue information for better user feedback
queue_info = task_data.get('queue_info', {})
if queue_info:
tasks_ahead = queue_info.get('tasks_ahead', 0)
current_priority = queue_info.get('current_priority', 0)
description = queue_info.get('description', '')
# print(f"📊 Queue Status - Tasks ahead: {tasks_ahead}, Priority: {current_priority}, Status: {status}")
return status, image_url, task_data
else:
return 'error', None, result.get('message', 'Unknown error')
elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试
retry_count += 1
if retry_count < max_retries:
print(f"⚠️ Status check server error {response.status_code}, retrying {retry_count}/{max_retries}")
time.sleep(1) # 状态查询重试间隔短一些
continue
else:
return 'error', None, f"HTTP Error after {max_retries} retries: {response.status_code}"
else:
return 'error', None, f"HTTP Error: {response.status_code}"
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
retry_count += 1
if retry_count < max_retries:
print(f"⚠️ Status check network error, retrying {retry_count}/{max_retries}: {e}")
time.sleep(1)
continue
else:
return 'error', None, f"Network error after {max_retries} retries: {str(e)}"
except Exception as e:
return 'error', None, f"Request Exception: {str(e)}"
return 'error', None, f"Failed after {max_retries} retries"
def process_image_edit(img_input, prompt, reference_image=None, progress_callback=None):
"""
Complete process for image editing
Args:
img_input: Can be file path (str) or PIL Image object
prompt: Editing instructions
progress_callback: Progress callback function
"""
try:
# Generate client IP and timestamp
client_ip = "127.0.0.1" # Default IP
time_id = int(time.time())
# Process input image - supports PIL Image and file path
if hasattr(img_input, 'save'): # PIL Image object
pil_image = img_input
print(f"💾 Using PIL Image directly from memory")
else:
# Load from file path
pil_image = Image.open(img_input)
print(f"📁 Loaded image from file: {img_input}")
if progress_callback:
progress_callback("uploading image...")
# Upload user image directly from memory
uploaded_url = upload_user_img_r2(client_ip, time_id, pil_image)
if not uploaded_url:
return None, None, "image upload failed", None
# Extract actual image URL from upload URL
if "?" in uploaded_url:
uploaded_url = uploaded_url.split("?")[0]
if progress_callback:
progress_callback("submitting edit task...")
reference_url = ""
if reference_image is not None:
try:
if progress_callback:
progress_callback("uploading reference image...")
if hasattr(reference_image, 'save'):
reference_pil = reference_image
else:
reference_pil = Image.open(reference_image)
reference_url = upload_user_img_r2(client_ip, time_id, reference_pil)
if not reference_url:
return None, None, "reference image upload failed", None
if "?" in reference_url:
reference_url = reference_url.split("?")[0]
except Exception as e:
return None, None, f"reference image processing failed: {str(e)}", None
# Submit image editing task
task_id, error = submit_image_edit_task(uploaded_url, prompt, reference_image_url=reference_url)
if error:
return None, None, error, None
if progress_callback:
progress_callback(f"task submitted, ID: {task_id}, processing...")
# Wait for task completion
max_attempts = 60 # Wait up to 10 minutes
task_uuid = None
for attempt in range(max_attempts):
status, output_url, task_data = check_task_status(task_id)
# Extract task_uuid from task_data
if task_data and isinstance(task_data, dict):
task_uuid = task_data.get('uuid', None)
if status == 'completed':
if output_url:
return uploaded_url, output_url, "image edit completed", task_uuid
else:
return None, None, "Task completed but no result image returned", task_uuid
elif status == 'error' or status == 'failed':
return None, None, f"task processing failed: {task_data}", task_uuid
elif status in ['queued', 'processing', 'running', 'created', 'working']:
# Enhanced progress message with queue info and website promotion
if progress_callback and task_data and isinstance(task_data, dict):
queue_info = task_data.get('queue_info', {})
if queue_info and status in ['queued', 'created']:
tasks_ahead = queue_info.get('tasks_ahead', 0)
current_priority = queue_info.get('current_priority', 0)
if tasks_ahead > 0:
progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator")
else:
progress_callback(f"🚀 Processing your image editing request...")
elif status == 'processing':
progress_callback(f"🎨 AI is processing... Please wait")
elif status in ['running', 'working']:
progress_callback(f"⚡ Generating... Almost done")
else:
progress_callback(f"📋 Task status: {status}")
else:
if progress_callback:
progress_callback(f"task processing... (status: {status})")
time.sleep(1)
else:
if progress_callback:
progress_callback(f"unknown status: {status}")
time.sleep(1)
return None, None, "task processing timeout", task_uuid
except Exception as e:
return None, None, f"error occurred during processing: {str(e)}", None
def process_local_image_edit(base_image, layers, prompt, reference_image=None, progress_callback=None, use_example_mask=None):
"""
处理局部图片编辑的完整流程
Args:
base_image (PIL.Image): 原始图片
layers (list): ImageEditor的层数据
prompt (str): 编辑指令
progress_callback: 进度回调函数
"""
try:
# Generate client IP and timestamp
client_ip = "127.0.0.1" # Default IP
time_id = int(time.time())
if progress_callback:
progress_callback("creating mask image...")
# Check if we should use example mask (backdoor for example case)
if use_example_mask:
# Load local mask file for example
try:
from PIL import Image
import os
# Check if base_image is valid
if base_image is None:
return None, None, "Base image is None, cannot process example mask", None
if os.path.exists(use_example_mask):
mask_image = Image.open(use_example_mask)
# Ensure mask has same size as base image
if hasattr(base_image, 'size') and mask_image.size != base_image.size:
mask_image = mask_image.resize(base_image.size)
# Ensure mask is in L mode (grayscale)
if mask_image.mode != 'L':
mask_image = mask_image.convert('L')
print(f"🎭 Using example mask from: {use_example_mask}, size: {mask_image.size}")
else:
return None, None, f"Example mask file not found: {use_example_mask}", None
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"Failed to load example mask: {str(e)}", None
else:
# Normal case: create mask from layers
mask_image = create_mask_from_layers(base_image, layers)
# 检查mask是否有内容
mask_array = np.array(mask_image)
if np.max(mask_array) == 0:
return None, None, "please draw mask", None
# Print mask statistics
if not use_example_mask:
print(f"📝 创建mask图片成功,绘制区域像素数: {np.sum(mask_array > 0)}")
else:
mask_array = np.array(mask_image)
print(f"🎭 Example mask loaded successfully, mask pixels: {np.sum(mask_array > 0)}")
if progress_callback:
progress_callback("uploading original image...")
# 直接从内存上传原始图片
uploaded_url = upload_user_img_r2(client_ip, time_id, base_image)
if not uploaded_url:
return None, None, "original image upload failed", None
# 从上传 URL 中提取实际的图片 URL
if "?" in uploaded_url:
uploaded_url = uploaded_url.split("?")[0]
if progress_callback:
progress_callback("uploading mask image...")
# 直接从内存上传mask图片
mask_url = upload_mask_image_r2(client_ip, time_id, mask_image)
if not mask_url:
return None, None, "mask image upload failed", None
# 从上传 URL 中提取实际的图片 URL
if "?" in mask_url:
mask_url = mask_url.split("?")[0]
reference_url = ""
if reference_image is not None:
try:
if progress_callback:
progress_callback("uploading reference image...")
if hasattr(reference_image, 'save'):
reference_pil = reference_image
else:
reference_pil = Image.open(reference_image)
reference_url = upload_user_img_r2(client_ip, time_id, reference_pil)
if not reference_url:
return None, None, "reference image upload failed", None
if "?" in reference_url:
reference_url = reference_url.split("?")[0]
except Exception as e:
return None, None, f"reference image processing failed: {str(e)}", None
print(f"📤 图片上传成功:")
print(f" 原始图片: {uploaded_url}")
print(f" Mask图片: {mask_url}")
if reference_url:
print(f" 参考图片: {reference_url}")
if progress_callback:
progress_callback("submitting local edit task...")
# 提交局部图片编辑任务 (task_type=81)
task_id, error = submit_image_edit_task(
uploaded_url,
prompt,
task_type="81",
mask_image_url=mask_url,
reference_image_url=reference_url
)
if error:
return None, None, error, None
if progress_callback:
progress_callback(f"task submitted, ID: {task_id}, processing...")
print(f"🚀 局部编辑任务已提交,任务ID: {task_id}")
# Wait for task completion
max_attempts = 60 # Wait up to 10 minutes
task_uuid = None
for attempt in range(max_attempts):
status, output_url, task_data = check_task_status(task_id)
# Extract task_uuid from task_data
if task_data and isinstance(task_data, dict):
task_uuid = task_data.get('uuid', None)
if status == 'completed':
if output_url:
print(f"✅ 局部编辑任务完成,结果: {output_url}")
return uploaded_url, output_url, "local image edit completed", task_uuid
else:
return None, None, "task completed but no result image returned", task_uuid
elif status == 'error' or status == 'failed':
return None, None, f"task processing failed: {task_data}", task_uuid
elif status in ['queued', 'processing', 'running', 'created', 'working']:
# Enhanced progress message with queue info and website promotion
if progress_callback and task_data and isinstance(task_data, dict):
queue_info = task_data.get('queue_info', {})
if queue_info and status in ['queued', 'created']:
tasks_ahead = queue_info.get('tasks_ahead', 0)
current_priority = queue_info.get('current_priority', 0)
if tasks_ahead > 0:
progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator")
else:
progress_callback(f"🚀 Processing your local editing request...")
elif status == 'processing':
progress_callback(f"🎨 AI is processing... Please wait")
elif status in ['running', 'working']:
progress_callback(f"⚡ Generating... Almost done")
else:
progress_callback(f"📋 Task status: {status}")
else:
if progress_callback:
progress_callback(f"processing... (status: {status})")
time.sleep(1) # Wait 1 second before retry
else:
if progress_callback:
progress_callback(f"unknown status: {status}")
time.sleep(1)
return None, None, "task processing timeout", task_uuid
except Exception as e:
print(f"❌ 局部编辑处理异常: {str(e)}")
return None, None, f"error occurred during processing: {str(e)}", None
def download_and_check_result_nsfw(image_url, nsfw_detector=None):
"""
下载结果图片并进行NSFW检测
Args:
image_url (str): 结果图片URL
nsfw_detector: NSFW检测器实例
Returns:
tuple: (is_nsfw, error_message)
"""
if nsfw_detector is None:
return False, None
try:
# 下载图片
response = requests.get(image_url, timeout=30)
if response.status_code != 200:
return False, f"Failed to download result image: HTTP {response.status_code}"
# 将图片数据转换为PIL Image
image_data = io.BytesIO(response.content)
result_image = Image.open(image_data)
# 进行NSFW检测
nsfw_result = nsfw_detector.predict_pil_label_only(result_image)
is_nsfw = nsfw_result.lower() == "nsfw"
print(f"🔍 结果图片NSFW检测: {'❌❌❌ ' + nsfw_result if is_nsfw else '✅✅✅ ' + nsfw_result}")
return is_nsfw, None
except Exception as e:
print(f"⚠️ 结果图片NSFW检测失败: {e}")
return False, f"Failed to check result image: {str(e)}"
if __name__ == "__main__":
pass