Skip to content
Snippets Groups Projects

Fix small typos in the model architecture

Merged Alexandru-Mihai GHERGHESCU requested to merge fix/model into main
1 file
+ 2
4
Compare changes
  • Side-by-side
  • Inline
+ 2
4
@@ -146,9 +146,7 @@ class TransformerBlock(nn.Module):
mask = torch.triu(mask, diagonal=1).to(x.device)
# compute normed attention and normed feed forward
nned = self.attention_norm(x)
attn = self.attention(nned, mask)
x = x + self.dropout(attn)
x = x + self.dropout(self.attention(self.attention_norm(x), mask))
x = x + self.dropout(self.ffn(self.ffn_norm(x)))
return x
@@ -186,7 +184,7 @@ class Transformer(nn.Module):
self.positional_encodings = self._compute_freqs(dim)
self.input_dropout = nn.Dropout(p=p_drop)
self.layers = nn.ModuleList([TransformerBlock(n_heads, dim, p_drop) for _ in range(n_layers)])
self.output = nn.Linear(dim, vocab_sz)
self.output = nn.Linear(dim, vocab_sz, bias=False)
self.output_norm = Norm(dim)
if weight_tying:
Loading