Spaces:
Running
Running
| import os | |
| import sys | |
| import cv2 | |
| import json | |
| import random | |
| import time | |
| import datetime | |
| import requests | |
| import func_timeout | |
| import numpy as np | |
| import gradio as gr | |
| import boto3 | |
| import tempfile | |
| import io | |
| import uuid | |
| from botocore.client import Config | |
| from PIL import Image | |
| # TOKEN = os.environ['TOKEN'] | |
| # APIKEY = os.environ['APIKEY'] | |
| # UKAPIURL = os.environ['UKAPIURL'] | |
| OneKey = os.environ['OneKey'].strip() | |
| OneKey = OneKey.split("#") | |
| TOKEN = OneKey[0] | |
| APIKEY = OneKey[1] | |
| UKAPIURL = OneKey[2] | |
| LLMKEY = OneKey[3] | |
| R2_ACCESS_KEY = OneKey[4] | |
| R2_SECRET_KEY = OneKey[5] | |
| R2_ENDPOINT = OneKey[6] | |
| # tmpFolder is no longer needed since we upload directly from memory | |
| # tmpFolder = "tmp" | |
| # os.makedirs(tmpFolder, exist_ok=True) | |
| # Legacy function - no longer used since we upload directly from memory | |
| # def upload_user_img(clientIp, timeId, img): | |
| # fileName = clientIp.replace(".", "")+str(timeId)+".jpg" | |
| # local_path = os.path.join(tmpFolder, fileName) | |
| # img = cv2.imread(img) | |
| # cv2.imwrite(os.path.join(tmpFolder, fileName), img) | |
| # | |
| # json_data = { | |
| # "token": TOKEN, | |
| # "input1": fileName, | |
| # "input2": "", | |
| # "protocol": "", | |
| # "cloud": "ali" | |
| # } | |
| # | |
| # session = requests.session() | |
| # ret = requests.post( | |
| # f"{UKAPIURL}/upload", | |
| # headers={'Content-Type': 'application/json'}, | |
| # json=json_data | |
| # ) | |
| # | |
| # res = "" | |
| # if ret.status_code==200: | |
| # if 'upload1' in ret.json(): | |
| # upload_url = ret.json()['upload1'] | |
| # headers = {'Content-Type': 'image/jpeg'} | |
| # response = session.put(upload_url, data=open(local_path, 'rb').read(), headers=headers) | |
| # # print(response.status_code) | |
| # if response.status_code == 200: | |
| # res = upload_url | |
| # if os.path.exists(local_path): | |
| # os.remove(local_path) | |
| # return res | |
| class R2Api: | |
| def __init__(self, session=None): | |
| super().__init__() | |
| self.R2_BUCKET = "omni-creator" | |
| self.domain = "https://www.omnicreator.net/" | |
| self.R2_ACCESS_KEY = R2_ACCESS_KEY | |
| self.R2_SECRET_KEY = R2_SECRET_KEY | |
| self.R2_ENDPOINT = R2_ENDPOINT | |
| self.client = boto3.client( | |
| "s3", | |
| endpoint_url=self.R2_ENDPOINT, | |
| aws_access_key_id=self.R2_ACCESS_KEY, | |
| aws_secret_access_key=self.R2_SECRET_KEY, | |
| config=Config(signature_version="s3v4") | |
| ) | |
| self.session = requests.Session() if session is None else session | |
| def upload_from_memory(self, image_data, filename, content_type='image/jpeg'): | |
| """ | |
| Upload image data directly from memory to R2 | |
| Args: | |
| image_data (bytes): Image data in bytes | |
| filename (str): Filename for the uploaded file | |
| content_type (str): MIME type of the image | |
| Returns: | |
| str: URL of the uploaded file | |
| """ | |
| t1 = time.time() | |
| headers = {"Content-Type": content_type} | |
| cloud_path = f"ImageEdit/Uploads/{str(datetime.date.today())}/{filename}" | |
| url = self.client.generate_presigned_url( | |
| "put_object", | |
| Params={"Bucket": self.R2_BUCKET, "Key": cloud_path, "ContentType": content_type}, | |
| ExpiresIn=604800 | |
| ) | |
| retry_count = 0 | |
| while retry_count < 3: | |
| try: | |
| response = self.session.put(url, data=image_data, headers=headers, timeout=15) | |
| if response.status_code == 200: | |
| break | |
| else: | |
| print(f"⚠️ Upload failed with status code: {response.status_code}") | |
| retry_count += 1 | |
| except (requests.exceptions.Timeout, requests.exceptions.RequestException) as e: | |
| print(f"⚠️ Upload retry {retry_count + 1}/3 failed: {e}") | |
| retry_count += 1 | |
| if retry_count == 3: | |
| raise Exception(f'Failed to upload file to R2 after 3 retries! Last error: {str(e)}') | |
| time.sleep(1) # 等待1秒后重试 | |
| continue | |
| print("upload_from_memory time is ====>", time.time() - t1) | |
| return f"{self.domain}{cloud_path}" | |
| def upload_user_img_r2(clientIp, timeId, pil_image): | |
| """ | |
| Upload PIL Image directly to R2 without saving to local file | |
| Args: | |
| clientIp (str): Client IP address | |
| timeId (int): Timestamp | |
| pil_image (PIL.Image): PIL Image object | |
| Returns: | |
| str: Uploaded URL | |
| """ | |
| # Generate unique filename using UUID to prevent file conflicts in concurrent environment | |
| unique_id = str(uuid.uuid4()) | |
| fileName = f"user_img_{unique_id}_{timeId}.jpg" | |
| # Convert PIL Image to bytes | |
| img_buffer = io.BytesIO() | |
| if pil_image.mode != 'RGB': | |
| pil_image = pil_image.convert('RGB') | |
| pil_image.save(img_buffer, format='JPEG', quality=95) | |
| img_data = img_buffer.getvalue() | |
| # Upload directly from memory | |
| res = R2Api().upload_from_memory(img_data, fileName, 'image/jpeg') | |
| return res | |
| def create_mask_from_layers(base_image, layers): | |
| """ | |
| Create mask image from ImageEditor layers | |
| Args: | |
| base_image (PIL.Image): Original image | |
| layers (list): ImageEditor layer data | |
| Returns: | |
| PIL.Image: Black and white mask image | |
| """ | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| # Create blank mask with same size as original image | |
| mask = Image.new('L', base_image.size, 0) # 'L' mode is grayscale, 0 is black | |
| if not layers: | |
| return mask | |
| # Iterate through all layers, set drawn areas to white | |
| for layer in layers: | |
| if layer is not None: | |
| # Convert layer to numpy array | |
| layer_array = np.array(layer) | |
| # Check layer format | |
| if len(layer_array.shape) == 3: # RGB/RGBA format | |
| # If RGBA, check alpha channel | |
| if layer_array.shape[2] == 4: | |
| # Use alpha channel as mask | |
| alpha_channel = layer_array[:, :, 3] | |
| # Set non-transparent areas (alpha > 0) to white | |
| mask_array = np.where(alpha_channel > 0, 255, 0).astype(np.uint8) | |
| else: | |
| # RGB format, check if not pure black (0,0,0) | |
| # Assume drawn areas are non-black | |
| non_black = np.any(layer_array > 0, axis=2) | |
| mask_array = np.where(non_black, 255, 0).astype(np.uint8) | |
| elif len(layer_array.shape) == 2: # Grayscale | |
| # Use grayscale values directly, set non-zero areas to white | |
| mask_array = np.where(layer_array > 0, 255, 0).astype(np.uint8) | |
| else: | |
| continue | |
| # Convert mask_array to PIL image and merge into total mask | |
| layer_mask = Image.fromarray(mask_array, mode='L') | |
| # Resize to match original image | |
| if layer_mask.size != base_image.size: | |
| layer_mask = layer_mask.resize(base_image.size, Image.LANCZOS) | |
| # Merge masks (use maximum value to ensure all drawn areas are included) | |
| mask_array_current = np.array(mask) | |
| layer_mask_array = np.array(layer_mask) | |
| combined_mask_array = np.maximum(mask_array_current, layer_mask_array) | |
| mask = Image.fromarray(combined_mask_array, mode='L') | |
| return mask | |
| def upload_mask_image_r2(client_ip, time_id, mask_image): | |
| """ | |
| Upload mask image to R2 directly from memory | |
| Args: | |
| client_ip (str): Client IP | |
| time_id (int): Timestamp | |
| mask_image (PIL.Image): Mask image | |
| Returns: | |
| str: Uploaded URL | |
| """ | |
| # Generate unique filename using UUID to prevent file conflicts in concurrent environment | |
| unique_id = str(uuid.uuid4()) | |
| file_name = f"mask_img_{unique_id}_{time_id}.png" | |
| try: | |
| # Convert mask image to bytes | |
| img_buffer = io.BytesIO() | |
| mask_image.save(img_buffer, format='PNG') | |
| img_data = img_buffer.getvalue() | |
| # Upload directly from memory | |
| res = R2Api().upload_from_memory(img_data, file_name, 'image/png') | |
| return res | |
| except Exception as e: | |
| print(f"Failed to upload mask image: {e}") | |
| return None | |
| def submit_image_edit_task(user_image_url, prompt, task_type="80", mask_image_url="", reference_image_url=""): | |
| """ | |
| Submit image editing task with improved error handling using API v2 | |
| """ | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': f'Bearer {APIKEY}' | |
| } | |
| data = { | |
| "user_image": user_image_url, | |
| "user_mask": mask_image_url, | |
| "type": task_type, | |
| "text": prompt, | |
| "user_uuid": APIKEY, | |
| "priority": 0, | |
| "secret_key": "219ngu" | |
| } | |
| if reference_image_url: | |
| data["user_image2"] = reference_image_url | |
| retry_count = 0 | |
| max_retries = 3 | |
| while retry_count < max_retries: | |
| try: | |
| response = requests.post( | |
| f'{UKAPIURL}/public_image_edit_v2', | |
| headers=headers, | |
| json=data, | |
| timeout=30 # 增加超时时间 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if result.get('code') == 0: | |
| return result['data']['task_id'], None | |
| else: | |
| return None, f"API Error: {result.get('message', 'Unknown error')}" | |
| elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试 | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| print(f"⚠️ Server error {response.status_code}, retrying {retry_count}/{max_retries}") | |
| time.sleep(2) # 等待2秒后重试 | |
| continue | |
| else: | |
| return None, f"HTTP Error after {max_retries} retries: {response.status_code}" | |
| else: | |
| return None, f"HTTP Error: {response.status_code}" | |
| except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| print(f"⚠️ Network error, retrying {retry_count}/{max_retries}: {e}") | |
| time.sleep(2) | |
| continue | |
| else: | |
| return None, f"Network error after {max_retries} retries: {str(e)}" | |
| except Exception as e: | |
| return None, f"Request Exception: {str(e)}" | |
| return None, f"Failed after {max_retries} retries" | |
| def check_task_status(task_id): | |
| """ | |
| Query task status with improved error handling using API v2 | |
| """ | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': f'Bearer {APIKEY}' | |
| } | |
| data = { | |
| "task_id": task_id | |
| } | |
| retry_count = 0 | |
| max_retries = 2 # 状态查询重试次数少一些 | |
| while retry_count < max_retries: | |
| try: | |
| response = requests.post( | |
| f'{UKAPIURL}/status_image_edit_v2', | |
| headers=headers, | |
| json=data, | |
| timeout=15 # 状态查询超时时间短一些 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if result.get('code') == 0: | |
| task_data = result['data'] | |
| status = task_data['status'] | |
| image_url = task_data.get('image_url') | |
| # Extract and log queue information for better user feedback | |
| queue_info = task_data.get('queue_info', {}) | |
| if queue_info: | |
| tasks_ahead = queue_info.get('tasks_ahead', 0) | |
| current_priority = queue_info.get('current_priority', 0) | |
| description = queue_info.get('description', '') | |
| # print(f"📊 Queue Status - Tasks ahead: {tasks_ahead}, Priority: {current_priority}, Status: {status}") | |
| return status, image_url, task_data | |
| else: | |
| return 'error', None, result.get('message', 'Unknown error') | |
| elif response.status_code in [502, 503, 504]: # 服务器错误,可以重试 | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| print(f"⚠️ Status check server error {response.status_code}, retrying {retry_count}/{max_retries}") | |
| time.sleep(1) # 状态查询重试间隔短一些 | |
| continue | |
| else: | |
| return 'error', None, f"HTTP Error after {max_retries} retries: {response.status_code}" | |
| else: | |
| return 'error', None, f"HTTP Error: {response.status_code}" | |
| except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| print(f"⚠️ Status check network error, retrying {retry_count}/{max_retries}: {e}") | |
| time.sleep(1) | |
| continue | |
| else: | |
| return 'error', None, f"Network error after {max_retries} retries: {str(e)}" | |
| except Exception as e: | |
| return 'error', None, f"Request Exception: {str(e)}" | |
| return 'error', None, f"Failed after {max_retries} retries" | |
| def process_image_edit(img_input, prompt, reference_image=None, progress_callback=None): | |
| """ | |
| Complete process for image editing | |
| Args: | |
| img_input: Can be file path (str) or PIL Image object | |
| prompt: Editing instructions | |
| progress_callback: Progress callback function | |
| """ | |
| try: | |
| # Generate client IP and timestamp | |
| client_ip = "127.0.0.1" # Default IP | |
| time_id = int(time.time()) | |
| # Process input image - supports PIL Image and file path | |
| if hasattr(img_input, 'save'): # PIL Image object | |
| pil_image = img_input | |
| print(f"💾 Using PIL Image directly from memory") | |
| else: | |
| # Load from file path | |
| pil_image = Image.open(img_input) | |
| print(f"📁 Loaded image from file: {img_input}") | |
| if progress_callback: | |
| progress_callback("uploading image...") | |
| # Upload user image directly from memory | |
| uploaded_url = upload_user_img_r2(client_ip, time_id, pil_image) | |
| if not uploaded_url: | |
| return None, None, "image upload failed", None | |
| # Extract actual image URL from upload URL | |
| if "?" in uploaded_url: | |
| uploaded_url = uploaded_url.split("?")[0] | |
| if progress_callback: | |
| progress_callback("submitting edit task...") | |
| reference_url = "" | |
| if reference_image is not None: | |
| try: | |
| if progress_callback: | |
| progress_callback("uploading reference image...") | |
| if hasattr(reference_image, 'save'): | |
| reference_pil = reference_image | |
| else: | |
| reference_pil = Image.open(reference_image) | |
| reference_url = upload_user_img_r2(client_ip, time_id, reference_pil) | |
| if not reference_url: | |
| return None, None, "reference image upload failed", None | |
| if "?" in reference_url: | |
| reference_url = reference_url.split("?")[0] | |
| except Exception as e: | |
| return None, None, f"reference image processing failed: {str(e)}", None | |
| # Submit image editing task | |
| task_id, error = submit_image_edit_task(uploaded_url, prompt, reference_image_url=reference_url) | |
| if error: | |
| return None, None, error, None | |
| if progress_callback: | |
| progress_callback(f"task submitted, ID: {task_id}, processing...") | |
| # Wait for task completion | |
| max_attempts = 60 # Wait up to 10 minutes | |
| task_uuid = None | |
| for attempt in range(max_attempts): | |
| status, output_url, task_data = check_task_status(task_id) | |
| # Extract task_uuid from task_data | |
| if task_data and isinstance(task_data, dict): | |
| task_uuid = task_data.get('uuid', None) | |
| if status == 'completed': | |
| if output_url: | |
| return uploaded_url, output_url, "image edit completed", task_uuid | |
| else: | |
| return None, None, "Task completed but no result image returned", task_uuid | |
| elif status == 'error' or status == 'failed': | |
| return None, None, f"task processing failed: {task_data}", task_uuid | |
| elif status in ['queued', 'processing', 'running', 'created', 'working']: | |
| # Enhanced progress message with queue info and website promotion | |
| if progress_callback and task_data and isinstance(task_data, dict): | |
| queue_info = task_data.get('queue_info', {}) | |
| if queue_info and status in ['queued', 'created']: | |
| tasks_ahead = queue_info.get('tasks_ahead', 0) | |
| current_priority = queue_info.get('current_priority', 0) | |
| if tasks_ahead > 0: | |
| progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator") | |
| else: | |
| progress_callback(f"🚀 Processing your image editing request...") | |
| elif status == 'processing': | |
| progress_callback(f"🎨 AI is processing... Please wait") | |
| elif status in ['running', 'working']: | |
| progress_callback(f"⚡ Generating... Almost done") | |
| else: | |
| progress_callback(f"📋 Task status: {status}") | |
| else: | |
| if progress_callback: | |
| progress_callback(f"task processing... (status: {status})") | |
| time.sleep(1) | |
| else: | |
| if progress_callback: | |
| progress_callback(f"unknown status: {status}") | |
| time.sleep(1) | |
| return None, None, "task processing timeout", task_uuid | |
| except Exception as e: | |
| return None, None, f"error occurred during processing: {str(e)}", None | |
| def process_local_image_edit(base_image, layers, prompt, reference_image=None, progress_callback=None, use_example_mask=None): | |
| """ | |
| 处理局部图片编辑的完整流程 | |
| Args: | |
| base_image (PIL.Image): 原始图片 | |
| layers (list): ImageEditor的层数据 | |
| prompt (str): 编辑指令 | |
| progress_callback: 进度回调函数 | |
| """ | |
| try: | |
| # Generate client IP and timestamp | |
| client_ip = "127.0.0.1" # Default IP | |
| time_id = int(time.time()) | |
| if progress_callback: | |
| progress_callback("creating mask image...") | |
| # Check if we should use example mask (backdoor for example case) | |
| if use_example_mask: | |
| # Load local mask file for example | |
| try: | |
| from PIL import Image | |
| import os | |
| # Check if base_image is valid | |
| if base_image is None: | |
| return None, None, "Base image is None, cannot process example mask", None | |
| if os.path.exists(use_example_mask): | |
| mask_image = Image.open(use_example_mask) | |
| # Ensure mask has same size as base image | |
| if hasattr(base_image, 'size') and mask_image.size != base_image.size: | |
| mask_image = mask_image.resize(base_image.size) | |
| # Ensure mask is in L mode (grayscale) | |
| if mask_image.mode != 'L': | |
| mask_image = mask_image.convert('L') | |
| print(f"🎭 Using example mask from: {use_example_mask}, size: {mask_image.size}") | |
| else: | |
| return None, None, f"Example mask file not found: {use_example_mask}", None | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"Failed to load example mask: {str(e)}", None | |
| else: | |
| # Normal case: create mask from layers | |
| mask_image = create_mask_from_layers(base_image, layers) | |
| # 检查mask是否有内容 | |
| mask_array = np.array(mask_image) | |
| if np.max(mask_array) == 0: | |
| return None, None, "please draw mask", None | |
| # Print mask statistics | |
| if not use_example_mask: | |
| print(f"📝 创建mask图片成功,绘制区域像素数: {np.sum(mask_array > 0)}") | |
| else: | |
| mask_array = np.array(mask_image) | |
| print(f"🎭 Example mask loaded successfully, mask pixels: {np.sum(mask_array > 0)}") | |
| if progress_callback: | |
| progress_callback("uploading original image...") | |
| # 直接从内存上传原始图片 | |
| uploaded_url = upload_user_img_r2(client_ip, time_id, base_image) | |
| if not uploaded_url: | |
| return None, None, "original image upload failed", None | |
| # 从上传 URL 中提取实际的图片 URL | |
| if "?" in uploaded_url: | |
| uploaded_url = uploaded_url.split("?")[0] | |
| if progress_callback: | |
| progress_callback("uploading mask image...") | |
| # 直接从内存上传mask图片 | |
| mask_url = upload_mask_image_r2(client_ip, time_id, mask_image) | |
| if not mask_url: | |
| return None, None, "mask image upload failed", None | |
| # 从上传 URL 中提取实际的图片 URL | |
| if "?" in mask_url: | |
| mask_url = mask_url.split("?")[0] | |
| reference_url = "" | |
| if reference_image is not None: | |
| try: | |
| if progress_callback: | |
| progress_callback("uploading reference image...") | |
| if hasattr(reference_image, 'save'): | |
| reference_pil = reference_image | |
| else: | |
| reference_pil = Image.open(reference_image) | |
| reference_url = upload_user_img_r2(client_ip, time_id, reference_pil) | |
| if not reference_url: | |
| return None, None, "reference image upload failed", None | |
| if "?" in reference_url: | |
| reference_url = reference_url.split("?")[0] | |
| except Exception as e: | |
| return None, None, f"reference image processing failed: {str(e)}", None | |
| print(f"📤 图片上传成功:") | |
| print(f" 原始图片: {uploaded_url}") | |
| print(f" Mask图片: {mask_url}") | |
| if reference_url: | |
| print(f" 参考图片: {reference_url}") | |
| if progress_callback: | |
| progress_callback("submitting local edit task...") | |
| # 提交局部图片编辑任务 (task_type=81) | |
| task_id, error = submit_image_edit_task( | |
| uploaded_url, | |
| prompt, | |
| task_type="81", | |
| mask_image_url=mask_url, | |
| reference_image_url=reference_url | |
| ) | |
| if error: | |
| return None, None, error, None | |
| if progress_callback: | |
| progress_callback(f"task submitted, ID: {task_id}, processing...") | |
| print(f"🚀 局部编辑任务已提交,任务ID: {task_id}") | |
| # Wait for task completion | |
| max_attempts = 60 # Wait up to 10 minutes | |
| task_uuid = None | |
| for attempt in range(max_attempts): | |
| status, output_url, task_data = check_task_status(task_id) | |
| # Extract task_uuid from task_data | |
| if task_data and isinstance(task_data, dict): | |
| task_uuid = task_data.get('uuid', None) | |
| if status == 'completed': | |
| if output_url: | |
| print(f"✅ 局部编辑任务完成,结果: {output_url}") | |
| return uploaded_url, output_url, "local image edit completed", task_uuid | |
| else: | |
| return None, None, "task completed but no result image returned", task_uuid | |
| elif status == 'error' or status == 'failed': | |
| return None, None, f"task processing failed: {task_data}", task_uuid | |
| elif status in ['queued', 'processing', 'running', 'created', 'working']: | |
| # Enhanced progress message with queue info and website promotion | |
| if progress_callback and task_data and isinstance(task_data, dict): | |
| queue_info = task_data.get('queue_info', {}) | |
| if queue_info and status in ['queued', 'created']: | |
| tasks_ahead = queue_info.get('tasks_ahead', 0) | |
| current_priority = queue_info.get('current_priority', 0) | |
| if tasks_ahead > 0: | |
| progress_callback(f"⏳ Queue: {tasks_ahead} tasks ahead | Low priority | Visit website for instant processing → https://omnicreator.net/#generator") | |
| else: | |
| progress_callback(f"🚀 Processing your local editing request...") | |
| elif status == 'processing': | |
| progress_callback(f"🎨 AI is processing... Please wait") | |
| elif status in ['running', 'working']: | |
| progress_callback(f"⚡ Generating... Almost done") | |
| else: | |
| progress_callback(f"📋 Task status: {status}") | |
| else: | |
| if progress_callback: | |
| progress_callback(f"processing... (status: {status})") | |
| time.sleep(1) # Wait 1 second before retry | |
| else: | |
| if progress_callback: | |
| progress_callback(f"unknown status: {status}") | |
| time.sleep(1) | |
| return None, None, "task processing timeout", task_uuid | |
| except Exception as e: | |
| print(f"❌ 局部编辑处理异常: {str(e)}") | |
| return None, None, f"error occurred during processing: {str(e)}", None | |
| def download_and_check_result_nsfw(image_url, nsfw_detector=None): | |
| """ | |
| 下载结果图片并进行NSFW检测 | |
| Args: | |
| image_url (str): 结果图片URL | |
| nsfw_detector: NSFW检测器实例 | |
| Returns: | |
| tuple: (is_nsfw, error_message) | |
| """ | |
| if nsfw_detector is None: | |
| return False, None | |
| try: | |
| # 下载图片 | |
| response = requests.get(image_url, timeout=30) | |
| if response.status_code != 200: | |
| return False, f"Failed to download result image: HTTP {response.status_code}" | |
| # 将图片数据转换为PIL Image | |
| image_data = io.BytesIO(response.content) | |
| result_image = Image.open(image_data) | |
| # 进行NSFW检测 | |
| nsfw_result = nsfw_detector.predict_pil_label_only(result_image) | |
| is_nsfw = nsfw_result.lower() == "nsfw" | |
| print(f"🔍 结果图片NSFW检测: {'❌❌❌ ' + nsfw_result if is_nsfw else '✅✅✅ ' + nsfw_result}") | |
| return is_nsfw, None | |
| except Exception as e: | |
| print(f"⚠️ 结果图片NSFW检测失败: {e}") | |
| return False, f"Failed to check result image: {str(e)}" | |
| if __name__ == "__main__": | |
| pass | |