Matrix Multiplications
I use einsum.
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
C = np.einsum('ij, jk -> ik', A, B)Softmax
In math,
def softmax(x):
# We calculate the e^x for each element
e_x = np.exp(x - x.max()) # For stability
return e_x / e_x.sum()Transformer
Caution
Beware of the dimensions.
Embedding
Two steps here, embedding lookups + positional encoding.
Look-ups
Let
Info
If
and is a one-hot vector , instead of a scalar, then . This is a full-fledged matrix multiplication, and not a lookup (although it essentially work as a lookup).
Therefore, seq len (or ctx window), and d_model.
Positional Encoding
Positional encoding, as per the transformer formula is this
where
Note that the [None, :] means that you add a dimension to become (1, d/2).
def positional_encoding(n, d):
PE = np.zeros((n, d))
pos = np.arange(n)[:, None] # (n, 1)
i = np.arange(0, d, 2)[None, :] # (1, d/2)
div = 10000 ** (i / d)
PE[:, 0::2] = np.sin(pos / div) # even dims
PE[:, 1::2] = np.cos(pos / div) # odd dims
return PE # (n, d)
def embed(tokens, E, n, d):
# In this case, tokens: (n,)
X = E[tokens] # (n, d)
X = X + positional_encoding(n, d) # (n, d)
return XAttention
Normal Attention
The attention mechanism.
Initially, the
Once we multiply
Softmax is applied row-wise — every row in the
Finally, we do matrix multiplication with
def attn(Q, K, V, mask=None):
d_k = K.shape[-1]
scores = np.einsum('nk,mk->nm', Q, K)
scores /= np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, -1e9)
weights = softmax(scores, axis=-1)
return np.einsum('nm,mv->nv', weights, V) Read Lecture 4 — Attention & Transformer for more details.
Multi-Head Attention
The difference with normal attention the
After attention, we concatenate all the heads, and then output a projection by multiplying with
def mha(X, W_Q, W_K, W_V, W_O, h):
B, n, d = X.shape
d_k = W_Q.shape[-1] // h
d_v = W_V.shape[-1] // h
# 1. Project + reshape into heads in one einsum
Q = np.einsum('nd,dhk->hnk', X, W_Q.reshape(d, h, d_k))
K = np.einsum('nd,dhk->hnk', X, W_K.reshape(d, h, d_k))
V = np.einsum('nd,dhv->hnv', X, W_V.reshape(d, h, d_v))
# 2. Attention scores
scores = np.einsum('hnk,hmk->hnm', Q, K) / np.sqrt(d_k)
weights = softmax(scores, axis=-1)
# 3. Weighted sum
attn = np.einsum('hnm,hmv->hnv', weights, V)
# 4. Concat heads
concat = np.einsum('hnv->nhv', attn).reshape(n, h * d_v)
# 5. Project to output
out = np.einsum('nx,xd->nd', concat, W_O)
return outWhat’s important is that everything must return to
LayerNorm
Right after MHA, we do LayerNorm.
where
def layer_norm(x, gamma, beta, eps=1e-5):
# Remember that x: (n, d)
mean = x.mean(axis=-1, keepdims=True) # (n, 1)
var = x.var(axis=-1, keepdims=True) # (n, 1)
x_hat = (x - mean) / np.sqrt(var + eps) # (n, d)
return gamma * x_hat + betaseq len).
Feed-Forward Network (FFN)
def ffn(x, W1, b1, W2, b2):
z = np.einsum('nd,df->nf', x, W1) + b1 # (n, dff)
h = np.maximum(0, z) # ReLU
return np.einsum('nf,fd->nd', h, W2) + b2 # (n, d)Normal weights + activation + weights. Remember that only the hidden layer output goes through activation function. The final weights don’t. Why would you zero out half of the signal?
One Transformer Block
def transformer_block(x, W_Q, W_K, W_V, W_O, h,
gamma1, beta1, W1, b1, W2, b2,
gamma2, beta2):
# Remember. x: (n, d)
attn_out = mha(x, W_Q, W_K, W_V, W_O, h) # (n,d)
x = layer_norm(x + attn_out, gamma1, beta1) # (n, d)
ffn_out = ffn(x, W1, b1, W2, b2) # (n, d)
x = layer_norm(x + ffn_out, gamma2, beta2) # (n, d)
return xAttention → Add + Layer Norm → FFN → Add + Layer Norm. By adding, I mean adding to the residual stream.
The Entire Transformer Mechanism
def transformer(tokens, E, W_Q, W_K, W_V, W_O, h, gamma1, beta1,
W1, b1, W2, b2, gamma2, beta2, n_layers=6):
# tokens: (n,) -- as above
n = len(tokens)
d = E.shape[1]
# 1. Embed + PE
x = embed(tokens, E, n, d) # (n, d)
# 2. Stack transformer blocks
for _ in range(n_layers):
x = transformer_block(
x, W_Q, W_K, W_V, W_O, h, gamma1, beta1
W1, b1, W2, b2, gamma2, beta2) # (n, d)
# 3. Final logits
logits = np.einsum('nd,vd->nv', x, E) # (n, |V|)
# 4. Softmax
probs = softmax(logits, axis=-1) #(n, |V|)
return probs