Source code for fit.core.autograd

"""
This module implements the autograd engine for automatic differentiation.

The autograd engine tracks computations and builds a directed acyclic graph
for efficient backpropagation.
"""

import numpy as np
from typing import Dict, List, Set, Callable, Optional, Tuple, Any, Union


[docs] class Node: """ Represents a node in the computational graph. Each node represents a value in the computation graph and tracks its dependencies (parents) as well as the backward function to compute gradients with respect to its inputs. """
[docs] def __init__(self, requires_grad: bool = False): """ Initialize a node in the computational graph. Args: requires_grad: Whether this node requires gradient computation """ self.parents: Set["Node"] = set() self.backward_fn: Callable[[], None] = lambda: None self.grad: Optional[np.ndarray] = None self.requires_grad: bool = requires_grad
[docs] def backward(self, gradient: Optional[np.ndarray] = None) -> None: """ Perform backpropagation starting from this node. Args: gradient: Upstream gradient to apply (defaults to ones) """ # Build topological ordering of the graph topo_order = [] visited = set() def build_topo(node: "Node") -> None: if node not in visited: visited.add(node) for parent in node.parents: if parent.requires_grad: build_topo(parent) topo_order.append(node) build_topo(self) # Initialize gradient at the start node if gradient is None: if hasattr(self, "data"): # Use ones with the same shape as data self.grad = np.ones_like(self.data, dtype=np.float64) else: # Scalar gradient self.grad = np.array(1.0, dtype=np.float64) else: # Ensure gradient has the correct dtype if not isinstance(gradient, np.ndarray): gradient = np.array(gradient, dtype=np.float64) elif gradient.dtype.kind != "f": gradient = gradient.astype(np.float64) self.grad = gradient # Backpropagate through the graph in reverse topological order for node in reversed(topo_order): try: node.backward_fn() except Exception as e: # Provide more helpful debugging info if backward fails node_info = f"Node type: {type(node).__name__}" if hasattr(node, "data"): node_info += f", shape: {node.data.shape}, dtype: {node.data.dtype}" print(f"Error in backward for {node_info}: {e}") raise
[docs] class Function: """ Base class for autograd functions. Each function represents an operation in the computation and defines how to compute the forward pass and the backward pass (gradient computation). """
[docs] @staticmethod def apply(ctx: Dict[str, Any], *inputs: Any) -> Any: """ Apply the function to inputs. Args: ctx: Context dictionary to store data for the backward pass *inputs: Input values Returns: Output value(s) """ raise NotImplementedError("Function subclasses must implement apply")
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[Optional[np.ndarray], ...]: """ Compute gradients with respect to inputs. Args: ctx: Context dictionary with data stored during the forward pass grad_output: Gradient with respect to the output Returns: Tuple of gradients with respect to inputs """ raise NotImplementedError("Function subclasses must implement backward")
[docs] @classmethod def forward(cls, *inputs: "Tensor") -> "Tensor": from fit.core.tensor import Tensor # Determine if output requires gradients requires_grad = any(t.requires_grad for t in inputs if isinstance(t, Tensor)) # Create context for storing data needed in backward ctx: Dict[str, Any] = {} # Convert tensor inputs to numpy arrays for computation numpy_inputs = [] tensor_inputs = [] # Keep track of actual tensor objects for inp in inputs: if hasattr(inp, "data") and hasattr(inp, "requires_grad"): numpy_inputs.append(inp.data) tensor_inputs.append(inp) # Store the tensor object elif inp is None: numpy_inputs.append(None) tensor_inputs.append(None) else: if not isinstance(inp, np.ndarray): inp = np.array(inp, dtype=np.float64) numpy_inputs.append(inp) tensor_inputs.append(None) # Apply the operation output_data = cls.apply(ctx, *numpy_inputs) # Create output tensor output = Tensor(output_data, requires_grad=requires_grad) if requires_grad: # Store references to input tensors with requires_grad=True grad_tensors = [ t for t in tensor_inputs if t is not None and t.requires_grad ] output._prev = set(grad_tensors) # Define backward function def backward_fn(): if output.grad is not None: grads = cls.backward(ctx, output.grad) # Debug print for MatMul # if cls.__name__ == "MatMul": # print(f"MatMul backward: output.grad.shape = {output.grad.shape}") # for i, grad in enumerate(grads): # if grad is not None: # print(f" grad[{i}].shape = {grad.shape}") # print(f" Assigning gradients to {len([inp for inp in tensor_inputs if inp is not None and inp.requires_grad])} tensors") # for i, inp in enumerate(tensor_inputs): # if inp is not None and inp.requires_grad: # print(f" tensor[{i}].shape = {inp.data.shape}") # Assign gradients to the correct tensors # We need to match gradients with their corresponding input tensors for i, inp in enumerate(tensor_inputs): if inp is not None and inp.requires_grad and i < len(grads): grad = grads[i] if grad is not None: # if cls.__name__ == "MatMul": # print(f" Assigning grad[{i}] (shape {grad.shape}) to tensor[{i}] (shape {inp.data.shape})") if inp.grad is None: inp.grad = grad else: inp.grad = inp.grad + grad output._backward = backward_fn return output
# Helper function for handling broadcasting in gradients def _unbroadcast(grad: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray: """ Unbroadcast a gradient to match the original tensor shape. Args: grad: Gradient that may have been broadcast shape: Original shape to unbroadcast to Returns: Unbroadcast gradient that matches the original shape """ # If shapes match, no need to unbroadcast if grad.shape == shape: return grad # For dimensions that were added in broadcasting, sum over them grad_ndim = grad.ndim shape_ndim = len(shape) if grad_ndim > shape_ndim: # Sum over added dimensions for _ in range(grad_ndim - shape_ndim): grad = np.sum(grad, axis=0) # For dimensions that were broadcast, sum over the broadcast dimension for i, (original_dim, grad_dim) in enumerate(zip(shape, grad.shape)): if original_dim == 1 and grad_dim > 1: grad = np.sum(grad, axis=i, keepdims=True) return grad # Core autograd functions
[docs] class Add(Function): """Addition function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray, b: np.ndarray) -> np.ndarray: ctx["a_shape"] = a.shape ctx["b_shape"] = b.shape return a + b
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: a_shape, b_shape = ctx["a_shape"], ctx["b_shape"] grad_a = grad_output grad_b = grad_output # Handle broadcasting if needed if a_shape != grad_output.shape: grad_a = _unbroadcast(grad_a, a_shape) if b_shape != grad_output.shape: grad_b = _unbroadcast(grad_b, b_shape) return grad_a, grad_b
[docs] class Multiply(Function): """Element-wise multiplication function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray, b: np.ndarray) -> np.ndarray: ctx["a"] = a ctx["b"] = b ctx["a_shape"] = a.shape ctx["b_shape"] = b.shape return a * b
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: a, b = ctx["a"], ctx["b"] a_shape, b_shape = ctx["a_shape"], ctx["b_shape"] # Compute raw gradients first grad_a = grad_output * b grad_b = grad_output * a # Handle broadcasting if needed if a_shape != grad_output.shape: grad_a = _unbroadcast(grad_a, a_shape) if b_shape != grad_output.shape: grad_b = _unbroadcast(grad_b, b_shape) return grad_a, grad_b
[docs] class MatMul(Function): """Matrix multiplication function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray, b: np.ndarray) -> np.ndarray: ctx["a"] = a ctx["b"] = b ctx["a_shape"] = a.shape ctx["b_shape"] = b.shape return a @ b
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: a, b = ctx["a"], ctx["b"] a_shape, b_shape = ctx["a_shape"], ctx["b_shape"] # For matrix multiplication C = A @ B: # dA = grad_output @ B.T # dB = A.T @ grad_output grad_a = grad_output @ b.T grad_b = a.T @ grad_output # Ensure gradients have the correct shapes assert ( grad_a.shape == a_shape ), f"grad_a shape {grad_a.shape} != a_shape {a_shape}" assert ( grad_b.shape == b_shape ), f"grad_b shape {grad_b.shape} != b_shape {b_shape}" return grad_a, grad_b
[docs] class Sum(Function): """Sum reduction function."""
[docs] @staticmethod def apply( ctx: Dict[str, Any], a: np.ndarray, axis: Optional[int] = None, keepdims: bool = False, ) -> np.ndarray: ctx["input_shape"] = a.shape ctx["axis"] = axis ctx["keepdims"] = keepdims return np.sum(a, axis=axis, keepdims=keepdims)
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, None, None]: input_shape = ctx["input_shape"] axis = ctx["axis"] keepdims = ctx["keepdims"] # If keepdims is False, we need to restore dimensions if not keepdims and axis is not None: # Add back reduced dimensions grad_output_reshaped = np.expand_dims(grad_output, axis=axis) else: grad_output_reshaped = grad_output # Broadcast gradient to input shape grad_input = np.broadcast_to(grad_output_reshaped, input_shape) return grad_input, None, None
[docs] class Mean(Function): """Mean reduction function."""
[docs] @staticmethod def apply( ctx: Dict[str, Any], a: np.ndarray, axis=None, keepdims=False ) -> np.ndarray: ctx["input_shape"] = a.shape # Handle problematic axis values if isinstance(axis, np.ndarray): if axis.ndim == 0: # 0-d array axis_val = axis.item() # Extract the scalar value if np.isnan(axis_val): axis = None else: axis = int(axis_val) else: axis = None # Multi-dimensional axis arrays not supported elif axis is not None and np.isnan(axis): axis = None ctx["axis"] = axis ctx["keepdims"] = keepdims if axis is None: ctx["size"] = a.size else: ctx["size"] = a.shape[axis] return np.mean(a, axis=axis, keepdims=keepdims)
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, None, None]: input_shape = ctx["input_shape"] axis = ctx["axis"] keepdims = ctx["keepdims"] size = ctx["size"] # If keepdims is False, we need to restore dimensions if not keepdims and axis is not None: # Add back reduced dimensions grad_output_reshaped = np.expand_dims(grad_output, axis=axis) else: grad_output_reshaped = grad_output # Broadcast gradient to input shape and divide by number of elements grad_input = np.broadcast_to(grad_output_reshaped, input_shape) / size return grad_input, None, None
[docs] class Exp(Function): """Exponential function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray) -> np.ndarray: # To prevent overflow, clip very large values a_safe = np.clip(a, -88, 88) # exp(88) is close to float64 max result = np.exp(a_safe) ctx["result"] = result return result
[docs] @staticmethod def backward(ctx: Dict[str, Any], grad_output: np.ndarray) -> Tuple[np.ndarray,]: result = ctx["result"] return (grad_output * result,)
[docs] class Log(Function): """Natural logarithm function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray) -> np.ndarray: # Add small epsilon for numerical stability a_safe = np.maximum(a, 1e-12) ctx["input"] = a_safe return np.log(a_safe)
[docs] @staticmethod def backward(ctx: Dict[str, Any], grad_output: np.ndarray) -> Tuple[np.ndarray,]: input_data = ctx["input"] return (grad_output / input_data,)
[docs] class Reshape(Function): """Reshape function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray: ctx["input_shape"] = a.shape return a.reshape(shape)
[docs] @staticmethod def backward( ctx: Dict[str, Any], grad_output: np.ndarray ) -> Tuple[np.ndarray, None]: input_shape = ctx["input_shape"] return grad_output.reshape(input_shape), None
[docs] class ReLU(Function): """Rectified Linear Unit activation function."""
[docs] @staticmethod def apply(ctx: Dict[str, Any], a: np.ndarray) -> np.ndarray: ctx["mask"] = a > 0 return np.maximum(0, a)
[docs] @staticmethod def backward(ctx: Dict[str, Any], grad_output: np.ndarray) -> Tuple[np.ndarray,]: mask = ctx["mask"] return (grad_output * mask,)
# Function registry for dynamic lookup _function_registry = { "add": Add, "multiply": Multiply, "matmul": MatMul, "sum": Sum, "mean": Mean, "exp": Exp, "log": Log, "reshape": Reshape, "relu": ReLU, }
[docs] def get_function(name: str) -> Function: """ Get a function by name from the registry. Args: name: Name of the function Returns: Function class Raises: ValueError: If the function is not found """ if name not in _function_registry: raise ValueError(f"Function '{name}' not found in registry") return _function_registry[name]
[docs] def register_function(name: str, function: Function) -> None: """ Register a new function in the registry. Args: name: Name of the function function: Function class to register """ _function_registry[name] = function