Skip to content

fusion

Bases: Module, function

A base class for fusion operations, extending the Module and function classes.

This class provides mechanisms for preprocessing, postprocessing, and fusing input tensors. It allows the use of customizable functions for data transformations and facilitates the definition of fusion-specific parameters and methods.

Notes

In the tinyBIG library, we introduce several advanced fusion strategies that can more effectively aggregate the outputs from the wide architectures. Formally, given the input matrices \(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k\), their fusion output can be represented as

\[
    \begin{equation}
    \mathbf{A} = \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k).
    \end{equation}
\]

The dimensions of the input matrices \(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k\) may be identical or vary, depending on the specific definition of the fusion function.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

A list or tuple of dimensions for the input tensors, by default None.

None
name str

The name of the fusion operation, by default 'base_fusion'.

'base_fusion'
require_parameters bool

Whether the fusion operation requires trainable parameters, by default False.

False
preprocess_functions list | tuple | callable

Functions to preprocess the input tensors, by default None.

None
postprocess_functions list | tuple | callable

Functions to postprocess the output tensors, by default None.

None
preprocess_function_configs dict

Configuration for instantiating the preprocess functions, by default None.

None
postprocess_function_configs dict

Configuration for instantiating the postprocess functions, by default None.

None
device str

The device for computations, by default 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/module/base_fusion.py
class fusion(Module, function):
    r"""
    A base class for fusion operations, extending the `Module` and `function` classes.

    This class provides mechanisms for preprocessing, postprocessing, and fusing input tensors.
    It allows the use of customizable functions for data transformations and facilitates
    the definition of fusion-specific parameters and methods.

    Notes
    ---------
    In the tinyBIG library, we introduce several advanced fusion strategies that can more effectively aggregate the outputs from the wide architectures.
    Formally, given the input matrices $\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k$, their fusion output can be represented as

    $$
        \begin{equation}
        \mathbf{A} = \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k).
        \end{equation}
    $$

    The dimensions of the input matrices $\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k$ may be identical or vary,
    depending on the specific definition of the fusion function.



    Parameters
    ----------
    dims : list[int] | tuple[int], optional
        A list or tuple of dimensions for the input tensors, by default None.
    name : str, optional
        The name of the fusion operation, by default 'base_fusion'.
    require_parameters : bool, optional
        Whether the fusion operation requires trainable parameters, by default False.
    preprocess_functions : list | tuple | callable, optional
        Functions to preprocess the input tensors, by default None.
    postprocess_functions : list | tuple | callable, optional
        Functions to postprocess the output tensors, by default None.
    preprocess_function_configs : dict, optional
        Configuration for instantiating the preprocess functions, by default None.
    postprocess_function_configs : dict, optional
        Configuration for instantiating the postprocess functions, by default None.
    device : str, optional
        The device for computations, by default 'cpu'.
    *args : tuple
        Additional positional arguments.
    **kwargs : dict
        Additional keyword arguments.
    """
    def __init__(
        self,
        dims: list[int] | tuple[int] = None,
        name: str = 'base_fusion',
        require_parameters: bool = False,
        preprocess_functions=None,
        postprocess_functions=None,
        preprocess_function_configs=None,
        postprocess_function_configs=None,
        device: str = 'cpu',
        *args, **kwargs
    ):
        """
        Initializes the fusion class with its parameters and preprocessing/postprocessing functions.

        Parameters
        ----------
        dims : list[int] | tuple[int], optional
            A list or tuple of dimensions for the input tensors, by default None.
        name : str, optional
            The name of the fusion operation, by default 'base_fusion'.
        require_parameters : bool, optional
            Whether the fusion operation requires trainable parameters, by default False.
        preprocess_functions : list | tuple | callable, optional
            Functions to preprocess the input tensors, by default None.
        postprocess_functions : list | tuple | callable, optional
            Functions to postprocess the output tensors, by default None.
        preprocess_function_configs : dict, optional
            Configuration for instantiating the preprocess functions, by default None.
        postprocess_function_configs : dict, optional
            Configuration for instantiating the postprocess functions, by default None.
        device : str, optional
            The device for computations, by default 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.
        """
        Module.__init__(self)
        function.__init__(self, name=name, device=device)

        self.dims = dims
        self.require_parameters = require_parameters

        self.preprocess_functions = config.instantiation_functions(preprocess_functions, preprocess_function_configs, device=self.device)
        self.postprocess_functions = config.instantiation_functions(postprocess_functions, postprocess_function_configs, device=self.device)


    def get_dims(self):
        """
        Retrieves the dimensions of the input tensors.

        Returns
        -------
        list[int] | tuple[int] | None
            The dimensions of the input tensors, or None if not specified.
        """
        return self.dims

    def get_num(self):
        """
        Retrieves the number of dimensions.

        Returns
        -------
        int
            The number of dimensions, or 0 if `dims` is not specified.
        """
        if self.dims is not None:
            return len(self.dims)
        else:
            return 0

    def get_dim(self, index: int):
        """
        Retrieves the dimension at the specified index.

        Parameters
        ----------
        index : int
            The index of the dimension to retrieve.

        Returns
        -------
        int
            The dimension at the specified index.

        Raises
        ------
        ValueError
            If the index is out of bounds or `dims` is not specified.
        """
        if self.dims is not None:
            if index is not None and 0 <= index <= len(self.dims):
                return self.dims[index]
            else:
                raise ValueError(f'Index {index} is out of dim list bounds...')
        else:
            return None

    def pre_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        """
        Applies preprocessing functions to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.
        device : str, optional
            The computational device, by default 'cpu'.

        Returns
        -------
        torch.Tensor
            The preprocessed tensor.
        """
        return function.func_x(x, self.preprocess_functions, device=device)

    def post_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        """
        Applies postprocessing functions to the input tensor.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.
        device : str, optional
            The computational device, by default 'cpu'.

        Returns
        -------
        torch.Tensor
            The postprocessed tensor.
        """
        return function.func_x(x, self.postprocess_functions, device=device)

    def to_config(self):
        """
        Serializes the fusion instance into a configuration dictionary.

        Returns
        -------
        dict
            A dictionary containing the class name and parameters,
            along with serialized preprocessing and postprocessing function configurations.
        """
        class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
        attributes = {attr: getattr(self, attr) for attr in self.__dict__}
        attributes.pop('preprocess_functions')
        attributes.pop('postprocess_functions')

        if self.preprocess_functions is not None:
            attributes['preprocess_function_configs'] = function.functions_to_configs(self.preprocess_functions)
        if self.postprocess_functions is not None:
            attributes['postprocess_function_configs'] = function.functions_to_configs(self.postprocess_functions)

        return {
            "function_class": class_name,
            "function_parameters": attributes
        }

    @abstractmethod
    def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
        """
        Abstract method to calculate a value `n` based on dimensions or other parameters.

        Parameters
        ----------
        dims : list[int] | tuple[int], optional
            The input dimensions, by default None.

        Raises
        ------
        NotImplementedError
            This method must be implemented in subclasses.
        """
        pass

    @abstractmethod
    def calculate_l(self, *args, **kwargs):
        """
        Abstract method to calculate a value `l` based on specific parameters.

        Raises
        ------
        NotImplementedError
            This method must be implemented in subclasses.
        """
        pass

    @abstractmethod
    def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        """
        Abstract method to define the forward pass of the fusion operation.

        Parameters
        ----------
        x : list[torch.Tensor] | tuple[torch.Tensor]
            A list or tuple of input tensors.
        w : torch.nn.Parameter, optional
            Trainable parameters for the fusion operation, by default None.
        device : str, optional
            The computational device, by default 'cpu'.

        Raises
        ------
        NotImplementedError
            This method must be implemented in subclasses.
        """
        pass

__init__(dims=None, name='base_fusion', require_parameters=False, preprocess_functions=None, postprocess_functions=None, preprocess_function_configs=None, postprocess_function_configs=None, device='cpu', *args, **kwargs)

Initializes the fusion class with its parameters and preprocessing/postprocessing functions.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

A list or tuple of dimensions for the input tensors, by default None.

None
name str

The name of the fusion operation, by default 'base_fusion'.

'base_fusion'
require_parameters bool

Whether the fusion operation requires trainable parameters, by default False.

False
preprocess_functions list | tuple | callable

Functions to preprocess the input tensors, by default None.

None
postprocess_functions list | tuple | callable

Functions to postprocess the output tensors, by default None.

None
preprocess_function_configs dict

Configuration for instantiating the preprocess functions, by default None.

None
postprocess_function_configs dict

Configuration for instantiating the postprocess functions, by default None.

None
device str

The device for computations, by default 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/module/base_fusion.py
def __init__(
    self,
    dims: list[int] | tuple[int] = None,
    name: str = 'base_fusion',
    require_parameters: bool = False,
    preprocess_functions=None,
    postprocess_functions=None,
    preprocess_function_configs=None,
    postprocess_function_configs=None,
    device: str = 'cpu',
    *args, **kwargs
):
    """
    Initializes the fusion class with its parameters and preprocessing/postprocessing functions.

    Parameters
    ----------
    dims : list[int] | tuple[int], optional
        A list or tuple of dimensions for the input tensors, by default None.
    name : str, optional
        The name of the fusion operation, by default 'base_fusion'.
    require_parameters : bool, optional
        Whether the fusion operation requires trainable parameters, by default False.
    preprocess_functions : list | tuple | callable, optional
        Functions to preprocess the input tensors, by default None.
    postprocess_functions : list | tuple | callable, optional
        Functions to postprocess the output tensors, by default None.
    preprocess_function_configs : dict, optional
        Configuration for instantiating the preprocess functions, by default None.
    postprocess_function_configs : dict, optional
        Configuration for instantiating the postprocess functions, by default None.
    device : str, optional
        The device for computations, by default 'cpu'.
    *args : tuple
        Additional positional arguments.
    **kwargs : dict
        Additional keyword arguments.
    """
    Module.__init__(self)
    function.__init__(self, name=name, device=device)

    self.dims = dims
    self.require_parameters = require_parameters

    self.preprocess_functions = config.instantiation_functions(preprocess_functions, preprocess_function_configs, device=self.device)
    self.postprocess_functions = config.instantiation_functions(postprocess_functions, postprocess_function_configs, device=self.device)

calculate_l(*args, **kwargs) abstractmethod

Abstract method to calculate a value l based on specific parameters.

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Source code in tinybig/module/base_fusion.py
@abstractmethod
def calculate_l(self, *args, **kwargs):
    """
    Abstract method to calculate a value `l` based on specific parameters.

    Raises
    ------
    NotImplementedError
        This method must be implemented in subclasses.
    """
    pass

calculate_n(dims=None, *args, **kwargs) abstractmethod

Abstract method to calculate a value n based on dimensions or other parameters.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

The input dimensions, by default None.

None

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Source code in tinybig/module/base_fusion.py
@abstractmethod
def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
    """
    Abstract method to calculate a value `n` based on dimensions or other parameters.

    Parameters
    ----------
    dims : list[int] | tuple[int], optional
        The input dimensions, by default None.

    Raises
    ------
    NotImplementedError
        This method must be implemented in subclasses.
    """
    pass

forward(x, w=None, device='cpu', *args, **kwargs) abstractmethod

Abstract method to define the forward pass of the fusion operation.

Parameters:

Name Type Description Default
x list[Tensor] | tuple[Tensor]

A list or tuple of input tensors.

required
w Parameter

Trainable parameters for the fusion operation, by default None.

None
device str

The computational device, by default 'cpu'.

'cpu'

Raises:

Type Description
NotImplementedError

This method must be implemented in subclasses.

Source code in tinybig/module/base_fusion.py
@abstractmethod
def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
    """
    Abstract method to define the forward pass of the fusion operation.

    Parameters
    ----------
    x : list[torch.Tensor] | tuple[torch.Tensor]
        A list or tuple of input tensors.
    w : torch.nn.Parameter, optional
        Trainable parameters for the fusion operation, by default None.
    device : str, optional
        The computational device, by default 'cpu'.

    Raises
    ------
    NotImplementedError
        This method must be implemented in subclasses.
    """
    pass

get_dim(index)

Retrieves the dimension at the specified index.

Parameters:

Name Type Description Default
index int

The index of the dimension to retrieve.

required

Returns:

Type Description
int

The dimension at the specified index.

Raises:

Type Description
ValueError

If the index is out of bounds or dims is not specified.

Source code in tinybig/module/base_fusion.py
def get_dim(self, index: int):
    """
    Retrieves the dimension at the specified index.

    Parameters
    ----------
    index : int
        The index of the dimension to retrieve.

    Returns
    -------
    int
        The dimension at the specified index.

    Raises
    ------
    ValueError
        If the index is out of bounds or `dims` is not specified.
    """
    if self.dims is not None:
        if index is not None and 0 <= index <= len(self.dims):
            return self.dims[index]
        else:
            raise ValueError(f'Index {index} is out of dim list bounds...')
    else:
        return None

get_dims()

Retrieves the dimensions of the input tensors.

Returns:

Type Description
list[int] | tuple[int] | None

The dimensions of the input tensors, or None if not specified.

Source code in tinybig/module/base_fusion.py
def get_dims(self):
    """
    Retrieves the dimensions of the input tensors.

    Returns
    -------
    list[int] | tuple[int] | None
        The dimensions of the input tensors, or None if not specified.
    """
    return self.dims

get_num()

Retrieves the number of dimensions.

Returns:

Type Description
int

The number of dimensions, or 0 if dims is not specified.

Source code in tinybig/module/base_fusion.py
def get_num(self):
    """
    Retrieves the number of dimensions.

    Returns
    -------
    int
        The number of dimensions, or 0 if `dims` is not specified.
    """
    if self.dims is not None:
        return len(self.dims)
    else:
        return 0

post_process(x, device='cpu', *args, **kwargs)

Applies postprocessing functions to the input tensor.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
device str

The computational device, by default 'cpu'.

'cpu'

Returns:

Type Description
Tensor

The postprocessed tensor.

Source code in tinybig/module/base_fusion.py
def post_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
    """
    Applies postprocessing functions to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        The input tensor.
    device : str, optional
        The computational device, by default 'cpu'.

    Returns
    -------
    torch.Tensor
        The postprocessed tensor.
    """
    return function.func_x(x, self.postprocess_functions, device=device)

pre_process(x, device='cpu', *args, **kwargs)

Applies preprocessing functions to the input tensor.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
device str

The computational device, by default 'cpu'.

'cpu'

Returns:

Type Description
Tensor

The preprocessed tensor.

Source code in tinybig/module/base_fusion.py
def pre_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
    """
    Applies preprocessing functions to the input tensor.

    Parameters
    ----------
    x : torch.Tensor
        The input tensor.
    device : str, optional
        The computational device, by default 'cpu'.

    Returns
    -------
    torch.Tensor
        The preprocessed tensor.
    """
    return function.func_x(x, self.preprocess_functions, device=device)

to_config()

Serializes the fusion instance into a configuration dictionary.

Returns:

Type Description
dict

A dictionary containing the class name and parameters, along with serialized preprocessing and postprocessing function configurations.

Source code in tinybig/module/base_fusion.py
def to_config(self):
    """
    Serializes the fusion instance into a configuration dictionary.

    Returns
    -------
    dict
        A dictionary containing the class name and parameters,
        along with serialized preprocessing and postprocessing function configurations.
    """
    class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
    attributes = {attr: getattr(self, attr) for attr in self.__dict__}
    attributes.pop('preprocess_functions')
    attributes.pop('postprocess_functions')

    if self.preprocess_functions is not None:
        attributes['preprocess_function_configs'] = function.functions_to_configs(self.preprocess_functions)
    if self.postprocess_functions is not None:
        attributes['postprocess_function_configs'] = function.functions_to_configs(self.postprocess_functions)

    return {
        "function_class": class_name,
        "function_parameters": attributes
    }