Source code for fit.nn.modules.attention

"""
Core Attention Mechanisms for FIT Framework

This module implements the fundamental attention mechanisms that power
modern deep learning: scaled dot-product attention, multi-head attention,
and various attention variants.

The implementation is educational (showing how attention really works)
while being efficient and production-ready.
"""

import numpy as np
from typing import Optional, Tuple, Union
import math

from fit.core.tensor import Tensor
from fit.nn.modules.base import Layer
from fit.nn.modules.linear import Linear
from fit.nn.modules.activation import Softmax, Dropout


[docs] class ScaledDotProductAttention(Layer): """ Scaled Dot-Product Attention: the core of all attention mechanisms. Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V This is the fundamental building block that makes Transformers work. """
[docs] def __init__(self, dropout: float = 0.1, temperature: float = 1.0): """ Initialize scaled dot-product attention. Args: dropout: Dropout probability for attention weights temperature: Temperature scaling factor (higher = more uniform attention) """ super().__init__() self.dropout = Dropout(dropout) if dropout > 0 else None self.temperature = temperature self.softmax = Softmax()
[docs] def forward( self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, return_attention: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ Apply scaled dot-product attention. Args: query: Query tensor (batch_size, seq_len_q, d_k) key: Key tensor (batch_size, seq_len_k, d_k) value: Value tensor (batch_size, seq_len_v, d_v) mask: Optional attention mask to prevent attention to certain positions return_attention: Whether to return attention weights Returns: Output tensor (batch_size, seq_len_q, d_v) and optionally attention weights """ batch_size, seq_len_q, d_k = query.data.shape seq_len_k = key.data.shape[1] # Step 1: Compute attention scores = Q * K^T / sqrt(d_k) scores = self._compute_attention_scores(query, key, d_k) # Step 2: Apply mask if provided if mask is not None: scores = self._apply_mask(scores, mask) # Step 3: Apply softmax to get attention weights attention_weights = self.softmax(scores) # Step 4: Apply dropout to attention weights (if training) if self.dropout is not None: attention_weights = self.dropout(attention_weights) # Step 5: Apply attention to values output = self._apply_attention_to_values(attention_weights, value) if return_attention: return output, attention_weights else: return output
def _compute_attention_scores(self, query: Tensor, key: Tensor, d_k: int) -> Tensor: """Compute raw attention scores.""" # Q * K^T scores = query @ self._transpose_last_two_dims(key) # Scale by sqrt(d_k) and temperature scale_factor = math.sqrt(d_k) * self.temperature scores = scores / scale_factor return scores def _transpose_last_two_dims(self, tensor: Tensor) -> Tensor: """Transpose the last two dimensions of a tensor.""" # For (batch, seq, dim) -> (batch, dim, seq) data = tensor.data transposed_data = np.transpose(data, (0, 2, 1)) return Tensor(transposed_data, requires_grad=tensor.requires_grad) def _apply_mask(self, scores: Tensor, mask: Tensor) -> Tensor: """Apply attention mask by setting masked positions to large negative values.""" # Mask should be 1 for positions to attend to, 0 for positions to ignore masked_scores = scores.data.copy() masked_scores[mask.data == 0] = -1e9 # Large negative value return Tensor(masked_scores, requires_grad=scores.requires_grad) def _apply_attention_to_values( self, attention_weights: Tensor, value: Tensor ) -> Tensor: """Apply attention weights to values.""" return attention_weights @ value
[docs] class MultiHeadAttention(Layer): """ Multi-Head Attention: the key innovation that makes Transformers so powerful. Instead of using a single attention function, we use multiple "heads" that can focus on different types of relationships in the data. """
[docs] def __init__( self, d_model: int, num_heads: int, dropout: float = 0.1, bias: bool = True ): """ Initialize multi-head attention. Args: d_model: Model dimension (must be divisible by num_heads) num_heads: Number of attention heads dropout: Dropout probability bias: Whether to use bias in linear projections """ super().__init__() if d_model % num_heads != 0: raise ValueError( f"d_model ({d_model}) must be divisible by num_heads ({num_heads})" ) self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Dimension per head # Linear projections for Q, K, V self.w_q = Linear(d_model, d_model, bias=bias) self.w_k = Linear(d_model, d_model, bias=bias) self.w_v = Linear(d_model, d_model, bias=bias) # Output projection self.w_o = Linear(d_model, d_model, bias=bias) # Core attention mechanism self.attention = ScaledDotProductAttention(dropout=dropout) # Add as children for parameter collection self.add_child(self.w_q) self.add_child(self.w_k) self.add_child(self.w_v) self.add_child(self.w_o) self.add_child(self.attention)
[docs] def forward( self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, return_attention: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ Apply multi-head attention. Args: query: Query tensor (batch_size, seq_len, d_model) key: Key tensor (batch_size, seq_len, d_model) value: Value tensor (batch_size, seq_len, d_model) mask: Optional attention mask return_attention: Whether to return attention weights Returns: Output tensor and optionally attention weights """ batch_size, seq_len, d_model = query.data.shape # Step 1: Linear projections and reshape for multiple heads Q = self._project_and_reshape(self.w_q(query), batch_size, seq_len) K = self._project_and_reshape(self.w_k(key), batch_size, seq_len) V = self._project_and_reshape(self.w_v(value), batch_size, seq_len) # Step 2: Apply attention to each head if return_attention: attn_output, attention_weights = self.attention( Q, K, V, mask=mask, return_attention=True ) else: attn_output = self.attention(Q, K, V, mask=mask) attention_weights = None # Step 3: Concatenate heads and apply output projection output = self._concatenate_heads(attn_output, batch_size, seq_len) output = self.w_o(output) if return_attention: return output, attention_weights else: return output
def _project_and_reshape(self, x: Tensor, batch_size: int, seq_len: int) -> Tensor: """Project and reshape tensor for multi-head attention.""" # (batch_size, seq_len, d_model) -> (batch_size, seq_len, num_heads, d_k) # -> (batch_size, num_heads, seq_len, d_k) reshaped = x.data.reshape(batch_size, seq_len, self.num_heads, self.d_k) transposed = np.transpose(reshaped, (0, 2, 1, 3)) # Flatten to (batch_size * num_heads, seq_len, d_k) for attention computation flattened = transposed.reshape(batch_size * self.num_heads, seq_len, self.d_k) return Tensor(flattened, requires_grad=x.requires_grad) def _concatenate_heads(self, x: Tensor, batch_size: int, seq_len: int) -> Tensor: """Concatenate attention heads back together.""" # (batch_size * num_heads, seq_len, d_k) -> (batch_size, num_heads, seq_len, d_k) reshaped = x.data.reshape(batch_size, self.num_heads, seq_len, self.d_k) # (batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, num_heads, d_k) transposed = np.transpose(reshaped, (0, 2, 1, 3)) # (batch_size, seq_len, num_heads, d_k) -> (batch_size, seq_len, d_model) concatenated = transposed.reshape(batch_size, seq_len, self.d_model) return Tensor(concatenated, requires_grad=x.requires_grad)
[docs] class SelfAttention(Layer): """ Self-Attention: a special case where query, key, and value are the same. This allows the model to relate different positions in a single sequence, which is crucial for understanding context and long-range dependencies. """
[docs] def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1): """ Initialize self-attention layer. Args: d_model: Model dimension num_heads: Number of attention heads dropout: Dropout probability """ super().__init__() self.multihead_attn = MultiHeadAttention(d_model, num_heads, dropout) self.add_child(self.multihead_attn)
[docs] def forward( self, x: Tensor, mask: Optional[Tensor] = None, return_attention: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ Apply self-attention to input sequence. Args: x: Input tensor (batch_size, seq_len, d_model) mask: Optional attention mask return_attention: Whether to return attention weights Returns: Output tensor and optionally attention weights """ # In self-attention, query = key = value = x return self.multihead_attn( query=x, key=x, value=x, mask=mask, return_attention=return_attention )
[docs] class CrossAttention(Layer): """ Cross-Attention: attention between two different sequences. Used in encoder-decoder architectures where the decoder attends to the encoder's output. Query comes from decoder, Key and Value from encoder. """
[docs] def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1): """ Initialize cross-attention layer. Args: d_model: Model dimension num_heads: Number of attention heads dropout: Dropout probability """ super().__init__() self.multihead_attn = MultiHeadAttention(d_model, num_heads, dropout) self.add_child(self.multihead_attn)
[docs] def forward( self, query: Tensor, # From decoder key_value: Tensor, # From encoder mask: Optional[Tensor] = None, return_attention: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ Apply cross-attention between query and key_value sequences. Args: query: Query tensor from decoder (batch_size, seq_len_q, d_model) key_value: Key and value tensor from encoder (batch_size, seq_len_kv, d_model) mask: Optional attention mask return_attention: Whether to return attention weights Returns: Output tensor and optionally attention weights """ # In cross-attention, key = value = encoder output, query = decoder input return self.multihead_attn( query=query, key=key_value, value=key_value, mask=mask, return_attention=return_attention, )
[docs] class CausalSelfAttention(Layer): """ Causal (Masked) Self-Attention: prevents positions from attending to future positions. Essential for autoregressive models like GPT, where we want to predict the next token without "cheating" by looking at future tokens. """
[docs] def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1): """ Initialize causal self-attention. Args: d_model: Model dimension num_heads: Number of attention heads dropout: Dropout probability """ super().__init__() self.multihead_attn = MultiHeadAttention(d_model, num_heads, dropout) self.add_child(self.multihead_attn)
[docs] def forward( self, x: Tensor, return_attention: bool = False ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ Apply causal self-attention to input sequence. Args: x: Input tensor (batch_size, seq_len, d_model) return_attention: Whether to return attention weights Returns: Output tensor and optionally attention weights """ batch_size, seq_len, d_model = x.data.shape # Create causal mask (lower triangular matrix) mask = self._create_causal_mask(seq_len) return self.multihead_attn( query=x, key=x, value=x, mask=mask, return_attention=return_attention )
def _create_causal_mask(self, seq_len: int) -> Tensor: """Create causal mask to prevent attention to future positions.""" # Lower triangular matrix: 1s below and on diagonal, 0s above mask = np.tril(np.ones((seq_len, seq_len))) # Expand for batch dimension mask = np.expand_dims(mask, 0) # (1, seq_len, seq_len) return Tensor(mask, requires_grad=False)
# Utility functions for attention
[docs] def create_padding_mask(sequences: Tensor, pad_token_id: int = 0) -> Tensor: """ Create padding mask to ignore padded positions in variable-length sequences. Args: sequences: Input sequences (batch_size, seq_len) pad_token_id: Token ID used for padding Returns: Padding mask (batch_size, 1, seq_len) """ # Create mask: 1 for real tokens, 0 for padding mask = (sequences.data != pad_token_id).astype(np.float32) # Add dimension for broadcasting with attention scores mask = np.expand_dims(mask, 1) # (batch_size, 1, seq_len) return Tensor(mask, requires_grad=False)
[docs] def create_look_ahead_mask(seq_len: int) -> Tensor: """ Create look-ahead mask for causal attention. Args: seq_len: Sequence length Returns: Look-ahead mask (1, seq_len, seq_len) """ mask = np.tril(np.ones((seq_len, seq_len))) mask = np.expand_dims(mask, 0) # Add batch dimension return Tensor(mask, requires_grad=False)
[docs] def attention_visualization_helper(attention_weights: Tensor, tokens: list = None): """ Helper function to visualize attention weights. Args: attention_weights: Attention weights (batch_size, num_heads, seq_len, seq_len) tokens: Optional list of tokens for labeling Returns: Dictionary with visualization data """ # Take first batch and average over heads for simplicity weights = attention_weights.data[0] # (num_heads, seq_len, seq_len) avg_weights = np.mean(weights, axis=0) # (seq_len, seq_len) viz_data = { "attention_matrix": avg_weights, "tokens": tokens or [f"Token_{i}" for i in range(avg_weights.shape[0])], "max_attention": np.max(avg_weights), "attention_entropy": -np.sum(avg_weights * np.log(avg_weights + 1e-9), axis=1), } return viz_data
# Testing and demonstration functions
[docs] def demonstrate_attention(): """Demonstrate how attention mechanisms work with simple examples.""" print("🔍 Attention Mechanisms Demonstration") print("=" * 50) # Create simple test data batch_size, seq_len, d_model = 2, 4, 8 # Random input sequences np.random.seed(42) x = Tensor(np.random.randn(batch_size, seq_len, d_model), requires_grad=True) print(f"Input shape: {x.data.shape}") print() # Test Self-Attention print("🎯 Self-Attention:") self_attn = SelfAttention(d_model=d_model, num_heads=2) output, attention_weights = self_attn(x, return_attention=True) print(f"Output shape: {output.data.shape}") print(f"Attention weights shape: {attention_weights.data.shape}") print(f"Attention weights (first head, first batch):") print(attention_weights.data[0][:seq_len, :seq_len].round(3)) print() # Test Causal Self-Attention print("🎯 Causal Self-Attention:") causal_attn = CausalSelfAttention(d_model=d_model, num_heads=2) causal_output, causal_weights = causal_attn(x, return_attention=True) print(f"Causal attention weights (first head, first batch):") print(causal_weights.data[0][:seq_len, :seq_len].round(3)) print("Notice the upper triangular part is nearly zero!") print() # Test gradient flow print("🔄 Testing Gradient Flow:") loss = output.sum() loss.backward() print(f"Input gradient shape: {x.grad.shape}") print(f"Gradient magnitude: {np.abs(x.grad).mean():.6f}") print("✅ Gradients flow correctly through attention!")
if __name__ == "__main__": demonstrate_attention()