"""
High-level training API for FIT framework.
This module provides simple, one-line training functions that handle
all the boilerplate while remaining flexible for advanced use cases.
"""
import numpy as np
from typing import Union, Optional, List, Dict, Any, Callable
import time
from fit.core.tensor import Tensor
from fit.nn.modules.container import Sequential
from fit.nn.modules.linear import Linear
from fit.nn.modules.activation import ReLU, Softmax, Tanh
from fit.nn.modules.normalization import BatchNorm
from fit.loss.classification import CrossEntropyLoss
from fit.loss.regression import MSELoss
from fit.optim.adam import Adam
from fit.optim.sgd import SGD, SGDMomentum
from fit.optim.experimental.sam import SAM
from fit.data.dataset import Dataset
from fit.data.dataloader import DataLoader
from fit.monitor.tracker import TrainingTracker
[docs]
class SimpleTrainer:
"""
High-level trainer that handles all the boilerplate.
Makes training as simple as:
trainer = SimpleTrainer(model, data)
trainer.fit()
"""
[docs]
def __init__(
self,
model,
data: Union[tuple, DataLoader],
validation_data: Optional[Union[tuple, DataLoader]] = None,
loss: Union[str, Any] = "auto",
optimizer: Union[str, Any] = "adam",
metrics: Optional[List[str]] = None,
callbacks: Optional[List[str]] = None,
**kwargs,
):
"""
Initialize the simple trainer.
Args:
model: The model to train
data: Training data as (X, y) tuple or DataLoader
validation_data: Validation data (optional)
loss: Loss function ('auto', 'mse', 'crossentropy', or loss object)
optimizer: Optimizer ('adam', 'sgd', 'sam', or optimizer object)
metrics: List of metrics to track ['accuracy', 'loss']
callbacks: List of callback names ['early_stopping', 'lr_scheduler']
**kwargs: Additional arguments (lr, batch_size, etc.)
"""
self.model = model
self.kwargs = kwargs
# Set up data
self.train_loader = self._setup_data(data, shuffle=True)
self.val_loader = (
self._setup_data(validation_data, shuffle=False)
if validation_data
else None
)
# Set up loss function
self.loss_fn = self._setup_loss(loss)
# Set up optimizer
self.optimizer = self._setup_optimizer(optimizer)
# Set up metrics
self.metrics = metrics or ["loss", "accuracy"]
# Set up callbacks
self.callbacks = self._setup_callbacks(callbacks or [])
# Set up tracker
self.tracker = TrainingTracker(
experiment_name=kwargs.get("experiment_name"),
early_stopping=self._get_early_stopping_config(),
)
def _setup_data(self, data, shuffle=True):
"""Convert data to DataLoader if needed."""
if data is None:
return None
if isinstance(data, DataLoader):
return data
if isinstance(data, tuple) and len(data) == 2:
X, y = data
dataset = Dataset(X, y)
batch_size = self.kwargs.get("batch_size", 32)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
raise ValueError("Data must be a tuple (X, y) or DataLoader")
def _setup_loss(self, loss):
"""Set up loss function."""
if isinstance(loss, str):
if loss == "auto":
# Try to infer from model output size
return self._infer_loss()
elif loss.lower() in ["mse", "mean_squared_error"]:
return MSELoss()
elif loss.lower() in ["crossentropy", "cross_entropy", "ce"]:
return CrossEntropyLoss()
else:
raise ValueError(f"Unknown loss function: {loss}")
else:
return loss
def _infer_loss(self):
"""Infer appropriate loss function from model."""
# Get a sample from the data to test model output
try:
sample_batch = next(iter(self.train_loader))
sample_x, sample_y = sample_batch
output = self.model(sample_x)
output_size = output.data.shape[-1]
# If output is single value, use MSE
if output_size == 1:
return MSELoss()
# If output is multi-class, use CrossEntropy
else:
return CrossEntropyLoss()
except Exception:
# Default to MSE if we can't infer
print("Warning: Could not infer loss function, defaulting to MSE")
return MSELoss()
def _setup_optimizer(self, optimizer):
"""Set up optimizer."""
lr = self.kwargs.get("lr", 0.001)
if isinstance(optimizer, str):
if optimizer.lower() == "adam":
return Adam(self.model.parameters(), lr=lr)
elif optimizer.lower() == "sgd":
momentum = self.kwargs.get("momentum", 0.0)
if momentum > 0:
return SGDMomentum(
self.model.parameters(), lr=lr, momentum=momentum
)
else:
return SGD(self.model.parameters(), lr=lr)
elif optimizer.lower() == "sam":
base_opt = self.kwargs.get("base_optimizer", "sgd")
rho = self.kwargs.get("rho", 0.05)
if base_opt.lower() == "adam":
base = Adam(self.model.parameters(), lr=lr)
else:
base = SGD(self.model.parameters(), lr=lr)
return SAM(self.model.parameters(), base, rho=rho)
else:
raise ValueError(f"Unknown optimizer: {optimizer}")
else:
return optimizer
def _setup_callbacks(self, callbacks):
"""Set up training callbacks."""
callback_instances = []
for callback in callbacks:
if callback == "early_stopping":
# Early stopping will be handled by tracker
continue
elif callback == "lr_scheduler":
# Could implement learning rate scheduling here
continue
return callback_instances
def _get_early_stopping_config(self):
"""Get early stopping configuration."""
if "early_stopping" in self.kwargs:
return {
"patience": self.kwargs.get("patience", 10),
"min_delta": self.kwargs.get("min_delta", 1e-4),
"metric": self.kwargs.get("monitor", "val_loss"),
}
return None
[docs]
def fit(self, epochs: int = 100, verbose: int = 1):
"""
Train the model.
Args:
epochs: Number of epochs to train
verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch)
Returns:
Training history dictionary
"""
print(f"Starting training for {epochs} epochs...")
print(f"Model: {self.model.__class__.__name__}")
print(f"Optimizer: {self.optimizer.__class__.__name__}")
print(f"Loss: {self.loss_fn.__class__.__name__}")
print(f"Batch size: {self.train_loader.batch_size}")
print("-" * 50)
history = {"train_loss": [], "val_loss": [], "val_accuracy": []}
for epoch in range(epochs):
# Training phase
train_loss = self._train_epoch(epoch, verbose)
history["train_loss"].append(train_loss)
# Validation phase
if self.val_loader:
val_loss, val_acc = self._validate_epoch(epoch, verbose)
history["val_loss"].append(val_loss)
history["val_accuracy"].append(val_acc)
else:
val_loss, val_acc = None, None
# Update tracker
metrics = {"train_loss": train_loss}
if val_loss is not None:
metrics["val_loss"] = val_loss
if val_acc is not None:
metrics["val_accuracy"] = val_acc
should_stop = self.tracker.update(epoch, metrics)
# Print progress
if verbose >= 1:
self._print_epoch_results(epoch, train_loss, val_loss, val_acc)
# Early stopping
if should_stop:
print(f"Early stopping at epoch {epoch + 1}")
break
print("Training completed!")
return history
def _train_epoch(self, epoch, verbose):
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
batch_count = 0
for batch_idx, (batch_x, batch_y) in enumerate(self.train_loader):
# Zero gradients
for param in self.model.parameters():
param.grad = None
# Forward pass
output = self.model(batch_x)
loss = self.loss_fn(output, batch_y)
# Backward pass
loss.backward()
# Optimizer step (handle SAM specially)
if isinstance(self.optimizer, SAM):
self.optimizer.first_step(zero_grad=True)
# Second forward pass for SAM
output2 = self.model(batch_x)
loss2 = self.loss_fn(output2, batch_y)
loss2.backward()
self.optimizer.second_step(zero_grad=True)
else:
self.optimizer.step()
total_loss += loss.data
batch_count += 1
return total_loss / batch_count
def _validate_epoch(self, epoch, verbose):
"""Validate for one epoch."""
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
batch_count = 0
for batch_x, batch_y in self.val_loader:
# Forward pass (no gradients needed)
output = self.model(batch_x)
loss = self.loss_fn(output, batch_y)
total_loss += loss.data
batch_count += 1
# Calculate accuracy
if "accuracy" in self.metrics:
predictions = np.argmax(output.data, axis=1)
targets = batch_y.data if hasattr(batch_y, "data") else batch_y
correct += np.sum(predictions == targets)
total += len(targets)
avg_loss = total_loss / batch_count
accuracy = correct / total if total > 0 else 0.0
return avg_loss, accuracy
def _print_epoch_results(self, epoch, train_loss, val_loss, val_acc):
"""Print results for one epoch."""
msg = f"Epoch {epoch + 1}: train_loss={train_loss:.4f}"
if val_loss is not None:
msg += f", val_loss={val_loss:.4f}"
if val_acc is not None:
msg += f", val_acc={val_acc:.4f}"
print(msg)
[docs]
def evaluate(self, test_data: Union[tuple, DataLoader]) -> Dict[str, float]:
"""
Evaluate the model on test data.
Args:
test_data: Test data as (X, y) tuple or DataLoader
Returns:
Dictionary with evaluation metrics
"""
test_loader = self._setup_data(test_data, shuffle=False)
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
batch_count = 0
print("Evaluating model...")
for batch_x, batch_y in test_loader:
output = self.model(batch_x)
loss = self.loss_fn(output, batch_y)
total_loss += loss.data
batch_count += 1
# Calculate accuracy
predictions = np.argmax(output.data, axis=1)
targets = batch_y.data if hasattr(batch_y, "data") else batch_y
correct += np.sum(predictions == targets)
total += len(targets)
results = {
"test_loss": total_loss / batch_count,
"test_accuracy": correct / total,
}
print("Evaluation Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")
return results
[docs]
def predict(self, X):
"""
Make predictions on new data.
Args:
X: Input data
Returns:
Predictions as numpy array
"""
self.model.eval()
if isinstance(X, np.ndarray):
X = Tensor(X, requires_grad=False)
output = self.model(X)
return output.data
[docs]
def save(self, filepath: str):
"""Save the trained model."""
from fit.nn.utils.model_io import save_model
save_model(self.model, filepath)
print(f"Model saved to {filepath}")
[docs]
def load(self, filepath: str):
"""Load a trained model."""
from fit.nn.utils.model_io import load_model
self.model = load_model(filepath)
print(f"Model loaded from {filepath}")
# Convenience functions for quick training
[docs]
def fit_classifier(
model,
train_data,
validation_data=None,
epochs=100,
lr=0.001,
batch_size=32,
optimizer="adam",
**kwargs,
):
"""
Quick function to train a classification model.
Args:
model: Model to train
train_data: Training data as (X, y) tuple
validation_data: Validation data (optional)
epochs: Number of epochs
lr: Learning rate
batch_size: Batch size
optimizer: Optimizer name or instance
**kwargs: Additional arguments
Returns:
Trained model and training history
"""
trainer = SimpleTrainer(
model=model,
data=train_data,
validation_data=validation_data,
loss="crossentropy",
optimizer=optimizer,
lr=lr,
batch_size=batch_size,
**kwargs,
)
history = trainer.fit(epochs=epochs)
return trainer.model, history
[docs]
def fit_regressor(
model,
train_data,
validation_data=None,
epochs=100,
lr=0.001,
batch_size=32,
optimizer="adam",
**kwargs,
):
"""
Quick function to train a regression model.
Args:
model: Model to train
train_data: Training data as (X, y) tuple
validation_data: Validation data (optional)
epochs: Number of epochs
lr: Learning rate
batch_size: Batch size
optimizer: Optimizer name or instance
**kwargs: Additional arguments
Returns:
Trained model and training history
"""
trainer = SimpleTrainer(
model=model,
data=train_data,
validation_data=validation_data,
loss="mse",
optimizer=optimizer,
lr=lr,
batch_size=batch_size,
**kwargs,
)
history = trainer.fit(epochs=epochs)
return trainer.model, history
[docs]
def quick_train(model, X, y, **kwargs):
"""
Ultra-simple training function.
Args:
model: Model to train
X: Input features
y: Target labels
**kwargs: Training parameters
Returns:
Trained model
"""
trainer = SimpleTrainer(model=model, data=(X, y), **kwargs)
trainer.fit(epochs=kwargs.get("epochs", 50))
return trainer.model