File size: 10,556 Bytes
d45e6fb 3503a08 d45e6fb 3503a08 fdce011 3503a08 fdce011 3503a08 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 d45e6fb fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 fdce011 3503a08 3f97e41 3503a08 3f97e41 fdce011 3503a08 fdce011 d45e6fb 3503a08 d45e6fb 3503a08 d45e6fb fdce011 3503a08 fdce011 3503a08 fdce011 3f97e41 fdce011 3f97e41 d45e6fb fdce011 3503a08 3f97e41 d45e6fb 3503a08 3f97e41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image, ImageFilter, ImageOps, ImageChops
import torchvision.transforms as transforms
import os
import pathlib
# --- ⚙️ Configuration ---
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
# --- 🎨 Filters ---
FILTERS = {
"Standard": "📄", "Invert": "⚫⚪", "Blur": "🌫️", "Sharpen": "🔪", "Contour": "🗺️",
"Detail": "🔍", "EdgeEnhance": "📏", "EdgeEnhanceMore": "📐", "Emboss": "🏞️",
"FindEdges": "🕵️", "Smooth": "🌊", "SmoothMore": "💧", "Solarize": "☀️",
"Posterize1": "🖼️1", "Posterize2": "🖼️2", "Posterize3": "🖼️3", "Posterize4": "🖼️4",
"Equalize": "⚖️", "AutoContrast": "🔧", "Thick1": "💪1", "Thick2": "💪2", "Thick3": "💪3",
"Thin1": "🏃1", "Thin2": "🏃2", "Thin3": "🏃3", "RedOnWhite": "🔴", "OrangeOnWhite": "🟠",
"YellowOnWhite": "🟡", "GreenOnWhite": "🟢", "BlueOnWhite": "🔵", "PurpleOnWhite": "🟣",
"PinkOnWhite": "🌸", "CyanOnWhite": "🩵", "MagentaOnWhite": "🟪", "BrownOnWhite": "🤎",
"GrayOnWhite": "🩶", "WhiteOnBlack": "⚪", "RedOnBlack": "🔴⚫", "OrangeOnBlack": "🟠⚫",
"YellowOnBlack": "🟡⚫", "GreenOnBlack": "🟢⚫", "BlueOnBlack": "🔵⚫", "PurpleOnBlack": "🟣⚫",
"PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩵⚫", "MagentaOnBlack": "🟪⚫", "BrownOnBlack": "🤎⚫",
"GrayOnBlack": "🩶⚫", "Multiply": "✖️", "Screen": "🖥️", "Overlay": "🔲", "Add": "➕",
"Subtract": "➖", "Difference": "≠", "Darker": "🌑", "Lighter": "🌕", "SoftLight": "💡",
"HardLight": "🔦", "Binary": "🌓", "Noise": "❄️"
}
# --- 🧠 Neural Network Model (Unchanged) ---
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True),
nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x): return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
model1, in_features, out_features = [], 64, 128
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
in_features = out_features; out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)]
self.model2 = nn.Sequential(*model2)
model3, out_features = [], in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
in_features = out_features; out_features = in_features//2
self.model3 = nn.Sequential(*model3)
model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
if sigmoid: model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x)))))
# --- 🔧 Model Loading ---
try:
model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval()
model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval()
except FileNotFoundError:
print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.")
model1, model2 = None, None
# --- ✨ Filter Application Logic (Unchanged) ---
def apply_filter(line_img, filter_name, original_img):
if filter_name == "Standard": return line_img
line_img_l = line_img.convert('L')
if filter_name == "Invert": return ImageOps.invert(line_img_l)
if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3))
if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN)
if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR)
if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL)
if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE)
if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE)
if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS)
if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES)
if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH)
if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE)
if filter_name == "Solarize": return ImageOps.solarize(line_img_l)
if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1]))
if filter_name == "Equalize": return ImageOps.equalize(line_img_l)
if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l)
if filter_name == "Binary": return line_img_l.convert('1')
if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"}
if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white")
colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"}
if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black")
line_img_rgb = line_img.convert('RGB')
blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light}
if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb)
if filter_name == "Noise":
img_array = np.array(line_img_l)
noise = np.random.randint(-20, 20, img_array.shape, dtype='int16')
noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8')
return Image.fromarray(noisy_array)
return line_img
# --- 🖼️ Main Processing Function ---
def predict(input_img_path, line_style, filter_choice):
if not model1 or not model2:
raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.")
filter_name = filter_choice.split(" ", 1)[1]
original_img = Image.open(input_img_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize(256, transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(original_img).unsqueeze(0)
with torch.no_grad():
output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor)
line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC)
final_image = apply_filter(line_drawing_full_res, filter_name, original_img)
# --- 💾 Save the output image ---
base_name = pathlib.Path(input_img_path).stem
output_filename = f"{base_name}_{filter_name}.png"
output_filepath = os.path.join(output_dir, output_filename)
final_image.save(output_filepath)
return final_image
# --- 🚀 Gradio UI Setup ---
title = "🖌️ Image to Line Art with Creative Filters"
description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture. Results are saved in the 'outputs' folder."
filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()]
# --- ✅ New Curated Examples Section ---
examples = []
example_images = [f"{i:02d}.jpeg" for i in range(1, 11)]
# A selection of 6 interesting filters to demonstrate
demo_filters = ["🗺️ Contour", "🔵⚫ BlueOnBlack", "✖️ Multiply", "🏞️ Emboss", "🔪 Sharpen", "❄️ Noise"]
# Create one example for each of the 10 image files, cycling through the demo filters
for i, img_file in enumerate(example_images):
if os.path.exists(img_file):
# Use modulo to cycle through the 6 demo filters for the 10 images
chosen_filter = demo_filters[i % len(demo_filters)]
examples.append([img_file, 'Simple Lines', chosen_filter])
if not examples:
print("⚠️ Warning: No example images ('01.jpeg' to '10.jpeg') found. Examples will be empty.")
# Reverted to the simpler and more stable gr.Interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type='filepath', label="Upload Image"),
gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'),
gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0])
],
outputs=gr.Image(type="pil", label="Filtered Line Art"),
title=title,
description=description,
examples=examples,
allow_flagging='never'
)
if __name__ == "__main__":
iface.launch() |