"""
Base classes for neural network modules.
"""
from typing import List, Iterator, Dict, Any
from abc import ABC, abstractmethod
from fit.core.tensor import Tensor
[docs]
class Layer(ABC):
"""
Base class for all neural network layers.
All layers should inherit from this class and implement the forward method.
"""
[docs]
def __init__(self):
"""Initialize the layer."""
self._parameters = []
self._modules = {}
self.training = True
[docs]
def add_parameter(self, param: Tensor):
"""
Add a parameter to this layer.
Args:
param: Parameter tensor to add
"""
if param not in self._parameters:
self._parameters.append(param)
[docs]
def parameters(self) -> List[Tensor]:
"""
Return all parameters of this layer.
Returns:
List of parameter tensors
"""
params = list(self._parameters)
# Add parameters from child modules
for module in self._modules.values():
params.extend(module.parameters())
return params
[docs]
def named_parameters(self, prefix: str = "") -> Iterator[tuple]:
"""
Return iterator over module parameters with names.
Args:
prefix: Prefix to add to parameter names
Yields:
(name, parameter) tuples
"""
for i, param in enumerate(self._parameters):
name = f"{prefix}param_{i}" if prefix else f"param_{i}"
yield name, param
for name, module in self._modules.items():
submodule_prefix = f"{prefix}{name}." if prefix else f"{name}."
yield from module.named_parameters(submodule_prefix)
[docs]
def zero_grad(self):
"""Zero all parameter gradients."""
for param in self.parameters():
param.zero_grad()
[docs]
def train(self):
"""Set the layer to training mode."""
self.training = True
for module in self._modules.values():
module.train()
[docs]
def eval(self):
"""Set the layer to evaluation mode."""
self.training = False
for module in self._modules.values():
module.eval()
[docs]
@abstractmethod
def forward(self, *args, **kwargs):
"""
Forward pass of the layer.
This method must be implemented by all subclasses.
"""
raise NotImplementedError("Subclasses must implement forward method")
[docs]
def __call__(self, *args, **kwargs):
"""
Make the layer callable.
This calls the forward method.
"""
return self.forward(*args, **kwargs)
[docs]
def __repr__(self):
"""Return string representation of the layer."""
extra_repr = self.extra_repr()
if extra_repr:
return f"{self.__class__.__name__}({extra_repr})"
return f"{self.__class__.__name__}()"
[docs]
def state_dict(self) -> Dict[str, Any]:
"""
Return state dictionary containing layer's state.
Returns:
Dictionary mapping parameter names to their values
"""
state = {}
# Add own parameters
for i, param in enumerate(self._parameters):
state[f"param_{i}"] = param.data.copy()
# Add child module states
for name, module in self._modules.items():
module_state = module.state_dict()
for key, value in module_state.items():
state[f"{name}.{key}"] = value
return state
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]):
"""
Load state from state dictionary.
Args:
state_dict: Dictionary containing state to load
"""
# Load own parameters
for i, param in enumerate(self._parameters):
key = f"param_{i}"
if key in state_dict:
param.data = state_dict[key].copy()
# Load child module states
for name, module in self._modules.items():
module_state = {}
prefix = f"{name}."
for key, value in state_dict.items():
if key.startswith(prefix):
module_key = key[len(prefix) :]
module_state[module_key] = value
if module_state:
module.load_state_dict(module_state)
[docs]
def apply(self, fn):
"""
Apply function to all parameters.
Args:
fn: Function to apply to each parameter
"""
for param in self._parameters:
fn(param)
for module in self._modules.values():
module.apply(fn)
[docs]
def cuda(self):
"""Move layer to CUDA (placeholder - not implemented)."""
print("CUDA support not implemented YET in FIT framework")
return self
[docs]
def cpu(self):
"""Move layer to CPU (already on CPU)."""
return self
[docs]
def to(self, device):
"""Move layer to specified device (placeholder)."""
if device != "cpu":
print(f"Device '{device}' not supported in FIT framework")
return self
[docs]
def add_child(self, module: "Layer"):
"""
Add a child module to this layer.
Args:
module: Child module to add
"""
# Generate a unique name for the module
name = f"child_{len(self._modules)}"
self._modules[name] = module
[docs]
class Module(Layer):
"""
Alias for Layer class to match PyTorch naming convention.
"""
pass
[docs]
class Identity(Layer):
"""
Identity layer that returns input unchanged.
"""
[docs]
def forward(self, x):
return x
[docs]
class Lambda(Layer):
"""
Layer that applies a function to its input.
"""
[docs]
def __init__(self, func):
"""
Initialize lambda layer.
Args:
func: Function to apply
"""
super().__init__()
self.func = func
[docs]
def forward(self, x):
return self.func(x)
[docs]
class ParameterList(Layer):
"""
Container for a list of parameters.
"""
[docs]
def __init__(self, parameters=None):
"""
Initialize parameter list.
Args:
parameters: Initial list of parameters
"""
super().__init__()
if parameters is not None:
for param in parameters:
self.append(param)
[docs]
def append(self, parameter: Tensor):
"""Add a parameter to the list."""
self.add_parameter(parameter)
[docs]
def extend(self, parameters: List[Tensor]):
"""Extend the list with multiple parameters."""
for param in parameters:
self.append(param)
[docs]
def forward(self, x):
"""Parameter lists don't have forward pass."""
raise RuntimeError("ParameterList has no forward method")
[docs]
class ParameterDict(Layer):
"""
Container for a dictionary of parameters.
"""
[docs]
def __init__(self, parameters=None):
"""
Initialize parameter dictionary.
Args:
parameters: Initial dictionary of parameters
"""
super().__init__()
self._param_dict = {}
if parameters is not None:
for key, param in parameters.items():
self[key] = param
[docs]
def __setitem__(self, key: str, parameter: Tensor):
"""Set a parameter in the dictionary."""
self._param_dict[key] = parameter
self.add_parameter(parameter)
[docs]
def __getitem__(self, key: str) -> Tensor:
"""Get a parameter from the dictionary."""
return self._param_dict[key]
[docs]
def __delitem__(self, key: str):
"""Delete a parameter from the dictionary."""
del self._param_dict[key]
# Note: We don't remove from _parameters list for simplicity
[docs]
def keys(self):
return self._param_dict.keys()
[docs]
def values(self):
return self._param_dict.values()
[docs]
def items(self):
return self._param_dict.items()
[docs]
def forward(self, x):
"""Parameter dicts don't have forward pass."""
raise RuntimeError("ParameterDict has no forward method")