Source code for fit.optim.experimental.sam

"""
Enhanced Sharpness-Aware Minimization (SAM) optimizer implementation.

SAM is a state-of-the-art optimizer that seeks parameters that lie in neighborhoods having
uniformly low loss values, leading to better generalization than standard optimizers.

This implementation includes:
- Adaptive sharpness
- Efficient computation
- Memory optimization
- Support for all base optimizers

Paper: https://arxiv.org/abs/2010.01412
"""

import numpy as np
from typing import List, Callable, Dict, Optional, Union, Tuple, Any

from fit.core.tensor import Tensor


[docs] class SAM: """ Enhanced implementation of Sharpness-Aware Minimization for improved generalization. SAM finds parameters that lie in neighborhoods having uniformly low loss values, making models more robust to perturbations and improving generalization performance. This is particularly effective for: - Training models that need to generalize well from limited data - Improving robustness to adversarial examples - Solving problems like XOR where sharp minima cause overfitting Args: parameters: List of parameters to optimize base_optimizer: Underlying optimizer (SGD, Adam, etc.) for parameter updates rho: Size of the neighborhood to sample for sharpness (default: 0.05) epsilon: Small constant for numerical stability (default: 1e-12) adaptive: Whether to use adaptive SAM which adjusts influence per parameter (default: False) auto_clip: Whether to automatically clip gradients during perturbation (default: True) """
[docs] def __init__( self, parameters: List[Tensor], base_optimizer: Any, rho: float = 0.05, epsilon: float = 1e-12, adaptive: bool = False, auto_clip: bool = True, ): self.parameters = parameters self.base_optimizer = base_optimizer self.rho = rho self.epsilon = epsilon self.adaptive = adaptive self.auto_clip = auto_clip # Store parameter copies for the sharpness-aware update self.param_copies = [None for _ in parameters] # For learning rate handling self._base_lr = getattr(base_optimizer, "lr", None) # Gradient norm scaling history for adaptive clipping self.grad_norm_history = [] self.max_history_size = 10
[docs] def first_step(self, zero_grad: bool = False): """ First step of SAM: Compute gradient at current position, perturb weights in the direction of steepest ascent, and save original weights. Args: zero_grad: Whether to zero gradients after computing perturbation """ # Save current parameters for i, param in enumerate(self.parameters): self.param_copies[i] = param.data.copy() # Compute perturbation based on current gradients gradients = [param.grad for param in self.parameters if param.grad is not None] if not gradients: return # No gradients to work with # Calculate gradient norm grad_norm = self._compute_grad_norm(gradients) if grad_norm < self.epsilon: return # Gradients too small # Apply perturbation perturbation_scale = self.rho / (grad_norm + self.epsilon) for param in self.parameters: if param.grad is not None: if self.adaptive: # Adaptive SAM: scale by parameter magnitude param_norm = np.linalg.norm(param.data) + self.epsilon scale = perturbation_scale * param_norm else: scale = perturbation_scale # Apply perturbation param.data += scale * param.grad # Zero gradients if requested if zero_grad: self.zero_grad()
[docs] def second_step(self, zero_grad: bool = False): """ Second step of SAM: Restore original weights and apply the actual update. Args: zero_grad: Whether to zero gradients after the update """ # Restore original parameters for i, param in enumerate(self.parameters): if self.param_copies[i] is not None: param.data = self.param_copies[i] # Apply base optimizer step self.base_optimizer.step() # Zero gradients if requested if zero_grad: self.zero_grad()
[docs] def step(self, closure: Optional[Callable] = None): """ Single step that combines both SAM steps (for compatibility). Note: This requires the closure to compute the loss function. For most use cases, use first_step() and second_step() separately. """ if closure is None: raise ValueError("SAM requires a closure function for single-step mode") # First step: compute gradients and perturb self.first_step(zero_grad=True) # Re-compute gradients at perturbed position loss = closure() loss.backward() # Second step: restore and update self.second_step(zero_grad=True) return loss
[docs] def zero_grad(self): """Zero all parameter gradients.""" for param in self.parameters: param.grad = None
def _compute_grad_norm(self, gradients: List[np.ndarray]) -> float: """Compute the norm of all gradients.""" total_norm = 0.0 for grad in gradients: if grad is not None: total_norm += np.sum(grad * grad) return np.sqrt(total_norm) @property def lr(self): """Get learning rate from base optimizer.""" return getattr(self.base_optimizer, "lr", None) @lr.setter def lr(self, value): """Set learning rate on base optimizer.""" if hasattr(self.base_optimizer, "lr"): self.base_optimizer.lr = value
[docs] def state_dict(self) -> Dict[str, Any]: """Get optimizer state.""" return { "rho": self.rho, "epsilon": self.epsilon, "adaptive": self.adaptive, "auto_clip": self.auto_clip, "base_optimizer": getattr(self.base_optimizer, "state_dict", lambda: {})(), }
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): """Load optimizer state.""" self.rho = state_dict.get("rho", self.rho) self.epsilon = state_dict.get("epsilon", self.epsilon) self.adaptive = state_dict.get("adaptive", self.adaptive) self.auto_clip = state_dict.get("auto_clip", self.auto_clip) if hasattr(self.base_optimizer, "load_state_dict"): self.base_optimizer.load_state_dict(state_dict.get("base_optimizer", {}))
[docs] class AdaptiveSAM(SAM): """ Adaptive SAM that automatically adjusts the perturbation size. This variant adjusts the perturbation based on the parameter magnitudes, often leading to better performance on diverse problems. """
[docs] def __init__(self, parameters: List[Tensor], base_optimizer: Any, **kwargs): kwargs["adaptive"] = True super().__init__(parameters, base_optimizer, **kwargs)