Skip to content

metric_fusion

Bases: fusion

A fusion mechanism that applies a specified numerical/statistica metric across input tensors.

Notes

Formally, given the input interdependence matrices \(\mathbf{A}_1, \mathbf{A}_2, \ldots, \mathbf{A}_k \in R^{m \times n}\) of identical shapes, we can represent their fusion output as

\[
    \begin{equation}
    \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k) = \mathbf{A}  \in R^{m \times n},
    \end{equation}
\]

where the entry \(\mathbf{A}(i, j)\) (for \(i \in \{1, 2, \cdots, m\}\) and \(j \in \{1, 2, \cdots, n\}\)) can be represented as

\[
    \begin{equation}
    \mathbf{A}(i, j) = metric \left( \mathbf{A}_1(i,j), \mathbf{A}_2(i,j), \cdots, \mathbf{A}_k(i,j)  \right).
    \end{equation}
\]

The \(metric(\cdots)\) can be either the numerical or statistical metrics, such as maximum, mean, product, etc.

Attributes:

Name Type Description
metric Callable[[Tensor], Tensor]

A callable metric function to apply to the input tensors.

Methods:

Name Description
calculate_n

Computes the output dimension of the fused input.

calculate_l

Computes the number of learnable parameters, if applicable.

forward

Performs the metric-based fusion on the input tensors.

Source code in tinybig/fusion/metric_fusion.py
class metric_fusion(base_fusion):
    r"""
        A fusion mechanism that applies a specified numerical/statistica metric across input tensors.

        Notes
        ----------

        Formally, given the input interdependence matrices $\mathbf{A}_1, \mathbf{A}_2, \ldots, \mathbf{A}_k \in R^{m \times n}$ of identical shapes,
        we can represent their fusion output as

        $$
            \begin{equation}
            \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k) = \mathbf{A}  \in R^{m \times n},
            \end{equation}
        $$

        where the entry $\mathbf{A}(i, j)$ (for $i \in \{1, 2, \cdots, m\}$ and $j \in \{1, 2, \cdots, n\}$) can be represented as

        $$
            \begin{equation}
            \mathbf{A}(i, j) = metric \left( \mathbf{A}_1(i,j), \mathbf{A}_2(i,j), \cdots, \mathbf{A}_k(i,j)  \right).
            \end{equation}
        $$

        The $metric(\cdots)$ can be either the numerical or statistical metrics, such as maximum, mean, product, etc.

        Attributes
        ----------
        metric : Callable[[torch.Tensor], torch.Tensor]
            A callable metric function to apply to the input tensors.

        Methods
        -------
        calculate_n(dims=None, *args, **kwargs)
            Computes the output dimension of the fused input.
        calculate_l(*args, **kwargs)
            Computes the number of learnable parameters, if applicable.
        forward(x, w=None, device='cpu', *args, **kwargs)
            Performs the metric-based fusion on the input tensors.
    """

    def __init__(
        self,
        dims: list[int] | tuple[int],
        metric: Callable[[torch.Tensor], torch.Tensor],
        name: str = "metric_fusion",
        *args, **kwargs
    ):
        """
            Initializes the metric-based fusion function.

            Parameters
            ----------
            dims : list[int] | tuple[int]
                Dimensions of the input tensors.
            metric : Callable[[torch.Tensor], torch.Tensor]
                A callable metric function to apply to the input tensors.
            name : str, optional
                Name of the fusion function. Defaults to "metric_fusion".
            *args : tuple
                Additional positional arguments for the parent class.
            **kwargs : dict
                Additional keyword arguments for the parent class.
        """
        super().__init__(dims=dims, name=name, require_parameters=False, *args, **kwargs)
        self.metric = metric

    def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
        """
            Computes the output dimension of the fused input.

            Parameters
            ----------
            dims : list[int] | tuple[int], optional
                List of dimensions of the input tensors. Defaults to None.

            Returns
            -------
            int
                Output dimension, equal to the input dimension if consistent.

            Raises
            ------
            AssertionError
                If input dimensions are inconsistent.
        """
        dims = dims if dims is not None else self.dims
        assert dims is not None and all(dim == dims[0] for dim in dims)
        return dims[0]

    def calculate_l(self, *args, **kwargs):
        """
            Computes the number of learnable parameters, if applicable.

            Returns
            -------
            int
                Number of learnable parameters. Returns 0 as metrics are non-parameterized.
        """
        return 0

    def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        """
            Performs the metric-based fusion on the input tensors.

            Parameters
            ----------
            x : list[torch.Tensor] | tuple[torch.Tensor]
                List or tuple of input tensors to be fused.
            w : torch.nn.Parameter, optional
                Learnable weights for fusion. Defaults to None.
            device : str, optional
                Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                Fused tensor after applying the metric.

            Raises
            ------
            ValueError
                If `x` is empty or if input tensors have inconsistent shapes.
            AssertionError
                If the metric is not callable.
        """
        if not x:
            raise ValueError("The input x cannot be empty...")
        if not all(x[0].shape == t.shape for t in x):
            raise ValueError("The input x must have the same shape.")

        x = torch.stack(x, dim=0)
        x = self.pre_process(x=x, device=device)

        assert self.metric is not None and isinstance(self.metric, Callable)

        x_shape = x.shape
        x_permuted = x.permute(*range(1, x.ndim), 0)

        fused_x = self.metric(x_permuted.view(-1, x_shape[0])).view(x_shape[1:])

        assert fused_x.size(-1) == self.calculate_n([element.size(-1) for element in x])
        return self.post_process(x=fused_x, device=device)

__init__(dims, metric, name='metric_fusion', *args, **kwargs)

Initializes the metric-based fusion function.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

Dimensions of the input tensors.

required
metric Callable[[Tensor], Tensor]

A callable metric function to apply to the input tensors.

required
name str

Name of the fusion function. Defaults to "metric_fusion".

'metric_fusion'
*args tuple

Additional positional arguments for the parent class.

()
**kwargs dict

Additional keyword arguments for the parent class.

{}
Source code in tinybig/fusion/metric_fusion.py
def __init__(
    self,
    dims: list[int] | tuple[int],
    metric: Callable[[torch.Tensor], torch.Tensor],
    name: str = "metric_fusion",
    *args, **kwargs
):
    """
        Initializes the metric-based fusion function.

        Parameters
        ----------
        dims : list[int] | tuple[int]
            Dimensions of the input tensors.
        metric : Callable[[torch.Tensor], torch.Tensor]
            A callable metric function to apply to the input tensors.
        name : str, optional
            Name of the fusion function. Defaults to "metric_fusion".
        *args : tuple
            Additional positional arguments for the parent class.
        **kwargs : dict
            Additional keyword arguments for the parent class.
    """
    super().__init__(dims=dims, name=name, require_parameters=False, *args, **kwargs)
    self.metric = metric

calculate_l(*args, **kwargs)

Computes the number of learnable parameters, if applicable.

Returns:

Type Description
int

Number of learnable parameters. Returns 0 as metrics are non-parameterized.

Source code in tinybig/fusion/metric_fusion.py
def calculate_l(self, *args, **kwargs):
    """
        Computes the number of learnable parameters, if applicable.

        Returns
        -------
        int
            Number of learnable parameters. Returns 0 as metrics are non-parameterized.
    """
    return 0

calculate_n(dims=None, *args, **kwargs)

Computes the output dimension of the fused input.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

List of dimensions of the input tensors. Defaults to None.

None

Returns:

Type Description
int

Output dimension, equal to the input dimension if consistent.

Raises:

Type Description
AssertionError

If input dimensions are inconsistent.

Source code in tinybig/fusion/metric_fusion.py
def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
    """
        Computes the output dimension of the fused input.

        Parameters
        ----------
        dims : list[int] | tuple[int], optional
            List of dimensions of the input tensors. Defaults to None.

        Returns
        -------
        int
            Output dimension, equal to the input dimension if consistent.

        Raises
        ------
        AssertionError
            If input dimensions are inconsistent.
    """
    dims = dims if dims is not None else self.dims
    assert dims is not None and all(dim == dims[0] for dim in dims)
    return dims[0]

forward(x, w=None, device='cpu', *args, **kwargs)

Performs the metric-based fusion on the input tensors.

Parameters:

Name Type Description Default
x list[Tensor] | tuple[Tensor]

List or tuple of input tensors to be fused.

required
w Parameter

Learnable weights for fusion. Defaults to None.

None
device str

Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

Fused tensor after applying the metric.

Raises:

Type Description
ValueError

If x is empty or if input tensors have inconsistent shapes.

AssertionError

If the metric is not callable.

Source code in tinybig/fusion/metric_fusion.py
def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
    """
        Performs the metric-based fusion on the input tensors.

        Parameters
        ----------
        x : list[torch.Tensor] | tuple[torch.Tensor]
            List or tuple of input tensors to be fused.
        w : torch.nn.Parameter, optional
            Learnable weights for fusion. Defaults to None.
        device : str, optional
            Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            Fused tensor after applying the metric.

        Raises
        ------
        ValueError
            If `x` is empty or if input tensors have inconsistent shapes.
        AssertionError
            If the metric is not callable.
    """
    if not x:
        raise ValueError("The input x cannot be empty...")
    if not all(x[0].shape == t.shape for t in x):
        raise ValueError("The input x must have the same shape.")

    x = torch.stack(x, dim=0)
    x = self.pre_process(x=x, device=device)

    assert self.metric is not None and isinstance(self.metric, Callable)

    x_shape = x.shape
    x_permuted = x.permute(*range(1, x.ndim), 0)

    fused_x = self.metric(x_permuted.view(-1, x_shape[0])).view(x_shape[1:])

    assert fused_x.size(-1) == self.calculate_n([element.size(-1) for element in x])
    return self.post_process(x=fused_x, device=device)