Source code for fit.utils.engine

from typing import Any, Callable, Dict, Optional

from fit.monitor.tracker import TrainingTracker
from fit.utils.trainer import Trainer
from fit.data.dataloader import DataLoader


[docs] def train_epoch(model, dataloader, loss_fn, optimizer, device=None): """ Train for a single epoch. Args: model: Model to train dataloader: DataLoader for training data loss_fn: Loss function optimizer: Optimizer device: Device to use (not used in this version, but kept for PyTorch compatibility) Returns: Dict with epoch metrics (loss, accuracy) """ # Set model to training mode (affects Dropout, BatchNorm, etc.) model.train() total_loss = 0.0 correct = 0 total = 0 # Iterate over batches for x, y in dataloader: # Forward pass outputs = model(x) loss = loss_fn(outputs, y) # Backward pass and optimize loss.backward() optimizer.step() optimizer.zero_grad() # Update metrics total_loss += loss.data # Calculate accuracy for classification tasks if outputs.data.ndim > 1 and outputs.data.shape[1] > 1: # Multi-class classification predictions = outputs.data.argmax(axis=1) correct += (predictions == y.data).sum() total += len(y.data) # Calculate epoch metrics metrics = { "loss": total_loss / len(dataloader), } if total > 0: metrics["accuracy"] = correct / total return metrics
[docs] def evaluate(model, dataloader, loss_fn, device=None): """ Evaluate model on a dataset. Args: model: Model to evaluate dataloader: DataLoader for evaluation data loss_fn: Loss function device: Device to use (not used in this version, but kept for PyTorch compatibility) Returns: Dict with evaluation metrics (loss, accuracy) """ # Set model to evaluation mode (affects Dropout, BatchNorm, etc.) model.eval() total_loss = 0.0 correct = 0 total = 0 # Iterate over batches (no gradient tracking needed) for x, y in dataloader: # Forward pass outputs = model(x) loss = loss_fn(outputs, y) # Update metrics total_loss += loss.data # Calculate accuracy for classification tasks if outputs.data.ndim > 1 and outputs.data.shape[1] > 1: # Multi-class classification predictions = outputs.data.argmax(axis=1) correct += (predictions == y.data).sum() total += len(y.data) # Calculate evaluation metrics metrics = { "loss": total_loss / len(dataloader), } if total > 0: metrics["accuracy"] = correct / total return metrics
[docs] def train( model, train_loader, val_loader, loss_fn, optimizer, epochs=10, device=None, scheduler=None, early_stopping=None, tracker=None, ): """ Complete training loop. Args: model: Model to train train_loader: DataLoader for training data val_loader: DataLoader for validation data (optional) loss_fn: Loss function optimizer: Optimizer epochs: Number of epochs to train device: Device to use (not used in this version, but kept for PyTorch compatibility) scheduler: Learning rate scheduler (optional) early_stopping: Early stopping settings (optional) tracker: TrainingTracker for logging (optional) Returns: TrainingTracker with training history """ # Create tracker if none provided if tracker is None: tracker = TrainingTracker(early_stopping=early_stopping) for epoch in range(1, epochs + 1): # Start epoch tracker.start_epoch() # Train for one epoch train_metrics = train_epoch(model, train_loader, loss_fn, optimizer, device) # Evaluate on validation set val_metrics = None if val_loader is not None: val_metrics = evaluate(model, val_loader, loss_fn, device) # Update learning rate if scheduler provided if scheduler is not None: scheduler.step() # Log metrics custom_metrics = {} if val_metrics: custom_metrics = {f"val_{k}": v for k, v in val_metrics.items()} tracker.log( loss=train_metrics["loss"], acc=train_metrics.get("accuracy"), lr=optimizer.lr if hasattr(optimizer, "lr") else None, custom_metrics=custom_metrics, ) # Print progress tracker.summary(last_n=1, style="box") # Check early stopping if tracker.should_early_stop(): print("Early stopping triggered!") break return tracker