Spaces:
Runtime error
Runtime error
| import pytest | |
| import torch | |
| from mmcv import Config | |
| from risk_biased.models.mlp import MLP | |
| def params(): | |
| torch.manual_seed(0) | |
| cfg = Config() | |
| cfg.batch_size = 4 | |
| cfg.input_dim = 10 | |
| cfg.output_dim = 15 | |
| cfg.latent_dim = 3 | |
| cfg.h_dim = 64 | |
| cfg.num_h_layers = 2 | |
| cfg.device = "cpu" | |
| cfg.is_mlp_residual = True | |
| return cfg | |
| def test_mlp(params): | |
| mlp = MLP( | |
| params.input_dim, | |
| params.output_dim, | |
| params.h_dim, | |
| params.num_h_layers, | |
| params.is_mlp_residual, | |
| ) | |
| input = torch.rand(params.batch_size, params.input_dim) | |
| output = mlp(input) | |
| # check shape | |
| assert output.shape == (params.batch_size, params.output_dim) | |