Skip to content

constant_interdependence

Bases: interdependence

A class for constant interdependence.

This class defines a constant interdependence matrix (A) for the relationship between rows or columns of the input tensor. It does not require input data or additional parameters for computation.

Notes

Formally, based on the (optional) input data batch \(\mathbf{X} \in {R}^{b \times m}\), we define the constant interdependence function as:

\[\begin{equation}
\xi(\mathbf{X}) = \mathbf{A} \in {R}^{m \times m'}.
\end{equation}\]

This function facilitates the definition of customized constant interdependence matrices, allowing for a manually defined matrix \(\mathbf{A}\) to be provided as a hyper-parameter during function initialization.

Two special cases warrant particular attention: when \(\mathbf{A}_c\) consists entirely of zeros, it is designated as the "zero interdependence matrix", whereas a matrix of all ones is termed the "one interdependence matrix".

Attributes:

Name Type Description
A Tensor

The interdependence matrix of shape (b, b_prime) or (m, m_prime), depending on the interdependence type.

b int

Number of rows in the input tensor.

m int

Number of columns in the input tensor.

interdependence_type str

Type of interdependence ('attribute', 'instance', etc.).

name str

Name of the interdependence function.

device str

Device for computation (e.g., 'cpu' or 'cuda').

Methods:

Name Description
__init__

Initializes the constant interdependence function.

update_A

Updates the interdependence matrix A.

calculate_b_prime

Computes the number of rows in the output tensor after interdependence.

calculate_m_prime

Computes the number of columns in the output tensor after interdependence.

calculate_A

Returns the constant interdependence matrix A.

Source code in tinybig/interdependence/basic_interdependence.py
class constant_interdependence(interdependence):
    r"""
        A class for constant interdependence.

        This class defines a constant interdependence matrix (`A`) for the relationship between rows or columns
        of the input tensor. It does not require input data or additional parameters for computation.

        Notes
        ----------
        Formally, based on the (optional) input data batch $\mathbf{X} \in {R}^{b \times m}$, we define the constant interdependence function as:

        \begin{equation}
        \xi(\mathbf{X}) = \mathbf{A} \in {R}^{m \times m'}.
        \end{equation}

        This function facilitates the definition of customized constant interdependence matrices, allowing for a
        manually defined matrix $\mathbf{A}$ to be provided as a hyper-parameter during function initialization.

        Two special cases warrant particular attention: when $\mathbf{A}_c$ consists entirely of zeros, it is designated as
        the "zero interdependence matrix", whereas a matrix of all ones is termed the "one interdependence matrix".

        Attributes
        ----------
        A : torch.Tensor
            The interdependence matrix of shape `(b, b_prime)` or `(m, m_prime)`, depending on the interdependence type.
        b : int
            Number of rows in the input tensor.
        m : int
            Number of columns in the input tensor.
        interdependence_type : str
            Type of interdependence ('attribute', 'instance', etc.).
        name : str
            Name of the interdependence function.
        device : str
            Device for computation (e.g., 'cpu' or 'cuda').

        Methods
        -------
        __init__(b, m, A, interdependence_type='attribute', name='constant_interdependence', ...)
            Initializes the constant interdependence function.
        update_A(A)
            Updates the interdependence matrix `A`.
        calculate_b_prime(b=None)
            Computes the number of rows in the output tensor after interdependence.
        calculate_m_prime(m=None)
            Computes the number of columns in the output tensor after interdependence.
        calculate_A(x=None, w=None, device='cpu', ...)
            Returns the constant interdependence matrix `A`.
    """
    def __init__(
        self,
        b: int, m: int,
        A: torch.Tensor,
        interdependence_type: str = 'attribute',
        name: str = 'constant_interdependence',
        device: str = 'cpu',
        *args, **kwargs
    ):
        """
            Initializes the constant interdependence function.

            Parameters
            ----------
            b : int
                Number of rows in the input tensor.
            m : int
                Number of columns in the input tensor.
            A : torch.Tensor
                The interdependence matrix of shape `(b, b_prime)` or `(m, m_prime)`, depending on the interdependence type.
            interdependence_type : str, optional
                Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.
            name : str, optional
                Name of the interdependence function. Defaults to 'constant_interdependence'.
            device : str, optional
                Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments for the parent `interdependence` class.
            **kwargs : dict
                Additional keyword arguments for the parent `interdependence` class.

            Raises
            ------
            ValueError
                If `A` is None or does not have 2 dimensions.
        """
        super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=False, require_parameters=False, device=device, *args, **kwargs)
        if A is None or A.ndim != 2:
            raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
        self.A = A
        if self.A.device != device:
            self.A.to(device)

    def update_A(self, A: torch.Tensor):
        """
            Updates the interdependence matrix `A`.

            Parameters
            ----------
            A : torch.Tensor
                The new interdependence matrix of shape `(b, b_prime)` or `(m, m_prime)`.

            Raises
            ------
            ValueError
                If `A` is None or does not have 2 dimensions.
        """
        if A is None or A.ndim != 2:
            raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
        self.check_A_shape_validity(A=A)
        self.A = A

    def calculate_b_prime(self, b: int = None):
        """
            Computes the number of rows in the output tensor after applying interdependence function.

            Parameters
            ----------
            b : int, optional
                Number of rows in the input tensor. If None, defaults to `self.b`.

            Returns
            -------
            int
                The number of rows in the output tensor.

            Raises
            ------
            AssertionError
                If `b` does not match the shape of `A` for row-based interdependence.
        """
        b = b if b is not None else self.b
        if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            assert self.A is not None and b is not None and self.A.size(0) == b
            return self.A.size(1)
        else:
            return b

    def calculate_m_prime(self, m: int = None):
        """
            Computes the number of columns in the output tensor after applying interdependence function.

            Parameters
            ----------
            m : int, optional
                Number of columns in the input tensor. If None, defaults to `self.m`.

            Returns
            -------
            int
                The number of columns in the output tensor.

            Raises
            ------
            AssertionError
                If `m` does not match the shape of `A` for column-based interdependence.
        """
        m = m if m is not None else self.m
        if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            assert self.A is not None and m is not None and self.A.size(0) == m
            return self.A.size(1)
        else:
            return m

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        """
            Returns the constant interdependence matrix `A`.

            Parameters
            ----------
            x : torch.Tensor, optional
                Ignored for constant interdependence. Defaults to None.
            w : torch.nn.Parameter, optional
                Ignored for constant interdependence. Defaults to None.
            device : str, optional
                Device for computation. Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                The constant interdependence matrix `A`.

            Raises
            ------
            AssertionError
                If `A` is not set or requires data or parameters for computation.
        """
        assert self.A is not None and self.require_data is False and self.require_parameters is False
        return self.A

__init__(b, m, A, interdependence_type='attribute', name='constant_interdependence', device='cpu', *args, **kwargs)

Initializes the constant interdependence function.

Parameters:

Name Type Description Default
b int

Number of rows in the input tensor.

required
m int

Number of columns in the input tensor.

required
A Tensor

The interdependence matrix of shape (b, b_prime) or (m, m_prime), depending on the interdependence type.

required
interdependence_type str

Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.

'attribute'
name str

Name of the interdependence function. Defaults to 'constant_interdependence'.

'constant_interdependence'
device str

Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments for the parent interdependence class.

()
**kwargs dict

Additional keyword arguments for the parent interdependence class.

{}

Raises:

Type Description
ValueError

If A is None or does not have 2 dimensions.

Source code in tinybig/interdependence/basic_interdependence.py
def __init__(
    self,
    b: int, m: int,
    A: torch.Tensor,
    interdependence_type: str = 'attribute',
    name: str = 'constant_interdependence',
    device: str = 'cpu',
    *args, **kwargs
):
    """
        Initializes the constant interdependence function.

        Parameters
        ----------
        b : int
            Number of rows in the input tensor.
        m : int
            Number of columns in the input tensor.
        A : torch.Tensor
            The interdependence matrix of shape `(b, b_prime)` or `(m, m_prime)`, depending on the interdependence type.
        interdependence_type : str, optional
            Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.
        name : str, optional
            Name of the interdependence function. Defaults to 'constant_interdependence'.
        device : str, optional
            Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments for the parent `interdependence` class.
        **kwargs : dict
            Additional keyword arguments for the parent `interdependence` class.

        Raises
        ------
        ValueError
            If `A` is None or does not have 2 dimensions.
    """
    super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=False, require_parameters=False, device=device, *args, **kwargs)
    if A is None or A.ndim != 2:
        raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
    self.A = A
    if self.A.device != device:
        self.A.to(device)

calculate_A(x=None, w=None, device='cpu', *args, **kwargs)

Returns the constant interdependence matrix A.

Parameters:

Name Type Description Default
x Tensor

Ignored for constant interdependence. Defaults to None.

None
w Parameter

Ignored for constant interdependence. Defaults to None.

None
device str

Device for computation. Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

The constant interdependence matrix A.

Raises:

Type Description
AssertionError

If A is not set or requires data or parameters for computation.

Source code in tinybig/interdependence/basic_interdependence.py
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
    """
        Returns the constant interdependence matrix `A`.

        Parameters
        ----------
        x : torch.Tensor, optional
            Ignored for constant interdependence. Defaults to None.
        w : torch.nn.Parameter, optional
            Ignored for constant interdependence. Defaults to None.
        device : str, optional
            Device for computation. Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            The constant interdependence matrix `A`.

        Raises
        ------
        AssertionError
            If `A` is not set or requires data or parameters for computation.
    """
    assert self.A is not None and self.require_data is False and self.require_parameters is False
    return self.A

calculate_b_prime(b=None)

Computes the number of rows in the output tensor after applying interdependence function.

Parameters:

Name Type Description Default
b int

Number of rows in the input tensor. If None, defaults to self.b.

None

Returns:

Type Description
int

The number of rows in the output tensor.

Raises:

Type Description
AssertionError

If b does not match the shape of A for row-based interdependence.

Source code in tinybig/interdependence/basic_interdependence.py
def calculate_b_prime(self, b: int = None):
    """
        Computes the number of rows in the output tensor after applying interdependence function.

        Parameters
        ----------
        b : int, optional
            Number of rows in the input tensor. If None, defaults to `self.b`.

        Returns
        -------
        int
            The number of rows in the output tensor.

        Raises
        ------
        AssertionError
            If `b` does not match the shape of `A` for row-based interdependence.
    """
    b = b if b is not None else self.b
    if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
        assert self.A is not None and b is not None and self.A.size(0) == b
        return self.A.size(1)
    else:
        return b

calculate_m_prime(m=None)

Computes the number of columns in the output tensor after applying interdependence function.

Parameters:

Name Type Description Default
m int

Number of columns in the input tensor. If None, defaults to self.m.

None

Returns:

Type Description
int

The number of columns in the output tensor.

Raises:

Type Description
AssertionError

If m does not match the shape of A for column-based interdependence.

Source code in tinybig/interdependence/basic_interdependence.py
def calculate_m_prime(self, m: int = None):
    """
        Computes the number of columns in the output tensor after applying interdependence function.

        Parameters
        ----------
        m : int, optional
            Number of columns in the input tensor. If None, defaults to `self.m`.

        Returns
        -------
        int
            The number of columns in the output tensor.

        Raises
        ------
        AssertionError
            If `m` does not match the shape of `A` for column-based interdependence.
    """
    m = m if m is not None else self.m
    if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
        assert self.A is not None and m is not None and self.A.size(0) == m
        return self.A.size(1)
    else:
        return m

update_A(A)

Updates the interdependence matrix A.

Parameters:

Name Type Description Default
A Tensor

The new interdependence matrix of shape (b, b_prime) or (m, m_prime).

required

Raises:

Type Description
ValueError

If A is None or does not have 2 dimensions.

Source code in tinybig/interdependence/basic_interdependence.py
def update_A(self, A: torch.Tensor):
    """
        Updates the interdependence matrix `A`.

        Parameters
        ----------
        A : torch.Tensor
            The new interdependence matrix of shape `(b, b_prime)` or `(m, m_prime)`.

        Raises
        ------
        ValueError
            If `A` is None or does not have 2 dimensions.
    """
    if A is None or A.ndim != 2:
        raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
    self.check_A_shape_validity(A=A)
    self.A = A