Skip to content

parameterized_concatenation_fusion

Bases: fusion

A fusion mechanism that concatenates input tensors along their last dimension, followed by a learnable parameterized transformation.

Notes

Formally, given input interdependence matrices \(\mathbf{A}_1, \mathbf{A}_2, \ldots, \mathbf{A}_k\), where each matrix \(\mathbf{A}_i \in R^{m \times n_i}\) has \(m\) rows and \(n_i\) columns, we define the fusion operator as follows:

\[
    \begin{equation}
    \begin{aligned}
    \mathbf{A} &= \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k) \\
    &= \left( \mathbf{A}_1 \sqcup \mathbf{A}_2 \sqcup \cdots \sqcup \mathbf{A}_k \right) \mathbf{W} \in R^{m \times n},
    \end{aligned}
    \end{equation}
\]

where \(\sqcup\) denotes the row-wise concatenation of the matrices. The term \(\mathbf{W} \in R^{(\sum_{i=1}^k n_i) \times n}\) represents a learnable parameter matrix that projects the concatenated matrix to a dimension of \(n\).

The number of required learnable parameter for this fusion function will be \(l = (\sum_{i=1}^k n_i) \times n\).

Attributes:

Name Type Description
n int

Output dimension after transformation.

dims list[int] | tuple[int]

List or tuple specifying the dimensions of the input tensors.

parameter_fabrication Callable

Function or object to handle parameter generation or transformation.

Methods:

Name Description
calculate_n

Computes the output dimension after the parameterized transformation.

calculate_l

Computes the number of learnable parameters for the transformation.

forward

Performs the concatenation fusion followed by the parameterized transformation.

Source code in tinybig/fusion/parameterized_concatenation_fusion.py
class parameterized_concatenation_fusion(fusion):
    r"""
        A fusion mechanism that concatenates input tensors along their last dimension, followed by a learnable parameterized transformation.

        Notes
        ----------

        Formally, given input interdependence matrices $\mathbf{A}_1, \mathbf{A}_2, \ldots, \mathbf{A}_k$,
        where each matrix $\mathbf{A}_i \in R^{m \times n_i}$ has $m$ rows and $n_i$ columns,
        we define the fusion operator as follows:

        $$
            \begin{equation}
            \begin{aligned}
            \mathbf{A} &= \text{fusion}(\mathbf{A}_1, \mathbf{A}_2, \cdots, \mathbf{A}_k) \\
            &= \left( \mathbf{A}_1 \sqcup \mathbf{A}_2 \sqcup \cdots \sqcup \mathbf{A}_k \right) \mathbf{W} \in R^{m \times n},
            \end{aligned}
            \end{equation}
        $$

        where $\sqcup$ denotes the row-wise concatenation of the matrices. The term $\mathbf{W} \in R^{(\sum_{i=1}^k n_i) \times n}$
        represents a learnable parameter matrix that projects the concatenated matrix to a dimension of $n$.

        The number of required learnable parameter for this fusion function will be $l = (\sum_{i=1}^k n_i) \times n$.

        Attributes
        ----------
        n : int
            Output dimension after transformation.
        dims : list[int] | tuple[int]
            List or tuple specifying the dimensions of the input tensors.
        parameter_fabrication : Callable
            Function or object to handle parameter generation or transformation.

        Methods
        -------
        calculate_n(dims=None, *args, **kwargs)
            Computes the output dimension after the parameterized transformation.
        calculate_l(*args, **kwargs)
            Computes the number of learnable parameters for the transformation.
        forward(x, w=None, device='cpu', *args, **kwargs)
            Performs the concatenation fusion followed by the parameterized transformation.
    """

    def __init__(self, n: int = None, dims: list[int] | tuple[int] = None, name: str = "parameterized_concatenation_fusion", require_parameters: bool = True, *args, **kwargs):
        """
            Initializes the parameterized concatenation fusion function.

            Parameters
            ----------
            n : int, optional
                Output dimension after transformation. Defaults to None.
            dims : list[int] | tuple[int], optional
                List or tuple specifying the dimensions of the input tensors. Defaults to None.
            name : str, optional
                Name of the fusion function. Defaults to "parameterized_concatenation_fusion".
            require_parameters : bool, optional
                Indicates whether the fusion requires learnable parameters. Defaults to True.
            *args : tuple
                Additional positional arguments for the parent class.
            **kwargs : dict
                Additional keyword arguments for the parent class.
        """
        super().__init__(dims=dims, name=name, require_parameters=True, *args, **kwargs)
        if n is not None:
            self.n = n
        else:
            assert dims is not None and all([dim == dims[0] for dim in dims])
            self.n = dims[0]
        self.parameter_fabrication = None

    def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
        """
            Computes the output dimension after the parameterized transformation.

            Parameters
            ----------
            dims : list[int] | tuple[int], optional
                List or tuple specifying the dimensions of the input tensors. Defaults to None.

            Returns
            -------
            int
                Output dimension after transformation.

            Raises
            ------
            AssertionError
                If `dims` is inconsistent or not provided.
        """
        if self.n is not None:
            return self.n
        else:
            dims = dims if dims is not None else self.dims
            assert dims is not None and all([dim == dims[0] for dim in dims])
            return dims[0]

    def calculate_l(self, *args, **kwargs):
        """
            Computes the number of learnable parameters for the transformation.

            Returns
            -------
            int
                Total number of learnable parameters.

            Raises
            ------
            ValueError
                If `dims` or `n` is not specified.
        """
        if self.dims is None or self.n is None:
            raise ValueError("The output dimension n is required...")
        if self.parameter_fabrication is None:
            return sum(self.dims) * self.n
        else:
            return self.parameter_fabrication.calculate_l(n=self.n, D=sum(self.dims))

    def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        """
            Performs the concatenation fusion followed by the parameterized transformation.

            Parameters
            ----------
            x : list[torch.Tensor] | tuple[torch.Tensor]
                List or tuple of input tensors to be concatenated and transformed.
            w : torch.nn.Parameter, optional
                Learnable weights for the transformation. Defaults to None.
            device : str, optional
                Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                Fused and transformed tensor.

            Raises
            ------
            ValueError
                If `x` is empty or if `dims` or `n` is not specified.
        """
        if not x:
            raise ValueError("The input x cannot be empty...")
        if not all(x[0].shape[:-1] == t.shape[:-1] for t in x):
            raise ValueError("Excluding the last dimension, the input x contains elements of different shapes for other dimensions...")

        if all(x[0].shape == t.shape for t in x):
            # if they are all the same shape, it will allow some cross-channel pre-processing operators...
            x = torch.stack(x, dim=0)
            x = self.pre_process(x=x, device=device)
            x = [t.squeeze(dim=0) for t in x.split(1, dim=0)]
        else:
            # otherwise, we cannot perform cross channel preprocessing, and have to pre-process them individually...
            x = [self.pre_process(t, device=device) for t in x]

        x = torch.cat(x, dim=-1)

        if self.dims is None or self.n is None:
            raise ValueError("The output dimension n is required...")
        if self.parameter_fabrication is None:
            W = w.reshape(self.n, sum(self.dims)).to(device=device)
        else:
            W = self.parameter_fabrication(w=w, n=self.n, D=sum(self.dims), device=device)

        fused_x = torch.matmul(x, W.t())

        assert fused_x.size(-1) == self.calculate_n([element.size(-1) for element in x])
        return self.post_process(x=fused_x, device=device)

__init__(n=None, dims=None, name='parameterized_concatenation_fusion', require_parameters=True, *args, **kwargs)

Initializes the parameterized concatenation fusion function.

Parameters:

Name Type Description Default
n int

Output dimension after transformation. Defaults to None.

None
dims list[int] | tuple[int]

List or tuple specifying the dimensions of the input tensors. Defaults to None.

None
name str

Name of the fusion function. Defaults to "parameterized_concatenation_fusion".

'parameterized_concatenation_fusion'
require_parameters bool

Indicates whether the fusion requires learnable parameters. Defaults to True.

True
*args tuple

Additional positional arguments for the parent class.

()
**kwargs dict

Additional keyword arguments for the parent class.

{}
Source code in tinybig/fusion/parameterized_concatenation_fusion.py
def __init__(self, n: int = None, dims: list[int] | tuple[int] = None, name: str = "parameterized_concatenation_fusion", require_parameters: bool = True, *args, **kwargs):
    """
        Initializes the parameterized concatenation fusion function.

        Parameters
        ----------
        n : int, optional
            Output dimension after transformation. Defaults to None.
        dims : list[int] | tuple[int], optional
            List or tuple specifying the dimensions of the input tensors. Defaults to None.
        name : str, optional
            Name of the fusion function. Defaults to "parameterized_concatenation_fusion".
        require_parameters : bool, optional
            Indicates whether the fusion requires learnable parameters. Defaults to True.
        *args : tuple
            Additional positional arguments for the parent class.
        **kwargs : dict
            Additional keyword arguments for the parent class.
    """
    super().__init__(dims=dims, name=name, require_parameters=True, *args, **kwargs)
    if n is not None:
        self.n = n
    else:
        assert dims is not None and all([dim == dims[0] for dim in dims])
        self.n = dims[0]
    self.parameter_fabrication = None

calculate_l(*args, **kwargs)

Computes the number of learnable parameters for the transformation.

Returns:

Type Description
int

Total number of learnable parameters.

Raises:

Type Description
ValueError

If dims or n is not specified.

Source code in tinybig/fusion/parameterized_concatenation_fusion.py
def calculate_l(self, *args, **kwargs):
    """
        Computes the number of learnable parameters for the transformation.

        Returns
        -------
        int
            Total number of learnable parameters.

        Raises
        ------
        ValueError
            If `dims` or `n` is not specified.
    """
    if self.dims is None or self.n is None:
        raise ValueError("The output dimension n is required...")
    if self.parameter_fabrication is None:
        return sum(self.dims) * self.n
    else:
        return self.parameter_fabrication.calculate_l(n=self.n, D=sum(self.dims))

calculate_n(dims=None, *args, **kwargs)

Computes the output dimension after the parameterized transformation.

Parameters:

Name Type Description Default
dims list[int] | tuple[int]

List or tuple specifying the dimensions of the input tensors. Defaults to None.

None

Returns:

Type Description
int

Output dimension after transformation.

Raises:

Type Description
AssertionError

If dims is inconsistent or not provided.

Source code in tinybig/fusion/parameterized_concatenation_fusion.py
def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
    """
        Computes the output dimension after the parameterized transformation.

        Parameters
        ----------
        dims : list[int] | tuple[int], optional
            List or tuple specifying the dimensions of the input tensors. Defaults to None.

        Returns
        -------
        int
            Output dimension after transformation.

        Raises
        ------
        AssertionError
            If `dims` is inconsistent or not provided.
    """
    if self.n is not None:
        return self.n
    else:
        dims = dims if dims is not None else self.dims
        assert dims is not None and all([dim == dims[0] for dim in dims])
        return dims[0]

forward(x, w=None, device='cpu', *args, **kwargs)

Performs the concatenation fusion followed by the parameterized transformation.

Parameters:

Name Type Description Default
x list[Tensor] | tuple[Tensor]

List or tuple of input tensors to be concatenated and transformed.

required
w Parameter

Learnable weights for the transformation. Defaults to None.

None
device str

Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

Fused and transformed tensor.

Raises:

Type Description
ValueError

If x is empty or if dims or n is not specified.

Source code in tinybig/fusion/parameterized_concatenation_fusion.py
def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
    """
        Performs the concatenation fusion followed by the parameterized transformation.

        Parameters
        ----------
        x : list[torch.Tensor] | tuple[torch.Tensor]
            List or tuple of input tensors to be concatenated and transformed.
        w : torch.nn.Parameter, optional
            Learnable weights for the transformation. Defaults to None.
        device : str, optional
            Device for computation ('cpu', 'cuda'). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            Fused and transformed tensor.

        Raises
        ------
        ValueError
            If `x` is empty or if `dims` or `n` is not specified.
    """
    if not x:
        raise ValueError("The input x cannot be empty...")
    if not all(x[0].shape[:-1] == t.shape[:-1] for t in x):
        raise ValueError("Excluding the last dimension, the input x contains elements of different shapes for other dimensions...")

    if all(x[0].shape == t.shape for t in x):
        # if they are all the same shape, it will allow some cross-channel pre-processing operators...
        x = torch.stack(x, dim=0)
        x = self.pre_process(x=x, device=device)
        x = [t.squeeze(dim=0) for t in x.split(1, dim=0)]
    else:
        # otherwise, we cannot perform cross channel preprocessing, and have to pre-process them individually...
        x = [self.pre_process(t, device=device) for t in x]

    x = torch.cat(x, dim=-1)

    if self.dims is None or self.n is None:
        raise ValueError("The output dimension n is required...")
    if self.parameter_fabrication is None:
        W = w.reshape(self.n, sum(self.dims)).to(device=device)
    else:
        W = self.parameter_fabrication(w=w, n=self.n, D=sum(self.dims), device=device)

    fused_x = torch.matmul(x, W.t())

    assert fused_x.size(-1) == self.calculate_n([element.size(-1) for element in x])
    return self.post_process(x=fused_x, device=device)