Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import Any, List, Optional | |
| import torch.nn as nn | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from sf3d.models.network import get_activation | |
| from sf3d.models.utils import BaseModule | |
| class HeadSpec: | |
| name: str | |
| out_channels: int | |
| n_hidden_layers: int | |
| output_activation: Optional[str] = None | |
| output_bias: float = 0.0 | |
| add_to_decoder_features: bool = False | |
| shape: Optional[list[int]] = None | |
| class MultiHeadEstimator(BaseModule): | |
| class Config(BaseModule.Config): | |
| triplane_features: int = 1024 | |
| n_layers: int = 2 | |
| hidden_features: int = 512 | |
| activation: str = "relu" | |
| pool: str = "max" | |
| # Literal["mean", "max"] = "mean" # noqa: F821 | |
| heads: List[HeadSpec] = field(default_factory=lambda: []) | |
| cfg: Config | |
| def configure(self): | |
| layers = [] | |
| cur_features = self.cfg.triplane_features * 3 | |
| for _ in range(self.cfg.n_layers): | |
| layers.append( | |
| nn.Conv2d( | |
| cur_features, | |
| self.cfg.hidden_features, | |
| kernel_size=3, | |
| padding=0, | |
| stride=2, | |
| ) | |
| ) | |
| layers.append(self.make_activation(self.cfg.activation)) | |
| cur_features = self.cfg.hidden_features | |
| self.layers = nn.Sequential(*layers) | |
| assert len(self.cfg.heads) > 0 | |
| heads = {} | |
| for head in self.cfg.heads: | |
| head_layers = [] | |
| for i in range(head.n_hidden_layers): | |
| head_layers += [ | |
| nn.Linear( | |
| self.cfg.hidden_features, | |
| self.cfg.hidden_features, | |
| ), | |
| self.make_activation(self.cfg.activation), | |
| ] | |
| head_layers += [ | |
| nn.Linear( | |
| self.cfg.hidden_features, | |
| head.out_channels, | |
| ), | |
| ] | |
| heads[head.name] = nn.Sequential(*head_layers) | |
| self.heads = nn.ModuleDict(heads) | |
| def make_activation(self, activation): | |
| if activation == "relu": | |
| return nn.ReLU(inplace=True) | |
| elif activation == "silu": | |
| return nn.SiLU(inplace=True) | |
| else: | |
| raise NotImplementedError | |
| def forward( | |
| self, | |
| triplane: Float[Tensor, "B 3 F Ht Wt"], | |
| ) -> dict[str, Any]: | |
| x = self.layers( | |
| triplane.reshape( | |
| triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1] | |
| ) | |
| ) | |
| if self.cfg.pool == "max": | |
| x = x.amax(dim=[-2, -1]) | |
| elif self.cfg.pool == "mean": | |
| x = x.mean(dim=[-2, -1]) | |
| else: | |
| raise NotImplementedError | |
| out = { | |
| ("decoder_" if head.add_to_decoder_features else "") | |
| + head.name: get_activation(head.output_activation)( | |
| self.heads[head.name](x) + head.output_bias | |
| ) | |
| for head in self.cfg.heads | |
| } | |
| for head in self.cfg.heads: | |
| if head.shape: | |
| head_name = ( | |
| "decoder_" if head.add_to_decoder_features else "" | |
| ) + head.name | |
| out[head_name] = out[head_name].reshape(*head.shape) | |
| return out | |