Source code for fit.nn.utils.model_io

import json
import os
import pickle

import numpy as np

from fit.core.tensor import Tensor


[docs] def save_model(model, path): """ Save model to file. Args: model: Model to save path: Path to save to """ # Create directory if it doesn't exist dir_name = os.path.dirname(path) if dir_name: os.makedirs(dir_name, exist_ok=True) # Collect model data model_data = { "model_type": model.__class__.__name__, "parameters": [], "config": {}, } # Save parameters for param in model.parameters(): model_data["parameters"].append( {"data": param.data.tolist(), "requires_grad": param.requires_grad} ) # Save configuration if available if hasattr(model, "get_config"): model_data["config"] = model.get_config() elif hasattr(model, "layers") and isinstance(model.layers, list): # For Sequential models, save layer configurations layers_config = [] for layer in model.layers: layer_config = {"type": layer.__class__.__name__} # Add layer-specific config if available if hasattr(layer, "get_config"): layer_config.update(layer.get_config()) layers_config.append(layer_config) model_data["config"]["layers"] = layers_config # Save to file with open(path, "wb") as f: pickle.dump(model_data, f) print(f"Model saved to {path}")
[docs] def load_model(path, model_class=None): """ Load model from file. Args: path: Path to load from model_class: Model class to instantiate (optional) Returns: Loaded model """ # Load model data with open(path, "rb") as f: model_data = pickle.load(f) # Instantiate model if class provided if model_class is not None: if "config" in model_data and model_data["config"]: model = model_class(**model_data["config"]) else: model = model_class() else: # Try to infer model class from model_type model_type = model_data.get("model_type", "") if "Sequential" in model_type: from fit.nn.modules.container import Sequential # Create empty Sequential model model = Sequential() # Add layers if config is available if "config" in model_data and "layers" in model_data["config"]: layers_config = model_data["config"]["layers"] for layer_config in layers_config: layer_type = layer_config.pop("type", "") # Import and instantiate each layer if layer_type == "Linear": from fit.nn.modules.linear import Linear in_features = layer_config.get("in_features") out_features = layer_config.get("out_features") layer = Linear(in_features, out_features) elif layer_type == "ReLU": from fit.nn.modules.activation import ReLU layer = ReLU() elif layer_type == "Softmax": from fit.nn.modules.activation import Softmax layer = Softmax() elif layer_type == "Dropout": from fit.nn.modules.activation import Dropout p = layer_config.get("p", 0.5) layer = Dropout(p) elif layer_type == "BatchNorm": from fit.nn.modules.normalization import BatchNorm num_features = layer_config.get("num_features") eps = layer_config.get("eps", 1e-5) momentum = layer_config.get("momentum", 0.1) layer = BatchNorm(num_features, eps, momentum) else: raise ValueError(f"Unknown layer type: {layer_type}") model.add_layer(layer) else: raise ValueError(f"Cannot instantiate model of type {model_type}") # Load parameters if "parameters" in model_data: params = model.parameters() if len(params) != len(model_data["parameters"]): raise ValueError( "Model has " + str(len(params)) + " parameters, but " + str(len(model_data["parameters"])) + " were saved" ) for i, param_data in enumerate(model_data["parameters"]): params[i].data = np.array(param_data["data"]) params[i].requires_grad = param_data["requires_grad"] print(f"Model loaded from {path}") return model