Source code for fit.monitor.tracker

"""
Training tracker for monitoring and logging training progress.
"""

import json
import os
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Union
import numpy as np

from fit.core.tensor import Tensor


[docs] class TrainingTracker: """ Comprehensive training tracker for monitoring metrics, early stopping, and logging. """
[docs] def __init__( self, experiment_name: Optional[str] = None, log_dir: str = "./logs", early_stopping: Optional[Dict] = None, save_best: bool = True, verbose: int = 1, ): """ Initialize training tracker. Args: experiment_name: Name of the experiment log_dir: Directory to save logs early_stopping: Early stopping configuration save_best: Whether to save best model state verbose: Verbosity level (0=silent, 1=normal, 2=verbose) """ self.experiment_name = experiment_name or f"experiment_{int(time.time())}" self.log_dir = log_dir self.verbose = verbose self.save_best = save_best # Create log directory os.makedirs(log_dir, exist_ok=True) # Initialize tracking data self.logs = {} self.current_epoch = 0 self.start_time = None self.best_values = {} self.best_epochs = {} # Early stopping self.early_stopping = early_stopping if early_stopping: self.patience = early_stopping.get("patience", 10) self.min_delta = early_stopping.get("min_delta", 1e-4) self.monitor_metric = early_stopping.get("metric", "val_loss") self.mode = early_stopping.get("mode", "min") # 'min' or 'max' self.wait = 0 self.stopped_epoch = 0 self.best_value = float("inf") if self.mode == "min" else float("-inf") # Metrics history self.epoch_times = [] self.learning_rates = []
[docs] def start_training(self): """Mark the start of training.""" self.start_time = time.time() if self.verbose >= 1: print(f"Starting training experiment: {self.experiment_name}") print(f"Logs will be saved to: {self.log_dir}")
[docs] def update(self, epoch: int, metrics: Dict[str, float]) -> bool: """ Update tracker with metrics for current epoch. Args: epoch: Current epoch number metrics: Dictionary of metric names and values Returns: True if training should stop (early stopping), False otherwise """ self.current_epoch = epoch # Record metrics for metric_name, value in metrics.items(): if metric_name not in self.logs: self.logs[metric_name] = [] # Convert tensor to float if needed if isinstance(value, Tensor): value = float(value.data) elif isinstance(value, np.ndarray): value = float(value) self.logs[metric_name].append(value) # Track best values if metric_name not in self.best_values: self.best_values[metric_name] = value self.best_epochs[metric_name] = epoch else: # Update best value (assuming lower is better for loss, higher for accuracy) is_better = False if "loss" in metric_name.lower() or "error" in metric_name.lower(): is_better = value < self.best_values[metric_name] else: is_better = value > self.best_values[metric_name] if is_better: self.best_values[metric_name] = value self.best_epochs[metric_name] = epoch # Check early stopping should_stop = False if self.early_stopping and self.monitor_metric in metrics: should_stop = self._check_early_stopping(metrics[self.monitor_metric]) # Log progress if self.verbose >= 1: self._log_epoch(epoch, metrics) return should_stop
def _check_early_stopping(self, current_value: float) -> bool: """ Check if early stopping criteria are met. Args: current_value: Current value of monitored metric Returns: True if training should stop """ if self.mode == "min": improved = current_value < (self.best_value - self.min_delta) else: improved = current_value > (self.best_value + self.min_delta) if improved: self.best_value = current_value self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = self.current_epoch if self.verbose >= 1: print(f"\nEarly stopping at epoch {self.current_epoch}") print( f"Best {self.monitor_metric}: {self.best_value} at epoch {self.current_epoch - self.wait}" ) return True return False def _log_epoch(self, epoch: int, metrics: Dict[str, float]): """Log metrics for current epoch.""" if epoch == 0: # Print header header = f"{'Epoch':<6}" for metric in metrics.keys(): header += f"{metric:<12}" header += f"{'Time':<8}" print(header) print("-" * len(header)) # Print metrics log_line = f"{epoch + 1:<6}" for metric, value in metrics.items(): if isinstance(value, float): log_line += f"{value:<12.4f}" else: log_line += f"{value:<12}" # Add time if available if len(self.epoch_times) > epoch: log_line += f"{self.epoch_times[epoch]:<8.2f}s" print(log_line)
[docs] def log_epoch_time(self, epoch_time: float): """Log time taken for current epoch.""" self.epoch_times.append(epoch_time)
[docs] def log_learning_rate(self, lr: float): """Log current learning rate.""" self.learning_rates.append(lr)
[docs] def plot_metrics( self, metrics: Optional[List[str]] = None, save_path: Optional[str] = None ): """ Plot training metrics. Args: metrics: List of metrics to plot (default: all) save_path: Path to save plot (optional) """ try: import matplotlib.pyplot as plt except ImportError: print("Matplotlib not available. Install with: pip install matplotlib") return False if not self.logs: print("No metrics to plot") return False metrics_to_plot = metrics or list(self.logs.keys()) metrics_to_plot = [ m for m in metrics_to_plot if m in self.logs and self.logs[m] ] if not metrics_to_plot: print("No valid metrics found to plot") return False # Create subplots n_metrics = len(metrics_to_plot) n_cols = min(2, n_metrics) n_rows = (n_metrics + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4 * n_rows)) if n_metrics == 1: axes = [axes] elif n_rows == 1: axes = axes if n_cols > 1 else [axes] else: axes = axes.flatten() for i, metric in enumerate(metrics_to_plot): ax = axes[i] epochs = range(1, len(self.logs[metric]) + 1) ax.plot(epochs, self.logs[metric], "b-", linewidth=2, label=metric) # Mark best value if metric in self.best_values: best_epoch = self.best_epochs[metric] best_value = self.best_values[metric] ax.plot( best_epoch + 1, best_value, "ro", markersize=8, label=f"Best: {best_value:.4f}", ) ax.set_title(f'{metric.replace("_", " ").title()}') ax.set_xlabel("Epoch") ax.set_ylabel(metric) ax.legend() ax.grid(True, alpha=0.3) # Remove empty subplots for i in range(n_metrics, len(axes)): axes[i].remove() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"Plot saved to {save_path}") else: # Save to log directory plot_path = os.path.join( self.log_dir, f"{self.experiment_name}_metrics.png" ) plt.savefig(plot_path, dpi=150, bbox_inches="tight") print(f"Plot saved to {plot_path}") plt.show() return True
[docs] def export(self, filepath: Optional[str] = None, format: str = "json") -> str: """ Export training logs to file. Args: filepath: Path to save file (auto-generated if None) format: Export format ("json" or "csv") Returns: Path to exported file """ if filepath is None: ext = ".json" if format.lower() == "json" else ".csv" filepath = os.path.join(self.log_dir, f"{self.experiment_name}_logs{ext}") if format.lower() == "json": return self._export_json(filepath) else: return self._export_csv(filepath)
def _export_json(self, filepath: str) -> str: """Export logs to JSON format.""" export_data = { "experiment_name": self.experiment_name, "total_epochs": self.current_epoch + 1, "training_time": time.time() - self.start_time if self.start_time else None, "best_values": self.best_values, "best_epochs": self.best_epochs, "logs": self.logs, "early_stopping": ( { "stopped": self.stopped_epoch > 0, "stopped_epoch": self.stopped_epoch, "patience": getattr(self, "patience", None), "monitor_metric": getattr(self, "monitor_metric", None), } if self.early_stopping else None ), "epoch_times": self.epoch_times, "learning_rates": self.learning_rates, "timestamp": datetime.now().isoformat(), } with open(filepath, "w") as f: json.dump(export_data, f, indent=2) print(f"Logs exported to {filepath}") return filepath def _export_csv(self, filepath: str) -> str: """Export logs to CSV format.""" import csv if not self.logs: print("No logs to export") return filepath # Get all metric names metrics = list(self.logs.keys()) max_epochs = max(len(values) for values in self.logs.values()) with open(filepath, "w", newline="") as csvfile: writer = csv.writer(csvfile) # Write header header = ["epoch"] + metrics if self.epoch_times: header.append("epoch_time") if self.learning_rates: header.append("learning_rate") writer.writerow(header) # Write data for epoch in range(max_epochs): row = [epoch + 1] # Add metric values for metric in metrics: if epoch < len(self.logs[metric]): row.append(self.logs[metric][epoch]) else: row.append("") # Add epoch time if available if self.epoch_times and epoch < len(self.epoch_times): row.append(self.epoch_times[epoch]) elif self.epoch_times: row.append("") # Add learning rate if available if self.learning_rates and epoch < len(self.learning_rates): row.append(self.learning_rates[epoch]) elif self.learning_rates: row.append("") writer.writerow(row) print(f"Logs exported to {filepath}") return filepath
[docs] def load(self, filepath: str) -> bool: """ Load training logs from file. Args: filepath: Path to log file (.json) Returns: True if successful, False otherwise """ try: with open(filepath, "r") as f: data = json.load(f) self.experiment_name = data.get("experiment_name", self.experiment_name) self.logs = data.get("logs", {}) self.best_values = data.get("best_values", {}) self.best_epochs = data.get("best_epochs", {}) self.epoch_times = data.get("epoch_times", []) self.learning_rates = data.get("learning_rates", []) # Update current epoch if self.logs: self.current_epoch = ( max(len(values) for values in self.logs.values()) - 1 ) print(f"Logs loaded from {filepath}") return True except Exception as e: print(f"Error loading logs: {e}") return False
[docs] def summary(self) -> Dict[str, Any]: """ Get training summary. Returns: Dictionary with training summary """ total_time = time.time() - self.start_time if self.start_time else None summary = { "experiment_name": self.experiment_name, "total_epochs": self.current_epoch + 1, "training_time": total_time, "best_metrics": self.best_values, "early_stopped": self.stopped_epoch > 0, "stopped_epoch": self.stopped_epoch if self.stopped_epoch > 0 else None, } if total_time: summary["avg_epoch_time"] = total_time / (self.current_epoch + 1) return summary
[docs] def __str__(self) -> str: """String representation of tracker.""" summary = self.summary() lines = [f"TrainingTracker: {summary['experiment_name']}"] lines.append(f" Epochs: {summary['total_epochs']}") if summary.get("training_time"): lines.append(f" Training time: {summary['training_time']:.2f}s") if summary["best_metrics"]: lines.append(" Best metrics:") for metric, value in summary["best_metrics"].items(): epoch = self.best_epochs.get(metric, 0) lines.append(f" {metric}: {value:.4f} (epoch {epoch + 1})") if summary["early_stopped"]: lines.append(f" Early stopped at epoch {summary['stopped_epoch']}") return "\n".join(lines)