| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class LinearImplicitBackward(nn.Module): | |
| can_torch_compile = True | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return F.linear(input, self.weight, self.bias) | |
| class LinearBackward(nn.Module): | |
| can_torch_compile = True | |
| has_backward = True | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return F.linear(input, self.weight, self.bias) | |
| class LinearNoBackward(nn.Module): | |
| can_torch_compile = True | |
| has_backward = False | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return F.linear(input, self.weight, self.bias) | |
| __all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"] | |