Skip to content

random_matrix_adaption_reconciliation

Bases: fabrication

A reconciliation mechanism using random matrices for parameter adaptation.

This class generates a reconciliation matrix W based on random matrices and diagonal parameter matrices.

Notes

Formally, given a parameter vector \(\mathbf{w} \in R^l\) of length \(l\), we can partition it into two vectors \(\lambda_1 \in R^{n}\) and \(\lambda_2 \in R^r\). These two vectors will define two diagonal matrices \(\Lambda_1 = diag( \lambda_1) \in R^{n \times n}\) and \(\Lambda_2 = diag(\lambda_2) \in R^{r \times r}\).

These two sub-matrices will fabricate a parameter matrix of shape \(n \times D\) as follows:

\[
    \begin{equation}
    \xi(\mathbf{w}) =  \Lambda_1 \mathbf{A} \Lambda_1 \mathbf{B}^\top \in R^{n \times D},
    \end{equation}
\]

where matrices \(\mathbf{A} \in R^{n \times r}\) and \(\mathbf{B} \in R^{D \times r}\) are randomly sampled from the Gaussian distribution \(\mathcal{N}(\mathbf{0}, \mathbf{I})\). The required length of vector \(\mathbf{w}\) is \(l = n + r\).

Attributes:

Name Type Description
r int

Rank of the random matrices used in the adaptation.

A Tensor

Random matrix of shape (n, r), initialized and reused during computation.

B Tensor

Random matrix of shape (D, r), initialized and reused during computation.

Methods:

Name Description
calculate_l

Computes the number of parameters required for the reconciliation.

forward

Computes the reconciliation matrix using the provided parameters and random matrices.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
class random_matrix_adaption_reconciliation(fabrication):
    r"""
        A reconciliation mechanism using random matrices for parameter adaptation.

        This class generates a reconciliation matrix `W` based on random matrices and diagonal parameter matrices.

        Notes
        ----------

        Formally, given a parameter vector $\mathbf{w} \in R^l$ of length $l$, we can partition it into two vectors $\lambda_1 \in R^{n}$ and $\lambda_2 \in R^r$.
        These two vectors will define two diagonal matrices $\Lambda_1 = diag( \lambda_1) \in R^{n \times n}$ and $\Lambda_2 = diag(\lambda_2) \in R^{r \times r}$.

        These two sub-matrices will fabricate a parameter matrix of shape $n \times D$ as follows:

        $$
            \begin{equation}
            \xi(\mathbf{w}) =  \Lambda_1 \mathbf{A} \Lambda_1 \mathbf{B}^\top \in R^{n \times D},
            \end{equation}
        $$

        where matrices $\mathbf{A} \in R^{n \times r}$ and $\mathbf{B} \in R^{D \times r}$ are randomly sampled from the Gaussian distribution $\mathcal{N}(\mathbf{0}, \mathbf{I})$.
        The required length of vector $\mathbf{w}$ is $l = n + r$.

        Attributes
        ----------
        r : int
            Rank of the random matrices used in the adaptation.
        A : torch.Tensor
            Random matrix of shape `(n, r)`, initialized and reused during computation.
        B : torch.Tensor
            Random matrix of shape `(D, r)`, initialized and reused during computation.

        Methods
        -------
        calculate_l(n, D)
            Computes the number of parameters required for the reconciliation.
        forward(n, D, w, device='cpu', *args, **kwargs)
            Computes the reconciliation matrix using the provided parameters and random matrices.
    """

    def __init__(self, name: str = 'random_matrix_adaption_reconciliation', r: int = 2, *args, **kwargs):
        """
            Initializes the random matrix adaption reconciliation mechanism.

            Parameters
            ----------
            name : str, optional
                Name of the reconciliation instance. Defaults to 'random_matrix_adaption_reconciliation'.
            r : int, optional
                Rank of the random matrices used in the adaptation. Defaults to 2.
            *args : tuple
                Additional positional arguments for the parent class.
            **kwargs : dict
                Additional keyword arguments for the parent class.
        """
        super().__init__(name=name, *args, **kwargs)
        self.r = r
        self.A = None
        self.B = None

    def calculate_l(self, n: int, D: int):
        """
            Computes the number of parameters required for the reconciliation.

            Parameters
            ----------
            n : int
                Number of rows in the reconciliation matrix.
            D : int
                Number of columns in the reconciliation matrix.

            Returns
            -------
            int
                Total number of parameters required, which is `n + r`.
        """
        return n + self.r

    def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
        """
            Computes the reconciliation matrix using the provided parameters and random matrices.

            Parameters
            ----------
            n : int
                Number of rows in the reconciliation matrix.
            D : int
                Number of columns in the reconciliation matrix.
            w : torch.nn.Parameter
                Parameter tensor of shape `(n, n + r)`, where `r` is the rank of the random matrices.
            device : str, optional
                Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                Reconciliation matrix of shape `(n, D)`.

            Raises
            ------
            AssertionError
                If the dimensions of `w`, `A`, or `B` are inconsistent with the expected shapes.
        """
        assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)
        lambda_1, lambda_2 = torch.split(w, [n, self.r], dim=1)

        Lambda_1 = torch.diag(lambda_1.view(-1)).to(device)
        Lambda_2 = torch.diag(lambda_2.view(-1)).to(device)

        if self.A is None or (self.A is not None and self.A.shape != (n, self.r)):
            self.A = torch.randn(n, self.r, device=device)
        if self.B is None or (self.B is not None and self.B.shape != (D, self.r)):
            self.B = torch.randn(D, self.r, device=device)
        assert self.A.shape == (n, self.r) and self.B.shape == (D, self.r)

        W = torch.matmul(torch.matmul(torch.matmul(Lambda_1, self.A), Lambda_2), self.B.t())
        assert W.shape == (n, D)
        return W

__init__(name='random_matrix_adaption_reconciliation', r=2, *args, **kwargs)

Initializes the random matrix adaption reconciliation mechanism.

Parameters:

Name Type Description Default
name str

Name of the reconciliation instance. Defaults to 'random_matrix_adaption_reconciliation'.

'random_matrix_adaption_reconciliation'
r int

Rank of the random matrices used in the adaptation. Defaults to 2.

2
*args tuple

Additional positional arguments for the parent class.

()
**kwargs dict

Additional keyword arguments for the parent class.

{}
Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def __init__(self, name: str = 'random_matrix_adaption_reconciliation', r: int = 2, *args, **kwargs):
    """
        Initializes the random matrix adaption reconciliation mechanism.

        Parameters
        ----------
        name : str, optional
            Name of the reconciliation instance. Defaults to 'random_matrix_adaption_reconciliation'.
        r : int, optional
            Rank of the random matrices used in the adaptation. Defaults to 2.
        *args : tuple
            Additional positional arguments for the parent class.
        **kwargs : dict
            Additional keyword arguments for the parent class.
    """
    super().__init__(name=name, *args, **kwargs)
    self.r = r
    self.A = None
    self.B = None

calculate_l(n, D)

Computes the number of parameters required for the reconciliation.

Parameters:

Name Type Description Default
n int

Number of rows in the reconciliation matrix.

required
D int

Number of columns in the reconciliation matrix.

required

Returns:

Type Description
int

Total number of parameters required, which is n + r.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def calculate_l(self, n: int, D: int):
    """
        Computes the number of parameters required for the reconciliation.

        Parameters
        ----------
        n : int
            Number of rows in the reconciliation matrix.
        D : int
            Number of columns in the reconciliation matrix.

        Returns
        -------
        int
            Total number of parameters required, which is `n + r`.
    """
    return n + self.r

forward(n, D, w, device='cpu', *args, **kwargs)

Computes the reconciliation matrix using the provided parameters and random matrices.

Parameters:

Name Type Description Default
n int

Number of rows in the reconciliation matrix.

required
D int

Number of columns in the reconciliation matrix.

required
w Parameter

Parameter tensor of shape (n, n + r), where r is the rank of the random matrices.

required
device str

Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

Reconciliation matrix of shape (n, D).

Raises:

Type Description
AssertionError

If the dimensions of w, A, or B are inconsistent with the expected shapes.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
    """
        Computes the reconciliation matrix using the provided parameters and random matrices.

        Parameters
        ----------
        n : int
            Number of rows in the reconciliation matrix.
        D : int
            Number of columns in the reconciliation matrix.
        w : torch.nn.Parameter
            Parameter tensor of shape `(n, n + r)`, where `r` is the rank of the random matrices.
        device : str, optional
            Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            Reconciliation matrix of shape `(n, D)`.

        Raises
        ------
        AssertionError
            If the dimensions of `w`, `A`, or `B` are inconsistent with the expected shapes.
    """
    assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)
    lambda_1, lambda_2 = torch.split(w, [n, self.r], dim=1)

    Lambda_1 = torch.diag(lambda_1.view(-1)).to(device)
    Lambda_2 = torch.diag(lambda_2.view(-1)).to(device)

    if self.A is None or (self.A is not None and self.A.shape != (n, self.r)):
        self.A = torch.randn(n, self.r, device=device)
    if self.B is None or (self.B is not None and self.B.shape != (D, self.r)):
        self.B = torch.randn(D, self.r, device=device)
    assert self.A.shape == (n, self.r) and self.B.shape == (D, self.r)

    W = torch.matmul(torch.matmul(torch.matmul(Lambda_1, self.A), Lambda_2), self.B.t())
    assert W.shape == (n, D)
    return W