LLMFromScratch / feedForward.py
Ashish Reddy
added all blocks of codes
dde4f12
raw
history blame contribute delete
776 Bytes
import torch, torch.nn as nn
batch_size = 64
max_len = 256
d_model = 384
n_head = 6
d_q = int(d_model / n_head)
dropout = 0.2
class FeedForward(nn.Module):
def __init__(self, d_model):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(d_model, 4*d_model), # Expand dimension, so when applying ReLU, you mitigate loss of information issue
nn.ReLU(),
nn.Linear(4*d_model, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.seq(x)
if __name__ == '__main__':
x = torch.randn(batch_size, max_len, d_model)
ffwd = FeedForward(d_model)
output = ffwd(x)
print("Input shape:", x.shape)
print("Output shape from FeedForward network:", output.shape)