Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| from flask import Flask, render_template, request, send_file | |
| from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan.utils import GFPGANer | |
| from realesrgan.utils import RealESRGANer | |
| # Fix OpenMP threads issue | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| app = Flask(__name__) | |
| os.makedirs("output", exist_ok=True) | |
| # Setup RealESRGAN upsampler | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| model_path = 'realesr-general-x4v3.pth' | |
| half = torch.cuda.is_available() | |
| upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
| # GFPGAN inference function | |
| def enhance_face(img_path, version="v1.4", scale=2): | |
| img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
| img_mode = None | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| img_mode = 'RGBA' | |
| elif len(img.shape) == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if version == 'RestoreFormer': | |
| face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
| else: | |
| face_enhancer = GFPGANer(model_path=f"{version}.pth", upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
| # Optional rescale | |
| if scale != 2: | |
| interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
| h, w = output.shape[0:2] | |
| output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
| # Save output | |
| extension = 'png' if img_mode == 'RGBA' else 'jpg' | |
| save_path = f"output/out.{extension}" | |
| cv2.imwrite(save_path, output) | |
| return save_path | |
| # Flask routes | |
| def index(): | |
| if request.method == "POST": | |
| file = request.files["image"] | |
| version = request.form.get("version", "v1.4") | |
| scale = float(request.form.get("scale", 2)) | |
| filepath = os.path.join("output", file.filename) | |
| file.save(filepath) | |
| output_path = enhance_face(filepath, version, scale) | |
| return send_file(output_path, as_attachment=True) | |
| return """ | |
| <h1>GFPGAN Face Restoration</h1> | |
| <form method="post" enctype="multipart/form-data"> | |
| Upload Image: <input type="file" name="image"><br><br> | |
| Version: | |
| <select name="version"> | |
| <option value="v1.2">v1.2</option> | |
| <option value="v1.3">v1.3</option> | |
| <option value="v1.4" selected>v1.4</option> | |
| <option value="RestoreFormer">RestoreFormer</option> | |
| </select><br><br> | |
| Rescale factor: <input type="number" step="0.1" name="scale" value="2"><br><br> | |
| <input type="submit" value="Enhance"> | |
| </form> | |
| """ | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=True) | |