Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os | |
| import random | |
| import torch | |
| import time | |
| from mGPT.config import instantiate_from_config | |
| from os.path import join as pjoin | |
| from mGPT.losses.mgpt import GPTLosses | |
| from mGPT.models.base import BaseModel | |
| from .base import BaseModel | |
| import json | |
| import mGPT.render.matplot.plot_3d_global as plot_3d | |
| class MotionGPT(BaseModel): | |
| """ | |
| Stage 1 Motion Tokenizer | |
| Stage 2 Motion-language pretrian | |
| Stage 3 Motion-language instruction tuning | |
| """ | |
| def __init__(self, | |
| cfg, | |
| datamodule, | |
| lm, | |
| motion_vae, | |
| codebook_size=512, | |
| stage='vae', | |
| debug=True, | |
| condition='text', | |
| task='t2m', | |
| metrics_dict=['TM2TMetrics'], | |
| **kwargs): | |
| self.save_hyperparameters(ignore='datamodule', logger=False) | |
| self.datamodule = datamodule | |
| super().__init__() | |
| # Instantiate motion tokenizer | |
| if motion_vae != None: | |
| self.vae = instantiate_from_config(motion_vae) | |
| # Instantiate motion-language model | |
| self.lm = instantiate_from_config(lm) | |
| # Freeze the motion tokenizer for lm training | |
| if 'lm' in self.hparams.stage: | |
| self.vae.training = False | |
| for p in self.vae.parameters(): | |
| p.requires_grad = False | |
| # Instantiate the losses | |
| self._losses = torch.nn.ModuleDict({ | |
| split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints) | |
| for split in ["losses_train", "losses_test", "losses_val"] | |
| }) | |
| # Data transform | |
| self.feats2joints = datamodule.feats2joints | |
| # Count codebook frequency | |
| self.codePred = [] | |
| self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) | |
| def forward(self, batch, task="t2m"): | |
| texts = batch["text"] | |
| lengths_ref = batch["length"] | |
| # Forward | |
| # texts = ['Generate motion: ' + text for text in texts] | |
| outputs, output_texts = self.lm.generate_direct(texts, do_sample=True) | |
| # Motion Decode | |
| feats_rst_lst = [] | |
| lengths = [] | |
| max_len = 0 | |
| for i in range(len(texts)): | |
| if task == "pred": | |
| motion = self.vae.decode( | |
| torch.cat((batch["motion"][i], outputs[i]))) | |
| elif task in ["t2m", "m2t", "inbetween"]: | |
| motion = self.vae.decode(outputs[i]) | |
| # motion = self.datamodule.denormalize(motion) | |
| lengths.append(motion.shape[1]) | |
| else: | |
| raise NotImplementedError | |
| if motion.shape[1] > max_len: | |
| max_len = motion.shape[1] | |
| if task in ["t2m", "m2t", "pred"]: | |
| feats_rst_lst.append(motion) | |
| elif task == "inbetween": | |
| motion = torch.cat( | |
| (batch["motion_heading"][i][None], | |
| motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3, | |
| ...], batch["motion_tailing"][i][None]), | |
| dim=1) | |
| feats_rst_lst.append(motion) | |
| feats_rst = torch.zeros( | |
| (len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device) | |
| # padding and concat | |
| for i in range(len(feats_rst_lst)): | |
| feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i] | |
| # Recover joints for evaluation | |
| joints_rst = self.feats2joints(feats_rst) | |
| # return set | |
| outputs = { | |
| "texts": output_texts, | |
| "feats": feats_rst, | |
| "joints": joints_rst, | |
| "length": lengths | |
| } | |
| return outputs | |
| def train_lm_forward(self, batch): | |
| tokens_ref = batch["motion"] | |
| texts = batch["text"] | |
| lengths = batch["length"] | |
| tasks = batch["tasks"] | |
| all_captions = batch['all_captions'] | |
| if self.hparams.condition == 'caption': | |
| texts = [random.choice(all_captions[i]) for i in range(len(texts))] | |
| # LLM Forward | |
| outputs = self.lm(texts, tokens_ref, lengths, tasks) | |
| # outputs = self.t2m_gpt.generate(texts) | |
| return {'outputs': outputs} | |
| def val_t2m_forward(self, batch): | |
| feats_ref = batch["motion"] | |
| texts = batch["text"] | |
| lengths = batch["length"] | |
| tasks = None | |
| if self.trainer.datamodule.is_mm: | |
| texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS | |
| feats_ref = feats_ref.repeat_interleave( | |
| self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) | |
| lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS | |
| instructions = pjoin(self.datamodule.hparams.data_root, | |
| 'template_instructions.json') | |
| instructions = json.load(open(instructions, 'r')) | |
| tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts) | |
| if self.hparams.condition == 'caption': | |
| tasks = [{ | |
| 'input': ['<Caption_Placeholder>'], | |
| 'output': [''] | |
| }] * len(texts) | |
| if self.hparams.cfg.DATASET.TASK_PATH: | |
| instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH) | |
| instructions = json.load(open(instructions, 'r')) | |
| tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts) | |
| min_len = lengths.copy() | |
| # Forward | |
| outputs = self.lm.generate_conditional(texts, | |
| lengths=lengths, | |
| stage='test', | |
| tasks=tasks) | |
| # Motion Decode | |
| feats_rst = torch.zeros_like(feats_ref) | |
| for i in range(len(texts)): | |
| outputs[i] = torch.clamp(outputs[i], | |
| 0, | |
| self.hparams.codebook_size - 1, | |
| out=None) | |
| if len(outputs[i]) > 1: | |
| motion = self.vae.decode(outputs[i]) | |
| else: | |
| motion = torch.zeros_like(feats_ref[i:i + 1, ...]) | |
| min_len[i] = min(motion.shape[1], lengths[i]) | |
| # Cut Motion | |
| feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] | |
| # Recover joints for evaluation | |
| joints_ref = self.feats2joints(feats_ref) | |
| joints_rst = self.feats2joints(feats_rst) | |
| # Renorm for evaluation | |
| feats_ref = self.datamodule.renorm4t2m(feats_ref) | |
| feats_rst = self.datamodule.renorm4t2m(feats_rst) | |
| # return set | |
| rs_set = { | |
| "m_ref": feats_ref, | |
| "m_rst": feats_rst, | |
| "joints_ref": joints_ref, | |
| "joints_rst": joints_rst, | |
| "length": min_len | |
| # "length": lengths | |
| } | |
| return rs_set | |
| def val_m2t_forward(self, batch): | |
| self.hparams.metrics_dict = [] | |
| feats_ref = batch["motion"] | |
| texts = batch["text"] | |
| lengths = batch["length"] | |
| all_captions = batch['all_captions'] | |
| # Motion Encode | |
| motion_tokens = [] | |
| lengths_tokens = [] | |
| for i in range(len(feats_ref)): | |
| motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) | |
| motion_tokens.append(motion_token[0]) | |
| lengths_tokens.append(motion_token.shape[1]) | |
| # Forward | |
| outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, | |
| lengths=lengths_tokens, | |
| task="m2t", | |
| stage='test') | |
| # return set | |
| rs_set = { | |
| "m_ref": feats_ref, | |
| "t_ref": all_captions, | |
| # "t_ref": texts, | |
| "t_pred": outputs, | |
| "length": lengths | |
| } | |
| return rs_set | |
| def val_m2m_forward(self, batch, task="pred"): | |
| feats_ref = batch["motion"] | |
| lengths = batch["length"] | |
| # Motion Encode | |
| motion_tokens = [] | |
| lengths_tokens = [] | |
| for i in range(len(feats_ref)): | |
| motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) | |
| motion_tokens.append(motion_token[0]) | |
| # Forward | |
| outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, | |
| lengths=lengths, | |
| task=task, | |
| stage='test') | |
| # Motion Decode | |
| feats_rst = torch.zeros_like(feats_ref) | |
| min_len = lengths.copy() | |
| for i in range(len(lengths)): | |
| outputs[i] = torch.clamp(outputs[i], | |
| 0, | |
| self.hparams.codebook_size - 1, | |
| out=None) | |
| if len(outputs[i]) > 1: | |
| motion = self.vae.decode(outputs[i]) | |
| else: | |
| motion = torch.zeros_like(feats_ref[i:i + 1, ...]) | |
| min_len[i] = min(motion.shape[1], lengths[i]) | |
| # Cut Motion | |
| feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] | |
| # Recover joints for evaluation | |
| joints_ref = self.feats2joints(feats_ref) | |
| joints_rst = self.feats2joints(feats_rst) | |
| # Renorm for evaluation | |
| feats_ref = self.datamodule.renorm4t2m(feats_ref) | |
| feats_rst = self.datamodule.renorm4t2m(feats_rst) | |
| # return set | |
| rs_set = { | |
| "m_ref": feats_ref, | |
| "m_rst": feats_rst, | |
| "joints_ref": joints_ref, | |
| "joints_rst": joints_rst, | |
| "length": min_len | |
| # "length": lengths | |
| } | |
| return rs_set | |
| def train_vae_forward(self, batch): | |
| # batch detach | |
| feats_ref = batch["motion"] | |
| joints_ref = self.feats2joints(feats_ref) | |
| # motion encode & decode | |
| feats_rst, loss_commit, perplexity = self.vae(feats_ref) | |
| joints_rst = self.feats2joints(feats_rst) | |
| # return set | |
| rs_set = { | |
| "m_ref": feats_ref, | |
| "joints_ref": joints_ref, | |
| "m_rst": feats_rst, | |
| "joints_rst": joints_rst, | |
| "loss_commit": loss_commit, | |
| "perplexity": perplexity, | |
| } | |
| return rs_set | |
| def val_vae_forward(self, batch, split="train"): | |
| # Detach batch | |
| feats_ref = batch["motion"] | |
| lengths = batch["length"] | |
| # Repeat for multimodal evaluation | |
| if self.trainer.datamodule.is_mm: | |
| feats_ref = feats_ref.repeat_interleave( | |
| self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) | |
| lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS | |
| # Motion encode & decode | |
| feats_rst = torch.zeros_like(feats_ref) | |
| for i in range(len(feats_ref)): | |
| if lengths[i] == 0: | |
| continue | |
| feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]]) | |
| feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred | |
| code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]]) | |
| # codeFre_pred = torch.bincount(code_pred[0], | |
| # minlength=self.hparams.codebook_size).to( | |
| # self.codeFrequency.device) | |
| # self.codePred.append(code_pred[0]) | |
| # self.codeFrequency += codeFre_pred | |
| # np.save('../memData/results/codeFrequency.npy', | |
| # self.codeFrequency.cpu().numpy()) | |
| # Recover joints for evaluation | |
| joints_ref = self.feats2joints(feats_ref) | |
| joints_rst = self.feats2joints(feats_rst) | |
| # Renorm for evaluation | |
| feats_ref = self.datamodule.renorm4t2m(feats_ref) | |
| feats_rst = self.datamodule.renorm4t2m(feats_rst) | |
| # Return set | |
| rs_set = { | |
| "m_ref": feats_ref, | |
| "joints_ref": joints_ref, | |
| "m_rst": feats_rst, | |
| "joints_rst": joints_rst, | |
| "length": lengths, | |
| } | |
| return rs_set | |
| def allsplit_step(self, split: str, batch, batch_idx): | |
| # Compute the losses | |
| loss = None | |
| if self.hparams.stage == "vae" and split in ["train", "val"]: | |
| rs_set = self.train_vae_forward(batch) | |
| loss = self._losses['losses_' + split].update(rs_set) | |
| elif self.hparams.stage in ["lm_instruct", "lm_pretrain" | |
| ] and split in ["train"]: | |
| rs_set = self.train_lm_forward(batch) | |
| loss = self._losses['losses_' + split].update(rs_set) | |
| elif self.hparams.stage == 'lm_rl' and split in ['train']: | |
| rs_set = self.train_rl_forward(batch) | |
| loss = None | |
| # Compute the metrics | |
| if split in ["val", "test"]: | |
| if self.hparams.stage == "vae": | |
| rs_set = self.val_vae_forward(batch, split) | |
| elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]: | |
| if self.hparams.task == "t2m": | |
| rs_set = self.val_t2m_forward(batch) | |
| elif self.hparams.task == "m2t": | |
| rs_set = self.val_m2t_forward(batch) | |
| elif self.hparams.task in ["m2m", "pred", "inbetween"]: | |
| rs_set = self.val_m2m_forward(batch, self.hparams.task) | |
| if self.hparams.task not in ["m2t"]: | |
| # MultiModality evaluation sperately | |
| if self.trainer.datamodule.is_mm: | |
| metrics_dicts = ['MMMetrics'] | |
| else: | |
| metrics_dicts = self.hparams.metrics_dict | |
| if self.hparams.task not in ['pred', 'inbetween']: | |
| metrics_dicts.remove('PredMetrics') | |
| for metric in metrics_dicts: | |
| lengths = batch['length'] | |
| if metric == "TemosMetric": | |
| getattr(self.metrics, | |
| metric).update(rs_set["joints_rst"], | |
| rs_set["joints_ref"], lengths) | |
| elif metric == "TM2TMetrics": | |
| if self.hparams.stage in [ | |
| "lm_instruct", "lm_pretrain", "lm_rl" | |
| ]: | |
| word_embs = batch['word_embs'] | |
| pos_ohot = batch['pos_ohot'] | |
| text_lengths = batch['text_len'] | |
| if self.trainer.datamodule.is_mm: | |
| word_embs = word_embs.repeat_interleave( | |
| self.hparams.cfg.METRIC.MM_NUM_REPEATS, | |
| dim=0) | |
| pos_ohot = pos_ohot.repeat_interleave( | |
| self.hparams.cfg.METRIC.MM_NUM_REPEATS, | |
| dim=0) | |
| text_lengths = text_lengths.repeat_interleave( | |
| self.hparams.cfg.METRIC.MM_NUM_REPEATS, | |
| dim=0) | |
| else: | |
| word_embs = None | |
| pos_ohot = None | |
| text_lengths = None | |
| getattr(self.metrics, metric).update( | |
| feats_ref=rs_set["m_ref"], | |
| feats_rst=rs_set["m_rst"], | |
| lengths_ref=lengths, | |
| lengths_rst=rs_set['length'], | |
| word_embs=word_embs, | |
| pos_ohot=pos_ohot, | |
| text_lengths=text_lengths, | |
| ) | |
| elif metric == "UncondMetrics": | |
| getattr(self.metrics, metric).update( | |
| recmotion_embeddings=rs_set["lat_rm"], | |
| gtmotion_embeddings=rs_set["lat_m"], | |
| lengths=lengths, | |
| ) | |
| elif metric == "MRMetrics": | |
| getattr(self.metrics, | |
| metric).update(rs_set["joints_rst"], | |
| rs_set["joints_ref"], lengths) | |
| elif metric == "PredMetrics": | |
| getattr(self.metrics, | |
| metric).update(rs_set["joints_rst"], | |
| rs_set["joints_ref"], lengths) | |
| elif metric == "MMMetrics": | |
| # pass | |
| getattr(self.metrics, | |
| metric).update(rs_set["m_rst"], | |
| rs_set['length']) | |
| else: | |
| raise TypeError(f"Not support this metric {metric}") | |
| elif self.hparams.task == "m2t" and self.hparams.stage in [ | |
| "lm_instruct", "lm_pretrain", "lm_rl" | |
| ]: | |
| self.hparams.metrics_dict = metrics_dicts = ['M2TMetrics'] | |
| for metric in metrics_dicts: | |
| if metric == "M2TMetrics": | |
| getattr(self.metrics, metric).update( | |
| feats_ref=rs_set["m_ref"], | |
| pred_texts=rs_set["t_pred"], | |
| gt_texts=batch["all_captions"], | |
| lengths=rs_set['length'], | |
| word_embs=batch["word_embs"], | |
| pos_ohot=batch["pos_ohot"], | |
| text_lengths=batch["text_len"], | |
| ) | |
| # return forward output rather than loss during test | |
| if split in ["test"]: | |
| if self.hparams.task == "t2m": | |
| return rs_set["joints_rst"], rs_set["length"], rs_set[ | |
| "joints_ref"] | |
| # pass | |
| elif self.hparams.task == "m2t": | |
| return rs_set["t_pred"], batch["length"] | |
| # return batch["length"] | |
| return loss | |