Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation is based on "Parameter-Efficient Orthogonal Finetuning
# via Butterfly Factorization" (https://huggingface.co/papers/2311.06243) in ICLR 2024.
import glob
import os
from pathlib import Path
import cv2
import face_alignment
import numpy as np
import torch
from accelerate import Accelerator
from skimage.io import imread
from torchvision.utils import save_image
from tqdm import tqdm
from transformers import AutoTokenizer
from utils.args_loader import parse_args
from utils.dataset import make_dataset
# Determine the best available device
if torch.cuda.is_available():
device = "cuda:0"
else:
# TODO: xpu support in facealignment will be ready after this PR is merged:https://github.com/1adrianb/face-alignment/pull/371
device = "cpu"
detect_model = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=device, flip_input=False)
# with open('./data/celebhq-text/prompt_val_blip_full.json', 'rt') as f: # fill50k, COCO
# for line in f:
# val_data = json.loads(line)
end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype=np.int32) - 1
def count_txt_files(directory):
pattern = os.path.join(directory, "*.txt")
txt_files = glob.glob(pattern)
return len(txt_files)
def plot_kpts(image, kpts, color="g"):
"""Draw 68 key points
Args:
image: the input image
kpt: (68, 3).
"""
if color == "r":
c = (255, 0, 0)
elif color == "g":
c = (0, 255, 0)
elif color == "b":
c = (255, 0, 0)
image = image.copy()
kpts = kpts.copy()
radius = max(int(min(image.shape[0], image.shape[1]) / 200), 1)
for i in range(kpts.shape[0]):
st = kpts[i, :2]
if kpts.shape[1] == 4:
if kpts[i, 3] > 0.5:
c = (0, 255, 0)
else:
c = (0, 0, 255)
image = cv2.circle(image, (int(st[0]), int(st[1])), radius, c, radius * 2)
if i in end_list:
continue
ed = kpts[i + 1, :2]
image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), radius)
return image
def generate_landmark2d(dataset, input_dir, pred_lmk_dir, gt_lmk_dir, vis=False):
print("Generate 2d landmarks ...")
os.makedirs(pred_lmk_dir, exist_ok=True)
imagepath_list = sorted(glob.glob(f"{input_dir}/pred*.png"))
for imagepath in tqdm(imagepath_list):
name = Path(imagepath).stem
idx = int(name.split("_")[-1])
pred_txt_path = os.path.join(pred_lmk_dir, f"{idx}.txt")
gt_lmk_path = os.path.join(gt_lmk_dir, f"{idx}_gt_lmk.jpg")
gt_txt_path = os.path.join(gt_lmk_dir, f"{idx}.txt")
gt_img_path = os.path.join(gt_lmk_dir, f"{idx}_gt_img.jpg")
if (not os.path.exists(pred_txt_path)) or (not os.path.exists(gt_txt_path)):
image = imread(imagepath) # [:, :, :3]
out = detect_model.get_landmarks(image)
if out is None:
continue
pred_kpt = out[0].squeeze()
np.savetxt(pred_txt_path, pred_kpt)
# Your existing code for obtaining the image tensor
gt_lmk_img = dataset[idx]["conditioning_pixel_values"]
save_image(gt_lmk_img, gt_lmk_path)
gt_img = (dataset[idx]["pixel_values"]) * 0.5 + 0.5
save_image(gt_img, gt_img_path)
gt_img = (gt_img.permute(1, 2, 0) * 255).type(torch.uint8).cpu().numpy()
out = detect_model.get_landmarks(gt_img)
if out is None:
continue
gt_kpt = out[0].squeeze()
np.savetxt(gt_txt_path, gt_kpt)
# gt_image = cv2.resize(cv2.imread(gt_lmk_path), (512, 512))
if vis:
gt_lmk_image = cv2.imread(gt_lmk_path)
# visualize predicted landmarks
vis_path = os.path.join(pred_lmk_dir, f"{idx}_overlay.jpg")
image = cv2.imread(imagepath)
image_point = plot_kpts(image, pred_kpt)
cv2.imwrite(vis_path, np.concatenate([image_point, gt_lmk_image], axis=1))
# visualize gt landmarks
vis_path = os.path.join(gt_lmk_dir, f"{idx}_overlay.jpg")
image = cv2.imread(gt_img_path)
image_point = plot_kpts(image, gt_kpt)
cv2.imwrite(vis_path, np.concatenate([image_point, gt_lmk_image], axis=1))
def landmark_comparison(val_dataset, lmk_dir, gt_lmk_dir):
print("Calculating reprojection error")
lmk_err = []
pbar = tqdm(range(len(val_dataset)))
for i in pbar:
# line = val_dataset[i]
# img_name = line["image"].split(".")[0]
lmk1_path = os.path.join(gt_lmk_dir, f"{i}.txt")
lmk1 = np.loadtxt(lmk1_path)
lmk2_path = os.path.join(lmk_dir, f"{i}.txt")
if not os.path.exists(lmk2_path):
print(f"{lmk2_path} not exist")
continue
lmk2 = np.loadtxt(lmk2_path)
lmk_err.append(np.mean(np.linalg.norm(lmk1 - lmk2, axis=1)))
pbar.set_description(f"lmk_err: {np.mean(lmk_err):.5f}")
print("Reprojection error:", np.mean(lmk_err))
np.save(os.path.join(lmk_dir, "lmk_err.npy"), lmk_err)
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_dir=logging_dir,
)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
val_dataset = make_dataset(args, tokenizer, accelerator, "test")
gt_lmk_dir = os.path.join(args.output_dir, "gt_lmk")
if not os.path.exists(gt_lmk_dir):
os.makedirs(gt_lmk_dir, exist_ok=True)
pred_lmk_dir = os.path.join(args.output_dir, "pred_lmk")
if not os.path.exists(pred_lmk_dir):
os.makedirs(pred_lmk_dir, exist_ok=True)
input_dir = os.path.join(args.output_dir, "results")
generate_landmark2d(val_dataset, input_dir, pred_lmk_dir, gt_lmk_dir, args.vis_overlays)
if count_txt_files(pred_lmk_dir) == len(val_dataset) and count_txt_files(gt_lmk_dir) == len(val_dataset):
landmark_comparison(val_dataset, pred_lmk_dir, gt_lmk_dir)
if __name__ == "__main__":
args = parse_args()
main(args)