File size: 10,556 Bytes
d45e6fb
 
 
 
3503a08
d45e6fb
3503a08
fdce011
 
 
 
 
3503a08
fdce011
3503a08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d45e6fb
fdce011
d45e6fb
 
 
 
fdce011
 
d45e6fb
fdce011
d45e6fb
 
 
 
fdce011
d45e6fb
fdce011
d45e6fb
fdce011
 
d45e6fb
fdce011
d45e6fb
fdce011
d45e6fb
fdce011
 
d45e6fb
fdce011
 
d45e6fb
fdce011
d45e6fb
fdce011
3503a08
fdce011
 
3503a08
fdce011
3503a08
 
fdce011
3503a08
fdce011
3503a08
 
 
 
 
 
 
 
 
 
 
 
 
fdce011
3503a08
 
 
fdce011
 
3503a08
fdce011
3503a08
fdce011
3503a08
fdce011
 
3503a08
fdce011
3503a08
 
 
fdce011
3503a08
3f97e41
 
3503a08
 
3f97e41
fdce011
3503a08
fdce011
d45e6fb
3503a08
d45e6fb
 
 
3503a08
d45e6fb
 
fdce011
3503a08
 
fdce011
3503a08
 
fdce011
3f97e41
fdce011
 
 
 
 
3f97e41
d45e6fb
fdce011
3503a08
3f97e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d45e6fb
3503a08
3f97e41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image, ImageFilter, ImageOps, ImageChops
import torchvision.transforms as transforms
import os
import pathlib

# --- ⚙️ Configuration ---
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)

# --- 🎨 Filters ---
FILTERS = {
    "Standard": "📄", "Invert": "⚫⚪", "Blur": "🌫️", "Sharpen": "🔪", "Contour": "🗺️",
    "Detail": "🔍", "EdgeEnhance": "📏", "EdgeEnhanceMore": "📐", "Emboss": "🏞️",
    "FindEdges": "🕵️", "Smooth": "🌊", "SmoothMore": "💧", "Solarize": "☀️",
    "Posterize1": "🖼️1", "Posterize2": "🖼️2", "Posterize3": "🖼️3", "Posterize4": "🖼️4",
    "Equalize": "⚖️", "AutoContrast": "🔧", "Thick1": "💪1", "Thick2": "💪2", "Thick3": "💪3",
    "Thin1": "🏃1", "Thin2": "🏃2", "Thin3": "🏃3", "RedOnWhite": "🔴", "OrangeOnWhite": "🟠",
    "YellowOnWhite": "🟡", "GreenOnWhite": "🟢", "BlueOnWhite": "🔵", "PurpleOnWhite": "🟣",
    "PinkOnWhite": "🌸", "CyanOnWhite": "🩵", "MagentaOnWhite": "🟪", "BrownOnWhite": "🤎",
    "GrayOnWhite": "🩶", "WhiteOnBlack": "⚪", "RedOnBlack": "🔴⚫", "OrangeOnBlack": "🟠⚫",
    "YellowOnBlack": "🟡⚫", "GreenOnBlack": "🟢⚫", "BlueOnBlack": "🔵⚫", "PurpleOnBlack": "🟣⚫",
    "PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩵⚫", "MagentaOnBlack": "🟪⚫", "BrownOnBlack": "🤎⚫",
    "GrayOnBlack": "🩶⚫", "Multiply": "✖️", "Screen": "🖥️", "Overlay": "🔲", "Add": "➕",
    "Subtract": "➖", "Difference": "≠", "Darker": "🌑", "Lighter": "🌕", "SoftLight": "💡",
    "HardLight": "🔦", "Binary": "🌓", "Noise": "❄️"
}

# --- 🧠 Neural Network Model (Unchanged) ---
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        conv_block = [  nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ]
        self.conv_block = nn.Sequential(*conv_block)
    def forward(self, x): return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
        super(Generator, self).__init__()
        model0 = [   nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ]
        self.model0 = nn.Sequential(*model0)
        model1, in_features, out_features = [], 64, 128
        for _ in range(2):
            model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
            in_features = out_features; out_features = in_features*2
        self.model1 = nn.Sequential(*model1)
        model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)]
        self.model2 = nn.Sequential(*model2)
        model3, out_features = [], in_features//2
        for _ in range(2):
            model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ]
            in_features = out_features; out_features = in_features//2
        self.model3 = nn.Sequential(*model3)
        model4 = [  nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
        if sigmoid: model4 += [nn.Sigmoid()]
        self.model4 = nn.Sequential(*model4)
    def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x)))))

# --- 🔧 Model Loading ---
try:
    model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval()
    model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval()
except FileNotFoundError:
    print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.")
    model1, model2 = None, None

# --- ✨ Filter Application Logic (Unchanged) ---
def apply_filter(line_img, filter_name, original_img):
    if filter_name == "Standard": return line_img
    line_img_l = line_img.convert('L')
    if filter_name == "Invert": return ImageOps.invert(line_img_l)
    if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3))
    if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN)
    if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR)
    if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL)
    if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE)
    if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE)
    if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS)
    if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES)
    if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH)
    if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE)
    if filter_name == "Solarize": return ImageOps.solarize(line_img_l)
    if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1]))
    if filter_name == "Equalize": return ImageOps.equalize(line_img_l)
    if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l)
    if filter_name == "Binary": return line_img_l.convert('1')
    if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
    if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7)))
    colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"}
    if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white")
    colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"}
    if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black")
    line_img_rgb = line_img.convert('RGB')
    blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light}
    if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb)
    if filter_name == "Noise":
        img_array = np.array(line_img_l)
        noise = np.random.randint(-20, 20, img_array.shape, dtype='int16')
        noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8')
        return Image.fromarray(noisy_array)
    return line_img

# --- 🖼️ Main Processing Function ---
def predict(input_img_path, line_style, filter_choice):
    if not model1 or not model2:
        raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.")
    
    filter_name = filter_choice.split(" ", 1)[1]
    original_img = Image.open(input_img_path).convert('RGB')
    
    transform = transforms.Compose([
        transforms.Resize(256, transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    input_tensor = transform(original_img).unsqueeze(0)

    with torch.no_grad():
        output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor)
    
    line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
    line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC)
    
    final_image = apply_filter(line_drawing_full_res, filter_name, original_img)
    
    # --- 💾 Save the output image ---
    base_name = pathlib.Path(input_img_path).stem
    output_filename = f"{base_name}_{filter_name}.png"
    output_filepath = os.path.join(output_dir, output_filename)
    final_image.save(output_filepath)
    
    return final_image

# --- 🚀 Gradio UI Setup ---
title = "🖌️ Image to Line Art with Creative Filters"
description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture. Results are saved in the 'outputs' folder."

filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()]

# --- ✅ New Curated Examples Section ---
examples = []
example_images = [f"{i:02d}.jpeg" for i in range(1, 11)]
# A selection of 6 interesting filters to demonstrate
demo_filters = ["🗺️ Contour", "🔵⚫ BlueOnBlack", "✖️ Multiply", "🏞️ Emboss", "🔪 Sharpen", "❄️ Noise"]

# Create one example for each of the 10 image files, cycling through the demo filters
for i, img_file in enumerate(example_images):
    if os.path.exists(img_file):
        # Use modulo to cycle through the 6 demo filters for the 10 images
        chosen_filter = demo_filters[i % len(demo_filters)]
        examples.append([img_file, 'Simple Lines', chosen_filter])

if not examples:
    print("⚠️ Warning: No example images ('01.jpeg' to '10.jpeg') found. Examples will be empty.")

# Reverted to the simpler and more stable gr.Interface
iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type='filepath', label="Upload Image"),
        gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'),
        gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0])
    ],
    outputs=gr.Image(type="pil", label="Filtered Line Art"),
    title=title,
    description=description,
    examples=examples,
    allow_flagging='never'
)

if __name__ == "__main__":
    iface.launch()