Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os, requests | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from model.model import ResHalf | |
| from inference import Inferencer | |
| from utils import util | |
| ## local | remote | |
| RUN_MODE = "remote" | |
| if RUN_MODE != "local": | |
| os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar") | |
| os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar") | |
| ## examples | |
| os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/girl.png") | |
| os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/wave.png") | |
| os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/painting.png") | |
| ## step 1: set up model | |
| device = "cpu" | |
| checkpt_path = "checkpoints/model_best.pth.tar" | |
| invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False) | |
| def prepare_data(input_img, decoding_only=False): | |
| input_img = np.array(input_img / 255., np.float32) | |
| if decoding_only: | |
| input_img = input_img[:,:,:1] | |
| input_img = util.img2tensor(input_img * 2. - 1.) | |
| return input_img | |
| def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"): | |
| input_img = prepare_data(input_img, decoding_only) | |
| input_img = input_img.to(device) | |
| if decoding_only: | |
| print('>>>:restoration mode') | |
| resColor = invhalfer(input_img, decoding_only=decoding_only) | |
| output = util.tensor2img(resColor / 2. + 0.5) * 255. | |
| else: | |
| print('>>>:halftoning mode') | |
| resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only) | |
| output = util.tensor2img(resHalftone / 2. + 0.5) * 255. | |
| return np.clip(output, 0, 255).astype(np.uint8) | |
| def click_run(input_img, decoding_only): | |
| output = run_invhalf(invhalfer, input_img, decoding_only, device) | |
| return output | |
| def click_move(output_img, decoding_only): | |
| if decoding_only: | |
| radio_status = "Halftoning (Photo2Halftone)" | |
| else: | |
| radio_status = "Restoration (Halftone2Photo)" | |
| return output_img, radio_status, None | |
| ## step 2: configure interface | |
| demo = gr.Blocks(title="ReversibleHalftoning") | |
| with demo: | |
| gr.Markdown(value=""" | |
| **Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| Image_input = gr.Image(type="numpy", label="Input", interactive=True).style(height=480) | |
| with gr.Row(): | |
| Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \ | |
| label="Choose a running mode", value="Halftoning (Photo2Halftone)") | |
| Button_run = gr.Button(value="Run") | |
| with gr.Column(): | |
| Image_output = gr.Image(type="numpy", label="Output").style(height=480) | |
| Button_move = gr.Button(value="Use it as input") | |
| Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output) | |
| Button_move.click(fn=click_move, inputs=[Image_output, Radio_mode], outputs=[Image_input, Radio_mode, Image_output]) | |
| if RUN_MODE != "local": | |
| gr.Examples(examples=[ | |
| ['girl.png', "Halftoning (Photo2Halftone)"], | |
| ['wave.png', "Halftoning (Photo2Halftone)"], | |
| ['painting.png', "Restoration (Halftone2Photo)"], | |
| ], | |
| inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples") | |
| if RUN_MODE == "local": | |
| demo.launch(server_name='9.134.253.83',server_port=7788) | |
| else: | |
| demo.launch() |