import numpy as np import torch import torch.nn as nn import gradio as gr from PIL import Image, ImageFilter, ImageOps, ImageChops import torchvision.transforms as transforms import os import pathlib # --- βš™οΈ Configuration --- output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) # --- 🎨 Filters --- FILTERS = { "Standard": "πŸ“„", "Invert": "⚫βšͺ", "Blur": "🌫️", "Sharpen": "πŸ”ͺ", "Contour": "πŸ—ΊοΈ", "Detail": "πŸ”", "EdgeEnhance": "πŸ“", "EdgeEnhanceMore": "πŸ“", "Emboss": "🏞️", "FindEdges": "πŸ•΅οΈ", "Smooth": "🌊", "SmoothMore": "πŸ’§", "Solarize": "β˜€οΈ", "Posterize1": "πŸ–ΌοΈ1", "Posterize2": "πŸ–ΌοΈ2", "Posterize3": "πŸ–ΌοΈ3", "Posterize4": "πŸ–ΌοΈ4", "Equalize": "βš–οΈ", "AutoContrast": "πŸ”§", "Thick1": "πŸ’ͺ1", "Thick2": "πŸ’ͺ2", "Thick3": "πŸ’ͺ3", "Thin1": "πŸƒ1", "Thin2": "πŸƒ2", "Thin3": "πŸƒ3", "RedOnWhite": "πŸ”΄", "OrangeOnWhite": "🟠", "YellowOnWhite": "🟑", "GreenOnWhite": "🟒", "BlueOnWhite": "πŸ”΅", "PurpleOnWhite": "🟣", "PinkOnWhite": "🌸", "CyanOnWhite": "🩡", "MagentaOnWhite": "πŸŸͺ", "BrownOnWhite": "🀎", "GrayOnWhite": "🩢", "WhiteOnBlack": "βšͺ", "RedOnBlack": "πŸ”΄βš«", "OrangeOnBlack": "🟠⚫", "YellowOnBlack": "🟑⚫", "GreenOnBlack": "🟒⚫", "BlueOnBlack": "πŸ”΅βš«", "PurpleOnBlack": "🟣⚫", "PinkOnBlack": "🌸⚫", "CyanOnBlack": "🩡⚫", "MagentaOnBlack": "πŸŸͺ⚫", "BrownOnBlack": "🀎⚫", "GrayOnBlack": "🩢⚫", "Multiply": "βœ–οΈ", "Screen": "πŸ–₯️", "Overlay": "πŸ”²", "Add": "βž•", "Subtract": "βž–", "Difference": "β‰ ", "Darker": "πŸŒ‘", "Lighter": "πŸŒ•", "SoftLight": "πŸ’‘", "HardLight": "πŸ”¦", "Binary": "πŸŒ“", "Noise": "❄️" } # --- 🧠 Neural Network Model (Unchanged) --- norm_layer = nn.InstanceNorm2d class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): super(Generator, self).__init__() model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] self.model0 = nn.Sequential(*model0) model1, in_features, out_features = [], 64, 128 for _ in range(2): model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features; out_features = in_features*2 self.model1 = nn.Sequential(*model1) model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)] self.model2 = nn.Sequential(*model2) model3, out_features = [], in_features//2 for _ in range(2): model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features; out_features = in_features//2 self.model3 = nn.Sequential(*model3) model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] if sigmoid: model4 += [nn.Sigmoid()] self.model4 = nn.Sequential(*model4) def forward(self, x, cond=None): return self.model4(self.model3(self.model2(self.model1(self.model0(x))))) # --- πŸ”§ Model Loading --- try: model1 = Generator(3, 1, 3); model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))); model1.eval() model2 = Generator(3, 1, 3); model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu'))); model2.eval() except FileNotFoundError: print("⚠️ Warning: Model files 'model.pth' or 'model2.pth' not found. The application will not run correctly.") model1, model2 = None, None # --- ✨ Filter Application Logic (Unchanged) --- def apply_filter(line_img, filter_name, original_img): if filter_name == "Standard": return line_img line_img_l = line_img.convert('L') if filter_name == "Invert": return ImageOps.invert(line_img_l) if filter_name == "Blur": return line_img.filter(ImageFilter.GaussianBlur(radius=3)) if filter_name == "Sharpen": return line_img.filter(ImageFilter.SHARPEN) if filter_name == "Contour": return line_img_l.filter(ImageFilter.CONTOUR) if filter_name == "Detail": return line_img.filter(ImageFilter.DETAIL) if filter_name == "EdgeEnhance": return line_img_l.filter(ImageFilter.EDGE_ENHANCE) if filter_name == "EdgeEnhanceMore": return line_img_l.filter(ImageFilter.EDGE_ENHANCE_MORE) if filter_name == "Emboss": return line_img_l.filter(ImageFilter.EMBOSS) if filter_name == "FindEdges": return line_img_l.filter(ImageFilter.FIND_EDGES) if filter_name == "Smooth": return line_img.filter(ImageFilter.SMOOTH) if filter_name == "SmoothMore": return line_img.filter(ImageFilter.SMOOTH_MORE) if filter_name == "Solarize": return ImageOps.solarize(line_img_l) if filter_name.startswith("Posterize"): return ImageOps.posterize(line_img_l, int(filter_name[-1])) if filter_name == "Equalize": return ImageOps.equalize(line_img_l) if filter_name == "AutoContrast": return ImageOps.autocontrast(line_img_l) if filter_name == "Binary": return line_img_l.convert('1') if filter_name.startswith("Thick"): return line_img_l.filter(ImageFilter.MinFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) if filter_name.startswith("Thin"): return line_img_l.filter(ImageFilter.MaxFilter(3 if filter_name[-1]=='1' else (5 if filter_name[-1]=='2' else 7))) colors_on_white = {"RedOnWhite": "red", "OrangeOnWhite": "orange", "YellowOnWhite": "yellow", "GreenOnWhite": "green", "BlueOnWhite": "blue", "PurpleOnWhite": "purple", "PinkOnWhite": "pink", "CyanOnWhite": "cyan", "MagentaOnWhite": "magenta", "BrownOnWhite": "brown", "GrayOnWhite": "gray"} if filter_name in colors_on_white: return ImageOps.colorize(line_img_l, black=colors_on_white[filter_name], white="white") colors_on_black = {"WhiteOnBlack": "white", "RedOnBlack": "red", "OrangeOnBlack": "orange", "YellowOnBlack": "yellow", "GreenOnBlack": "green", "BlueOnBlack": "blue", "PurpleOnBlack": "purple", "PinkOnBlack": "pink", "CyanOnBlack": "cyan", "MagentaOnBlack": "magenta", "BrownOnBlack": "brown", "GrayOnBlack": "gray"} if filter_name in colors_on_black: return ImageOps.colorize(line_img_l, black=colors_on_black[filter_name], white="black") line_img_rgb = line_img.convert('RGB') blend_ops = {"Multiply": ImageChops.multiply, "Screen": ImageChops.screen, "Overlay": ImageChops.overlay, "Add": ImageChops.add, "Subtract": ImageChops.subtract, "Difference": ImageChops.difference, "Darker": ImageChops.darker, "Lighter": ImageChops.lighter, "SoftLight": ImageChops.soft_light, "HardLight": ImageChops.hard_light} if filter_name in blend_ops: return blend_ops[filter_name](original_img, line_img_rgb) if filter_name == "Noise": img_array = np.array(line_img_l) noise = np.random.randint(-20, 20, img_array.shape, dtype='int16') noisy_array = np.clip(img_array.astype('int16') + noise, 0, 255).astype('uint8') return Image.fromarray(noisy_array) return line_img # --- πŸ–ΌοΈ Main Processing Function --- def predict(input_img_path, line_style, filter_choice): if not model1 or not model2: raise gr.Error("Models are not loaded. Please check for 'model.pth' and 'model2.pth'.") filter_name = filter_choice.split(" ", 1)[1] original_img = Image.open(input_img_path).convert('RGB') transform = transforms.Compose([ transforms.Resize(256, transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) input_tensor = transform(original_img).unsqueeze(0) with torch.no_grad(): output = model2(input_tensor) if line_style == 'Simple Lines' else model1(input_tensor) line_drawing_low_res = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) line_drawing_full_res = line_drawing_low_res.resize(original_img.size, Image.Resampling.BICUBIC) final_image = apply_filter(line_drawing_full_res, filter_name, original_img) # --- πŸ’Ύ Save the output image --- base_name = pathlib.Path(input_img_path).stem output_filename = f"{base_name}_{filter_name}.png" output_filepath = os.path.join(output_dir, output_filename) final_image.save(output_filepath) return final_image # --- πŸš€ Gradio UI Setup --- title = "πŸ–ŒοΈ Image to Line Art with Creative Filters" description = "Upload an image, choose a line style (Complex or Simple), and select a filter from the dropdown to transform your picture. Results are saved in the 'outputs' folder." filter_choices = [f"{emoji} {name}" for name, emoji in FILTERS.items()] # --- βœ… New Curated Examples Section --- examples = [] example_images = [f"{i:02d}.jpeg" for i in range(1, 11)] # A selection of 6 interesting filters to demonstrate demo_filters = ["πŸ—ΊοΈ Contour", "πŸ”΅βš« BlueOnBlack", "βœ–οΈ Multiply", "🏞️ Emboss", "πŸ”ͺ Sharpen", "❄️ Noise"] # Create one example for each of the 10 image files, cycling through the demo filters for i, img_file in enumerate(example_images): if os.path.exists(img_file): # Use modulo to cycle through the 6 demo filters for the 10 images chosen_filter = demo_filters[i % len(demo_filters)] examples.append([img_file, 'Simple Lines', chosen_filter]) if not examples: print("⚠️ Warning: No example images ('01.jpeg' to '10.jpeg') found. Examples will be empty.") # Reverted to the simpler and more stable gr.Interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type='filepath', label="Upload Image"), gr.Radio(['Complex Lines', 'Simple Lines'], label='Line Style', value='Simple Lines'), gr.Dropdown(filter_choices, label="Filter", value=filter_choices[0]) ], outputs=gr.Image(type="pil", label="Filtered Line Art"), title=title, description=description, examples=examples, allow_flagging='never' ) if __name__ == "__main__": iface.launch()