Skip to content

random_matrix_hypernet_reconciliation

Bases: fabrication

A reconciliation mechanism using a hypernetwork approach with random matrices.

This class computes a reconciliation matrix W using a series of random matrices and a hypernetwork-like architecture.

Notes

Formally, based on the parameter vector \(\mathbf{w} \in R^l\), the random_matrix_hypernet_reconciliation function will fabricate it into a parameter matrix \(\mathbf{W} \in R^{n \times D}\) as follow:

\[
    \begin{equation}
    \begin{aligned}
    \text{Hypernet}(\mathbf{w}) &= \sigma(\mathbf{w} \mathbf{H}_1) \mathbf{H}_2 \\
    &= \sigma \left( \mathbf{w} (\mathbf{P} \mathbf{Q}^\top) \right) \left( \mathbf{S} \mathbf{T}^\top \right)\\
    &= \left( \sigma \left( (\mathbf{w} \mathbf{P}) \mathbf{Q}^\top \right) \mathbf{S} \right) \mathbf{T}^\top \in R^{n \times D},
    \end{aligned}
    \end{equation}
\]

where \(\mathbf{P} \in R^{l \times r}\), \(\mathbf{Q} \in R^{d \times r}\), \(\mathbf{S} \in R^{d \times r}\) and \(\mathbf{T} \in R^{(n \times D) \times r}\) are the low-rank random and frozen sub-matrices that can compose the matrices \(\mathbf{H}_1 \in R^{l \times d}\) and \(\mathbf{H}_2 \in R^{d \times (n \times D)}\) of the hypernet. Moreover, by leveraging the associative law of matrix multiplication, we can avoid explicitly calculating and storing \(\mathbf{H}_1\) and \(\mathbf{H}_2\) as indicated by the above equation. These low-rank random matrix representations reduce the space consumption of this function to \(\mathcal{O}\left(r \cdot (l + 2d + n \cdot D)\right)\).

Attributes:

Name Type Description
r int

Rank of the random matrices.

l int

Dimension of the hypernetwork input.

hidden_dim int

Hidden dimension of the hypernetwork.

P Tensor

Random matrix of shape (l, r), initialized and reused during computation.

Q Tensor

Random matrix of shape (hidden_dim, r), initialized and reused during computation.

S Tensor

Random matrix of shape (hidden_dim, r), initialized and reused during computation.

T Tensor

Random matrix of shape (n * D, r), initialized and reused during computation.

Methods:

Name Description
calculate_l

Computes the number of parameters required for the hypernetwork.

forward

Computes the reconciliation matrix using the hypernetwork approach.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
class random_matrix_hypernet_reconciliation(fabrication):
    r"""
        A reconciliation mechanism using a hypernetwork approach with random matrices.

        This class computes a reconciliation matrix `W` using a series of random matrices and a hypernetwork-like architecture.

        Notes
        ----------

        Formally, based on the parameter vector $\mathbf{w} \in R^l$, the random_matrix_hypernet_reconciliation function will fabricate it into a parameter matrix $\mathbf{W} \in R^{n \times D}$ as follow:

        $$
            \begin{equation}
            \begin{aligned}
            \text{Hypernet}(\mathbf{w}) &= \sigma(\mathbf{w} \mathbf{H}_1) \mathbf{H}_2 \\
            &= \sigma \left( \mathbf{w} (\mathbf{P} \mathbf{Q}^\top) \right) \left( \mathbf{S} \mathbf{T}^\top \right)\\
            &= \left( \sigma \left( (\mathbf{w} \mathbf{P}) \mathbf{Q}^\top \right) \mathbf{S} \right) \mathbf{T}^\top \in R^{n \times D},
            \end{aligned}
            \end{equation}
        $$

        where $\mathbf{P} \in R^{l \times r}$, $\mathbf{Q} \in R^{d \times r}$, $\mathbf{S} \in R^{d \times r}$ and $\mathbf{T} \in R^{(n \times D) \times r}$
        are the low-rank random and frozen sub-matrices that can compose the matrices $\mathbf{H}_1 \in R^{l \times d}$ and $\mathbf{H}_2 \in R^{d \times (n \times D)}$ of the hypernet.
        Moreover, by leveraging the associative law of matrix multiplication, we can avoid explicitly calculating and storing $\mathbf{H}_1$ and $\mathbf{H}_2$ as indicated by the above equation.
        These low-rank random matrix representations reduce the space consumption of this function to $\mathcal{O}\left(r \cdot (l + 2d + n \cdot D)\right)$.

        Attributes
        ----------
        r : int
            Rank of the random matrices.
        l : int
            Dimension of the hypernetwork input.
        hidden_dim : int
            Hidden dimension of the hypernetwork.
        P : torch.Tensor
            Random matrix of shape `(l, r)`, initialized and reused during computation.
        Q : torch.Tensor
            Random matrix of shape `(hidden_dim, r)`, initialized and reused during computation.
        S : torch.Tensor
            Random matrix of shape `(hidden_dim, r)`, initialized and reused during computation.
        T : torch.Tensor
            Random matrix of shape `(n * D, r)`, initialized and reused during computation.

        Methods
        -------
        calculate_l(n=None, D=None)
            Computes the number of parameters required for the hypernetwork.
        forward(n, D, w, device='cpu', *args, **kwargs)
            Computes the reconciliation matrix using the hypernetwork approach.
    """
    def __init__(self, name='random_matrix_hypernet_reconciliation', r: int = 2, l: int = 64, hidden_dim: int = 128, *args, **kwargs):
        """
            Initializes the random matrix hypernetwork reconciliation mechanism.

            Parameters
            ----------
            name : str, optional
                Name of the reconciliation instance. Defaults to 'random_matrix_hypernet_reconciliation'.
            r : int, optional
                Rank of the random matrices. Defaults to 2.
            l : int, optional
                Dimension of the hypernetwork input. Defaults to 64.
            hidden_dim : int, optional
                Hidden dimension of the hypernetwork. Defaults to 128.
            *args : tuple
                Additional positional arguments for the parent class.
            **kwargs : dict
                Additional keyword arguments for the parent class.
        """
        super().__init__(name=name, *args, **kwargs)
        self.r = r
        self.l = l
        self.hidden_dim = hidden_dim

        self.P = None
        self.Q = None
        self.S = None
        self.T = None

    def calculate_l(self, n: int = None, D: int = None):
        """
            Computes the number of parameters required for the hypernetwork.

            Parameters
            ----------
            n : int, optional
                Number of rows in the reconciliation matrix (unused here). Defaults to None.
            D : int, optional
                Number of columns in the reconciliation matrix (unused here). Defaults to None.

            Returns
            -------
            int
                Total number of parameters required, equal to `l`.
        """
        assert self.l is not None
        return self.l

    def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
        """
            Computes the reconciliation matrix using the hypernetwork approach.

            Parameters
            ----------
            n : int
                Number of rows in the reconciliation matrix.
            D : int
                Number of columns in the reconciliation matrix.
            w : torch.nn.Parameter
                Parameter tensor of shape `(1, l)` where `l` is the hypernetwork input dimension.
            device : str, optional
                Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                Reconciliation matrix of shape `(n, D)`.

            Raises
            ------
            AssertionError
                If the dimensions of `w`, `P`, `Q`, `S`, or `T` are inconsistent with the expected shapes.
        """

        assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)

        if self.P is None or (self.P is not None and self.P.shape != (self.l, self.r)):
            self.P = torch.randn(self.l, self.r, device=device)
        if self.Q is None or (self.Q is not None and self.Q.shape != (self.hidden_dim, self.r)):
            self.Q = torch.randn(self.hidden_dim, self.r, device=device)
        assert self.P.shape == (self.l, self.r) and self.Q.shape == (self.hidden_dim, self.r)

        if self.S is None or (self.S is not None and self.S.shape != (self.hidden_dim, self.r)):
            self.S = torch.randn(self.hidden_dim, self.r, device=device)
        if self.T is None or (self.T is not None and self.T.shape != (n*D, self.r)):
            self.T = torch.randn(n*D, self.r, device=device)
        assert self.S.shape == (self.hidden_dim, self.r) and self.T.shape == (n*D, self.r)

        W = torch.matmul(
            torch.matmul(
                F.sigmoid(torch.matmul(torch.matmul(w, self.P), self.Q.t())),
                self.S),
            self.T.t()
        ).view(n, D)

        assert W.shape == (n, D)
        return W

__init__(name='random_matrix_hypernet_reconciliation', r=2, l=64, hidden_dim=128, *args, **kwargs)

Initializes the random matrix hypernetwork reconciliation mechanism.

Parameters:

Name Type Description Default
name str

Name of the reconciliation instance. Defaults to 'random_matrix_hypernet_reconciliation'.

'random_matrix_hypernet_reconciliation'
r int

Rank of the random matrices. Defaults to 2.

2
l int

Dimension of the hypernetwork input. Defaults to 64.

64
hidden_dim int

Hidden dimension of the hypernetwork. Defaults to 128.

128
*args tuple

Additional positional arguments for the parent class.

()
**kwargs dict

Additional keyword arguments for the parent class.

{}
Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def __init__(self, name='random_matrix_hypernet_reconciliation', r: int = 2, l: int = 64, hidden_dim: int = 128, *args, **kwargs):
    """
        Initializes the random matrix hypernetwork reconciliation mechanism.

        Parameters
        ----------
        name : str, optional
            Name of the reconciliation instance. Defaults to 'random_matrix_hypernet_reconciliation'.
        r : int, optional
            Rank of the random matrices. Defaults to 2.
        l : int, optional
            Dimension of the hypernetwork input. Defaults to 64.
        hidden_dim : int, optional
            Hidden dimension of the hypernetwork. Defaults to 128.
        *args : tuple
            Additional positional arguments for the parent class.
        **kwargs : dict
            Additional keyword arguments for the parent class.
    """
    super().__init__(name=name, *args, **kwargs)
    self.r = r
    self.l = l
    self.hidden_dim = hidden_dim

    self.P = None
    self.Q = None
    self.S = None
    self.T = None

calculate_l(n=None, D=None)

Computes the number of parameters required for the hypernetwork.

Parameters:

Name Type Description Default
n int

Number of rows in the reconciliation matrix (unused here). Defaults to None.

None
D int

Number of columns in the reconciliation matrix (unused here). Defaults to None.

None

Returns:

Type Description
int

Total number of parameters required, equal to l.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def calculate_l(self, n: int = None, D: int = None):
    """
        Computes the number of parameters required for the hypernetwork.

        Parameters
        ----------
        n : int, optional
            Number of rows in the reconciliation matrix (unused here). Defaults to None.
        D : int, optional
            Number of columns in the reconciliation matrix (unused here). Defaults to None.

        Returns
        -------
        int
            Total number of parameters required, equal to `l`.
    """
    assert self.l is not None
    return self.l

forward(n, D, w, device='cpu', *args, **kwargs)

Computes the reconciliation matrix using the hypernetwork approach.

Parameters:

Name Type Description Default
n int

Number of rows in the reconciliation matrix.

required
D int

Number of columns in the reconciliation matrix.

required
w Parameter

Parameter tensor of shape (1, l) where l is the hypernetwork input dimension.

required
device str

Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

Reconciliation matrix of shape (n, D).

Raises:

Type Description
AssertionError

If the dimensions of w, P, Q, S, or T are inconsistent with the expected shapes.

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
    """
        Computes the reconciliation matrix using the hypernetwork approach.

        Parameters
        ----------
        n : int
            Number of rows in the reconciliation matrix.
        D : int
            Number of columns in the reconciliation matrix.
        w : torch.nn.Parameter
            Parameter tensor of shape `(1, l)` where `l` is the hypernetwork input dimension.
        device : str, optional
            Device for computation ('cpu', 'cuda', etc.). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            Reconciliation matrix of shape `(n, D)`.

        Raises
        ------
        AssertionError
            If the dimensions of `w`, `P`, `Q`, `S`, or `T` are inconsistent with the expected shapes.
    """

    assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)

    if self.P is None or (self.P is not None and self.P.shape != (self.l, self.r)):
        self.P = torch.randn(self.l, self.r, device=device)
    if self.Q is None or (self.Q is not None and self.Q.shape != (self.hidden_dim, self.r)):
        self.Q = torch.randn(self.hidden_dim, self.r, device=device)
    assert self.P.shape == (self.l, self.r) and self.Q.shape == (self.hidden_dim, self.r)

    if self.S is None or (self.S is not None and self.S.shape != (self.hidden_dim, self.r)):
        self.S = torch.randn(self.hidden_dim, self.r, device=device)
    if self.T is None or (self.T is not None and self.T.shape != (n*D, self.r)):
        self.T = torch.randn(n*D, self.r, device=device)
    assert self.S.shape == (self.hidden_dim, self.r) and self.T.shape == (n*D, self.r)

    W = torch.matmul(
        torch.matmul(
            F.sigmoid(torch.matmul(torch.matmul(w, self.P), self.Q.t())),
            self.S),
        self.T.t()
    ).view(n, D)

    assert W.shape == (n, D)
    return W