Advanced Examples

Advanced examples showcasing FIT’s capabilities.

Optimizer Comparison

Comparing different optimizers on the XOR problem:

Optimizer Comparison
"""
Enhanced XOR problem comparison of different optimizers.

This version fixes convergence issues by properly initializing the model
and providing accurate gradients for all optimizers.
"""

import matplotlib.pyplot as plt
import numpy as np

from fit.core.tensor import Tensor
from fit.nn.modules.activation import Tanh
from fit.nn.modules.linear import Linear
from fit.nn.modules.container import Sequential
from fit.loss.regression import MSELoss
from fit.optim.sgd import SGD, SGDMomentum
from fit.optim.adam import Adam
from fit.optim.experimental.sam import SAM


def train_xor_with_optimizer(optimizer_name, epochs=2000, verbose=True):
    """
    Train a model to solve the XOR problem with a specific optimizer.

    Args:
        optimizer_name: Name of the optimizer to use
        epochs: Number of training epochs
        verbose: Whether to print progress

    Returns:
        Dictionary with training results
    """
    # Use a different seed for each optimizer to ensure different initialization
    np.random.seed(42 + hash(optimizer_name) % 1000)

    # XOR dataset
    X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float64)
    y = np.array([[0], [1], [1], [0]], dtype=np.float64)

    # Convert to tensors
    X_tensor = Tensor(X, requires_grad=True)
    y_tensor = Tensor(y, requires_grad=False)

    # Create a model with adequate capacity for XOR
    # Using 32 hidden neurons with Tanh activation
    hidden_size = 32
    model = Sequential(Linear(2, hidden_size), Tanh(), Linear(hidden_size, 1))

    # Critical: Initialize weights with a pattern that breaks symmetry for XOR
    # This is the most important part for ensuring convergence

    # Initialize with specific pattern for first layer
    init_scale = 1.0
    first_weights = np.zeros((hidden_size, 2))  # Fixed: proper shape for weight matrix
    for i in range(hidden_size):
        if i % 2 == 0:
            first_weights[i, 0] = init_scale
            first_weights[i, 1] = -init_scale
        else:
            first_weights[i, 0] = -init_scale
            first_weights[i, 1] = init_scale

    # Alternating bias pattern
    first_bias = np.zeros(hidden_size)
    for i in range(hidden_size):
        if i < hidden_size // 2:
            first_bias[i] = 0.1
        else:
            first_bias[i] = -0.1

    # Apply custom initialization
    model.layers[0].weight.data = first_weights
    model.layers[0].bias.data = first_bias

    # Initialize second layer with smaller random weights
    model.layers[2].weight.data = np.random.normal(
        0, 0.1, model.layers[2].weight.data.shape
    )
    model.layers[2].bias.data = np.zeros_like(model.layers[2].bias.data)

    # Set up optimizer
    if optimizer_name == "SGD":
        optimizer = SGD(model.parameters(), lr=0.1)
    elif optimizer_name == "SGDMomentum":
        optimizer = SGDMomentum(model.parameters(), lr=0.1, momentum=0.9)
    elif optimizer_name == "Adam":
        optimizer = Adam(model.parameters(), lr=0.01)
    elif optimizer_name == "SAM":
        # Use SGDMomentum as base optimizer for SAM, not SGD with momentum
        base_optimizer = SGDMomentum(model.parameters(), lr=0.1, momentum=0.9)
        optimizer = SAM(model.parameters(), base_optimizer, rho=0.05)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_name}")

    # Loss function
    loss_fn = MSELoss()

    # Track training progress
    losses = []
    accuracies = []

    # Training loop
    for epoch in range(epochs):
        # Zero gradients
        optimizer.zero_grad()

        if optimizer_name == "SAM":
            # SAM requires two forward passes
            outputs = model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()

            # First step: perturb weights
            optimizer.first_step(zero_grad=True)

            # Second forward pass
            outputs = model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()

            # Second step: update weights
            optimizer.second_step(zero_grad=True)
        else:
            # Standard optimization
            outputs = model(X_tensor)
            loss = loss_fn(outputs, y_tensor)
            loss.backward()
            optimizer.step()

        losses.append(float(loss.data))

        # Calculate accuracy
        threshold = 0.5
        predictions = (outputs.data >= threshold).astype(int)
        true_values = y.astype(int)
        accuracy = np.mean(predictions == true_values) * 100
        accuracies.append(accuracy)

        # Print progress occasionally
        if verbose and (epoch % 200 == 0 or epoch == 1 or epoch == epochs - 1):
            print(
                f"{optimizer_name} - Epoch {epoch}/{epochs}, Loss: {losses[-1]:.4f}, Accuracy: {accuracy:.1f}%"
            )

        # Early stopping if converged
        if accuracy == 100.0 and losses[-1] < 0.01 and epoch > 100:
            if verbose:
                print(f"{optimizer_name} - Converged at epoch {epoch}/{epochs}")
            break

    # Final evaluation
    outputs = model(X_tensor)
    threshold = 0.5
    predicted_classes = (outputs.data >= threshold).astype(int)
    actual_classes = y.astype(int)
    final_accuracy = np.mean(predicted_classes == actual_classes) * 100

    if verbose:
        print(f"\n{optimizer_name} final results:")
        print(f"  Final accuracy: {final_accuracy:.1f}%")
        print(f"  Final loss: {losses[-1]:.4f}")

    return {
        "model": model,
        "losses": losses,
        "accuracies": accuracies,
        "final_accuracy": final_accuracy,
        "predictions": outputs.data.flatten(),
        "optimizer": optimizer_name,
    }


def plot_decision_boundaries(results, optimizers):
    """
    Plot decision boundaries for each optimizer.
    """
    plt.figure(figsize=(15, 10))

    # Create a mesh grid for visualization
    h = 0.02
    x_min, x_max = -0.1, 1.1
    y_min, y_max = -0.1, 1.1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    mesh_points = np.c_[xx.ravel(), yy.ravel()]

    # XOR data points
    X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    y = np.array([[0], [1], [1], [0]]).flatten()

    # Plot each optimizer's decision boundary
    for i, opt_name in enumerate(optimizers):
        if opt_name not in results:
            continue

        result = results[opt_name]
        model = result["model"]

        # Plot in a grid
        plt.subplot(2, 3, i + 1)

        # Make predictions on the mesh grid
        Z = []
        for point in mesh_points:
            x_point = Tensor(point.reshape(1, -1), requires_grad=False)
            pred = model(x_point).data[0, 0]
            Z.append(1 if pred >= 0.5 else 0)

        # Reshape for contour plot
        Z = np.array(Z).reshape(xx.shape)

        # Plot decision boundary
        plt.contourf(xx, yy, Z, levels=[0, 0.5, 1], cmap=plt.cm.RdBu, alpha=0.5)
        plt.contour(xx, yy, Z, levels=[0.5], colors="k", linestyles="-", linewidths=2)

        # Plot data points
        plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdBu_r, edgecolors="k", s=80)

        # Add title and accuracy
        plt.title(f"{opt_name} - Accuracy: {result['final_accuracy']:.1f}%")
        plt.xlabel("X1")
        plt.ylabel("X2")

    plt.tight_layout()
    plt.savefig("optimizer_decision_boundaries.png", dpi=150, bbox_inches="tight")
    print("\nDecision boundaries saved to optimizer_decision_boundaries.png")


def compare_optimizers():
    """
    Compare different optimizers on the XOR problem.
    """
    print("Optimizer Comparison on XOR Problem")
    print("=" * 40)

    # List of optimizers to compare
    optimizers = ["SGD", "SGDMomentum", "Adam", "SAM"]
    results = {}

    # Train with each optimizer
    for optimizer_name in optimizers:
        print(f"\nTraining with {optimizer_name}...")
        try:
            result = train_xor_with_optimizer(optimizer_name, epochs=1000, verbose=True)
            results[optimizer_name] = result
        except Exception as e:
            print(f"Error training with {optimizer_name}: {e}")
            continue

    if not results:
        print("No successful training runs!")
        return

    # Set up plot
    plt.figure(figsize=(12, 8))

    # Plot loss curves
    plt.subplot(2, 1, 1)
    for optimizer_name, result in results.items():
        plt.plot(result["losses"], label=f"{optimizer_name}")

    plt.title("Training Loss Comparison")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.ylim(1e-4, 2)
    plt.legend()
    plt.grid(True)

    # Plot accuracy comparison
    plt.subplot(2, 1, 2)
    accuracies = [results[opt]["final_accuracy"] for opt in results.keys()]
    optimizer_names = list(results.keys())

    plt.bar(optimizer_names, accuracies)
    plt.title("Final Accuracy Comparison")
    plt.xlabel("Optimizer")
    plt.ylabel("Accuracy (%)")
    plt.ylim(0, 105)

    # Add accuracy values on top of bars
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 2, f"{acc:.1f}%", ha="center")

    plt.tight_layout()
    plt.savefig("optimizer_comparison.png", dpi=150, bbox_inches="tight")
    print("\nComparison plot saved to optimizer_comparison.png")

    # Plot decision boundaries
    plot_decision_boundaries(results, optimizer_names)

    # Print summary
    print("\nOptimizer Performance Summary:")
    for opt in optimizer_names:
        print(
            f"- {opt}: {results[opt]['final_accuracy']:.1f}% accuracy, final loss = {results[opt]['losses'][-1]:.6f}"
        )


if __name__ == "__main__":
    compare_optimizers()

SAM Optimizer

Using Sharpness-Aware Minimization for better generalization:

SAM Optimizer Example
"""
This example demonstrates how to use the Sharpness-Aware Minimization (SAM) optimizer
for improved generalization and robustness on image classification tasks.

We'll compare standard optimizers against SAM on the CIFAR-10 dataset
to show the generalization benefits.
"""

import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from fit.core.tensor import Tensor
from fit.nn.modules.activation import ReLU
from fit.nn.modules.linear import Linear
from fit.nn.modules.container import Sequential
from fit.nn.modules.normalization import BatchNorm
from fit.loss.classification import CrossEntropyLoss
from fit.optim.sgd import SGD, SGDMomentum
from fit.optim.adam import Adam
from fit.optim.experimental.sam import SAM
from fit.data.dataset import Dataset
from fit.data.dataloader import DataLoader


def load_cifar10_subset(n_samples=10000):
    """
    Load a subset of CIFAR-10 data for faster experimentation.

    Args:
        n_samples: Number of samples to load

    Returns:
        Tuple of (X_train, X_val, X_test, y_train, y_val, y_test)
    """
    print("Loading CIFAR-10 dataset...")

    # Load data from OpenML
    cifar = fetch_openml(name="CIFAR_10", version=1, parser="auto", as_frame=False)
    X = cifar.data.astype("float32") / 255.0
    y = cifar.target.astype("int")

    # Convert features to grayscale to simplify
    # Shape from (n, 3072) to (n, 1024) by averaging RGB channels
    X_reshaped = X.reshape(-1, 3, 32, 32)
    X_gray = X_reshaped.mean(axis=1).reshape(-1, 1024)

    # Take subset for faster training
    indices = np.random.permutation(len(X_gray))[:n_samples]
    X_subset = X_gray[indices]
    y_subset = y[indices]

    # Normalize data
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_subset)

    # Split data
    X_train, X_temp, y_train, y_temp = train_test_split(
        X_scaled, y_subset, test_size=0.3, random_state=42, stratify=y_subset
    )

    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
    )

    print(
        f"Dataset ready: {len(X_train)} train, {len(X_val)} validation, {len(X_test)} test samples"
    )

    return X_train, X_val, X_test, y_train, y_val, y_test


def create_model():
    """
    Create a simple CNN-like model for CIFAR-10 classification.

    Returns:
        Sequential model
    """
    model = Sequential(
        Linear(1024, 256),
        ReLU(),
        BatchNorm(256),
        Linear(256, 128),
        ReLU(),
        BatchNorm(128),
        Linear(128, 64),
        ReLU(),
        Linear(64, 10),  # 10 classes for CIFAR-10
    )

    return model


def train_model(model, train_loader, val_loader, optimizer, criterion, epochs=20):
    """
    Train a model and track performance.

    Args:
        model: Model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        optimizer: Optimizer instance
        criterion: Loss function
        epochs: Number of epochs

    Returns:
        Dictionary with training history
    """
    history = {"train_loss": [], "val_loss": [], "val_accuracy": [], "train_time": []}

    print(f"Starting training with {optimizer.__class__.__name__}...")

    for epoch in range(epochs):
        start_time = time.time()

        # Training phase
        epoch_train_loss = 0.0
        train_batches = 0

        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            # Zero gradients
            for param in model.parameters():
                param.grad = None

            # Forward pass
            output = model(batch_x)
            loss = criterion(output, batch_y)

            # Backward pass
            loss.backward()

            # SAM specific step
            if isinstance(optimizer, SAM):
                optimizer.first_step(zero_grad=True)

                # Second forward pass for SAM
                output2 = model(batch_x)
                loss2 = criterion(output2, batch_y)
                loss2.backward()

                optimizer.second_step(zero_grad=True)
            else:
                optimizer.step()

            epoch_train_loss += loss.data
            train_batches += 1

            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.data:.4f}")

        # Validation phase
        val_loss, val_accuracy = evaluate_model(model, val_loader, criterion)

        # Record metrics
        avg_train_loss = epoch_train_loss / train_batches
        epoch_time = time.time() - start_time

        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(val_accuracy)
        history["train_time"].append(epoch_time)

        print(f"Epoch {epoch+1}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}")
        print(f"  Val Accuracy: {val_accuracy:.4f}")
        print(f"  Time: {epoch_time:.2f}s")
        print("-" * 50)

    return history


def evaluate_model(model, data_loader, criterion):
    """
    Evaluate model on validation/test data.

    Args:
        model: Model to evaluate
        data_loader: Data loader
        criterion: Loss function

    Returns:
        Tuple of (average_loss, accuracy)
    """
    total_loss = 0.0
    correct = 0
    total = 0
    batches = 0

    for batch_x, batch_y in data_loader:
        output = model(batch_x)
        loss = criterion(output, batch_y)

        total_loss += loss.data
        batches += 1

        # Calculate accuracy
        predictions = np.argmax(output.data, axis=1)
        targets = batch_y.data if hasattr(batch_y, "data") else batch_y

        correct += np.sum(predictions == targets)
        total += len(targets)

    avg_loss = total_loss / batches
    accuracy = correct / total

    return avg_loss, accuracy


def compare_optimizers():
    """
    Compare different optimizers on CIFAR-10 subset.
    """
    print("SAM vs Standard Optimizers Comparison")
    print("=" * 50)

    # Load data
    X_train, X_val, X_test, y_train, y_val, y_test = load_cifar10_subset(n_samples=5000)

    # Create data loaders
    train_dataset = Dataset(X_train, y_train)
    val_dataset = Dataset(X_val, y_val)
    test_dataset = Dataset(X_test, y_test)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Loss function
    criterion = CrossEntropyLoss()

    # Define optimizers to compare - FIXED IMPORTS
    optimizers_config = [
        (
            "SGD",
            lambda params: SGD(params, lr=0.01),
        ),  # No momentum parameter for base SGD
        (
            "SGD-Momentum",
            lambda params: SGDMomentum(params, lr=0.01, momentum=0.9),
        ),  # Use SGDMomentum
        ("Adam", lambda params: Adam(params, lr=0.001)),
        (
            "SAM-SGD",
            lambda params: SAM(
                params, SGDMomentum(params, lr=0.01, momentum=0.9), rho=0.05
            ),
        ),
        ("SAM-Adam", lambda params: SAM(params, Adam(params, lr=0.001), rho=0.05)),
    ]

    results = {}

    for name, optimizer_fn in optimizers_config:
        print(f"\n{'='*20} Training with {name} {'='*20}")

        try:
            # Create fresh model for each optimizer
            model = create_model()
            optimizer = optimizer_fn(model.parameters())

            # Train model
            history = train_model(
                model,
                train_loader,
                val_loader,
                optimizer,
                criterion,
                epochs=5,  # Reduced epochs for demo
            )

            # Final test evaluation
            test_loss, test_accuracy = evaluate_model(model, test_loader, criterion)

            results[name] = {
                "history": history,
                "test_loss": test_loss,
                "test_accuracy": test_accuracy,
            }

            print(f"\nFinal {name} Results:")
            print(f"Test Loss: {test_loss:.4f}")
            print(f"Test Accuracy: {test_accuracy:.4f}")

        except Exception as e:
            print(f"Error training with {name}: {e}")
            continue

    # Plot comparison
    if results:
        plot_comparison(results)
    else:
        print("No successful training runs!")

    return results


def plot_comparison(results):
    """
    Plot comparison of different optimizers.

    Args:
        results: Dictionary with results from each optimizer
    """
    try:
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Plot training loss
        axes[0, 0].set_title("Training Loss")
        for name, data in results.items():
            axes[0, 0].plot(data["history"]["train_loss"], label=name)
        axes[0, 0].set_xlabel("Epoch")
        axes[0, 0].set_ylabel("Loss")
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # Plot validation loss
        axes[0, 1].set_title("Validation Loss")
        for name, data in results.items():
            axes[0, 1].plot(data["history"]["val_loss"], label=name)
        axes[0, 1].set_xlabel("Epoch")
        axes[0, 1].set_ylabel("Loss")
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # Plot validation accuracy
        axes[1, 0].set_title("Validation Accuracy")
        for name, data in results.items():
            axes[1, 0].plot(data["history"]["val_accuracy"], label=name)
        axes[1, 0].set_xlabel("Epoch")
        axes[1, 0].set_ylabel("Accuracy")
        axes[1, 0].legend()
        axes[1, 0].grid(True)

        # Plot final test accuracy comparison
        axes[1, 1].set_title("Final Test Accuracy")
        names = list(results.keys())
        test_accuracies = [results[name]["test_accuracy"] for name in names]

        bars = axes[1, 1].bar(names, test_accuracies)
        axes[1, 1].set_ylabel("Test Accuracy")
        axes[1, 1].set_ylim(0, 1)

        # Add value labels on bars
        for bar, acc in zip(bars, test_accuracies):
            height = bar.get_height()
            axes[1, 1].text(
                bar.get_x() + bar.get_width() / 2.0,
                height + 0.01,
                f"{acc:.3f}",
                ha="center",
                va="bottom",
            )

        plt.tight_layout()
        plt.savefig("sam_comparison.png", dpi=150, bbox_inches="tight")
        plt.show()

        print("Comparison plot saved as 'sam_comparison.png'")

    except Exception as e:
        print(f"Could not generate plot: {e}")


def demonstrate_sam_benefits():
    """
    Demonstrate the key benefits of SAM optimizer.
    """
    print("\nSAM Benefits Demonstration:")
    print("=" * 40)
    print("1. Better Generalization: SAM typically achieves better test accuracy")
    print("2. Flatter Minima: SAM finds parameters in flatter loss landscapes")
    print("3. Robustness: More stable to hyperparameter choices")
    print("4. Reduced Overfitting: Better gap between train and validation performance")
    print("\nKey SAM Parameters:")
    print("- rho: Controls the sharpness penalty (typical: 0.05-0.2)")
    print("- base_optimizer: Underlying optimizer (SGD, Adam, etc.)")
    print("- adaptive: Whether to use adaptive rho (default: False)")


if __name__ == "__main__":
    print("Sharpness-Aware Minimization (SAM) Example")
    print("=" * 50)

    # Show SAM benefits explanation
    demonstrate_sam_benefits()

    # Ask user if they want to run the comparison
    choice = input("\nRun optimizer comparison? (y/n): ").lower()

    if choice == "y":
        try:
            results = compare_optimizers()

            print("\n" + "=" * 50)
            print("SUMMARY")
            print("=" * 50)

            if results:
                # Find best performing optimizer
                best_optimizer = max(
                    results.keys(), key=lambda x: results[x]["test_accuracy"]
                )
                best_accuracy = results[best_optimizer]["test_accuracy"]

                print(f"Best performing optimizer: {best_optimizer}")
                print(f"Best test accuracy: {best_accuracy:.4f}")

                # Show SAM vs non-SAM comparison
                sam_optimizers = [name for name in results.keys() if "SAM" in name]
                regular_optimizers = [
                    name for name in results.keys() if "SAM" not in name
                ]

                if sam_optimizers and regular_optimizers:
                    avg_sam_acc = np.mean(
                        [results[name]["test_accuracy"] for name in sam_optimizers]
                    )
                    avg_regular_acc = np.mean(
                        [results[name]["test_accuracy"] for name in regular_optimizers]
                    )

                    print(f"\nAverage SAM accuracy: {avg_sam_acc:.4f}")
                    print(f"Average regular accuracy: {avg_regular_acc:.4f}")
                    print(f"SAM improvement: {avg_sam_acc - avg_regular_acc:.4f}")
            else:
                print("No successful results to analyze.")

        except Exception as e:
            print(f"Error running comparison: {e}")
            import traceback

            traceback.print_exc()

    else:
        print("Skipping comparison. To run it later, execute this script again.")

    print("\nSAM example completed!")

Transformer Model

Building and training a transformer model:

Attention Visualization

Visualizing attention patterns:

Ensemble Methods

Using ensemble methods for improved performance:

Custom Loss Functions

Creating and using custom loss functions: