Spaces:
Runtime error
Runtime error
| import itertools | |
| import functools | |
| import math | |
| import multiprocessing | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.rcParams.update({'font.size': 24}) | |
| matplotlib.rcParams.update({ | |
| "text.usetex": True, | |
| "text.latex.preamble": r"\usepackage{biolinum} \usepackage{libertineRoman} \usepackage{libertineMono} \usepackage{biolinum} \usepackage[libertine]{newtxmath}", | |
| 'ps.usedistiller': "xpdf", | |
| }) | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| import numpy as np | |
| import tqdm | |
| from scipy.stats import wasserstein_distance | |
| import pose_estimation | |
| def cub(x, a, b, c): | |
| x2 = x * x | |
| x3 = x2 * x | |
| y = a * x3 + b * x2 + c * x | |
| return y | |
| def subsample(a, p=0.0005, seed=0): | |
| np.random.seed(seed) | |
| N = len(a) | |
| inds = np.random.choice(range(N), size=int(p * N)) | |
| a = a[inds].copy() | |
| return a | |
| def read_cos_opt(path, fname="cos_hist.npy"): | |
| cos_opt = [] | |
| for p in Path(path).rglob(fname): | |
| d = np.load(p) | |
| cos_opt.append(d) | |
| cos_opt = np.array(cos_opt) | |
| return cos_opt | |
| def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None): | |
| cos_opt = read_cos_opt(cos_opt_dir) | |
| angle_opt = np.arccos(cos_opt) | |
| angle_opt2 = cub(angle_opt, *params) | |
| cos_opt2 = np.cos(angle_opt2) | |
| cos_smpl = np.load(hist_smpl_fpath) | |
| # cos_smpl = subsample(cos_smpl) | |
| print(cos_smpl.shape) | |
| cos_smpl = np.clip(cos_smpl, -1, 1) | |
| cos_opt = angle_opt | |
| cos_opt2 = angle_opt2 | |
| cos_smpl = np.arccos(cos_smpl) | |
| cos_opt = 180 / math.pi * cos_opt | |
| cos_opt2 = 180 / math.pi * cos_opt2 | |
| cos_smpl = 180 / math.pi * cos_smpl | |
| max_range = 90 # math.pi / 2 | |
| xticks = [0, 15, 30, 45, 60, 75, 90] | |
| for idx, bone in enumerate(pose_estimation.SKELETON): | |
| i, j = bone | |
| i_name = pose_estimation.KPS[i] | |
| j_name = pose_estimation.KPS[j] | |
| if i_name != "Left Upper Leg": | |
| continue | |
| name = f"{i_name}_{j_name}" | |
| gs = gridspec.GridSpec(2, 4) | |
| fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300) | |
| ax0 = fig.add_subplot(gs[0, 0]) | |
| ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True) | |
| ax0.set_xticks(xticks) | |
| ax0.tick_params(labelbottom=False, labelleft=True) | |
| ax1 = fig.add_subplot(gs[1, 0], sharex=ax0) | |
| ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True) | |
| ax1.set_xticks(xticks) | |
| if xy is not None: | |
| ax2 = fig.add_subplot(gs[:, 1:3]) | |
| ax2.plot(xy[0], xy[1], linewidth=8) | |
| ax2.plot(xy[0], xy[0], linewidth=4, linestyle="dashed") | |
| ax2.set_xticks(xticks) | |
| ax2.set_yticks(xticks) | |
| ax3 = fig.add_subplot(gs[0, 3], sharey=ax0) | |
| ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True) | |
| ax3.set_xticks(xticks) | |
| ax3.tick_params(labelbottom=False, labelleft=False) | |
| ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1) | |
| alpha = 0.5 | |
| ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{B}_i$", alpha=alpha) | |
| ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$f(\mathcal{B}_i)$", alpha=alpha) | |
| ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{A}_i$", alpha=alpha) | |
| ax4.set_xticks(xticks) | |
| ax4.tick_params(labelbottom=True, labelleft=False) | |
| ax4.legend() | |
| fig.savefig(out_dir / f"hist_{name}.png") | |
| plt.close() | |
| def kldiv(p_hist, q_hist): | |
| wd = wasserstein_distance(p_hist, q_hist) | |
| return wd | |
| def calc_histogram(x, bins=10, range=(0, 1)): | |
| h, _ = np.histogram(x, bins=bins, range=range, density=True) | |
| return h | |
| def step(params, angles_opt, p_hist, bone_idx=None): | |
| if sum(params) > 1: | |
| return math.inf, params | |
| kl = 0 | |
| for i, _ in enumerate(pose_estimation.SKELETON): | |
| if bone_idx is not None and i != bone_idx: | |
| continue | |
| angles_opt2 = cub(angles_opt[:, i], *params) | |
| if angles_opt2.max() > 1 or angles_opt2.min() < 0: | |
| kl = math.inf | |
| break | |
| q_hist = calc_histogram(angles_opt2) | |
| kl += kldiv(p_hist[i], q_hist) | |
| return kl, params | |
| def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None): | |
| cos_opt = read_cos_opt(cos_opt_dir) | |
| angles_opt = np.arccos(cos_opt) / (math.pi / 2) | |
| cos_smpl = np.load(hist_smpl_fpath) | |
| # cos_smpl = subsample(cos_smpl) | |
| print(cos_smpl.shape) | |
| cos_smpl = np.clip(cos_smpl, -1, 1) | |
| mask = cos_smpl <= 1 | |
| assert np.all(mask), (~mask).mean() | |
| mask = cos_smpl >= 0 | |
| assert np.all(mask), (~mask).mean() | |
| angles_smpl = np.arccos(cos_smpl) / (math.pi / 2) | |
| p_hist = [ | |
| calc_histogram(angles_smpl[:, i]) | |
| for i, _ in enumerate(pose_estimation.SKELETON) | |
| ] | |
| with multiprocessing.Pool(8) as p: | |
| results = list( | |
| tqdm.tqdm( | |
| p.imap_unordered( | |
| functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx), | |
| itertools.product( | |
| np.linspace(0, 20, 100), | |
| np.linspace(-20, 20, 200), | |
| np.linspace(-20, 1, 100), | |
| ), | |
| ), | |
| total=(100 * 200 * 100), | |
| ) | |
| ) | |
| kls, params = zip(*results) | |
| ind = np.argmin(kls) | |
| best_params = params[ind] | |
| print(kls[ind], best_params) | |
| inds = np.argsort(kls) | |
| for i in inds[:10]: | |
| print(kls[i]) | |
| print(params[i]) | |
| print() | |
| return best_params | |
| def main(): | |
| cos_opt_dir = "paper_single2_150mse" | |
| hist_smpl_fpath = "./data/hist_smpl.npy" | |
| # hist_smpl_fpath = "./testtest.npy" | |
| params = optimize(cos_opt_dir, hist_smpl_fpath) | |
| # params = (1.2121212121212122, -1.105527638190953, 0.787878787878789) | |
| # params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633) | |
| print(params) | |
| x = np.linspace(0, math.pi / 2, 100) | |
| y = cub(x / (math.pi / 2), *params) * (math.pi / 2) | |
| x = x * 180 / math.pi | |
| y = y * 180 / math.pi | |
| out_dir = Path("hists") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y)) | |
| plt.figure(figsize=(4, 4), dpi=300) | |
| plt.plot(x, y, linewidth=6) | |
| plt.plot(x, x, linewidth=2, linestyle="dashed") | |
| xticks = [0, 15, 30, 45, 60, 75, 90] | |
| plt.xticks(xticks) | |
| plt.yticks(xticks) | |
| plt.axis("equal") | |
| plt.tight_layout() | |
| plt.savefig(out_dir / "new_out.png") | |
| if __name__ == "__main__": | |
| main() | |