Skip to content

statistical_kernel_based_interdependence

Bases: interdependence

A statistical kernel-based interdependence function.

This class computes the interdependence matrix using a specified statistical kernel function.

Notes

Formally, given a data batch \(\mathbf{X} \in R^{b \times m}\), we can define the statistical kernel-based interdependence function as:

\[
    \begin{equation}
    \xi(\mathbf{X}) = \mathbf{A} \in R^{m \times m'} \text{, where } \mathbf{A}(i, j) = \text{kernel} \left(\mathbf{X}(:, i), \mathbf{X}(:, j)\right).
    \end{equation}
\]

Attributes:

Name Type Description
kernel Callable

The kernel function used to compute the interdependence matrix.

Methods:

Name Description
__init__

Initializes the statistical kernel-based interdependence function.

calculate_A

Computes the interdependence matrix using the specified kernel.

Source code in tinybig/interdependence/statistical_kernel_interdependence.py
class statistical_kernel_based_interdependence(interdependence):
    r"""
        A statistical kernel-based interdependence function.

        This class computes the interdependence matrix using a specified statistical kernel function.

        Notes
        ----------

        Formally, given a data batch $\mathbf{X} \in R^{b \times m}$, we can define the statistical kernel-based interdependence function as:

        $$
            \begin{equation}
            \xi(\mathbf{X}) = \mathbf{A} \in R^{m \times m'} \text{, where } \mathbf{A}(i, j) = \text{kernel} \left(\mathbf{X}(:, i), \mathbf{X}(:, j)\right).
            \end{equation}
        $$

        Attributes
        ----------
        kernel : Callable
            The kernel function used to compute the interdependence matrix.

        Methods
        -------
        __init__(...)
            Initializes the statistical kernel-based interdependence function.
        calculate_A(x=None, w=None, device='cpu', *args, **kwargs)
            Computes the interdependence matrix using the specified kernel.
    """

    def __init__(
        self,
        b: int, m: int, kernel: Callable,
        interdependence_type: str = 'attribute',
        name: str = 'statistical_kernel_based_interdependence',
        require_data: bool = True,
        require_parameters: bool = False,
        device: str = 'cpu', *args, **kwargs
    ):
        """
            Initializes the statistical kernel-based interdependence function.

            Parameters
            ----------
            b : int
                Number of rows in the input tensor.
            m : int
                Number of columns in the input tensor.
            kernel : Callable
                The kernel function used to compute the interdependence matrix.
            interdependence_type : str, optional
                Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.
            name : str, optional
                Name of the interdependence function. Defaults to 'statistical_kernel_based_interdependence'.
            require_data : bool, optional
                If True, requires input data for matrix computation. Defaults to True.
            require_parameters : bool, optional
                If True, requires parameters for matrix computation. Defaults to False.
            device : str, optional
                Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments for the parent class.
            **kwargs : dict
                Additional keyword arguments for the parent class.

            Raises
            ------
            ValueError
                If no kernel function is provided.
        """
        super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_parameters=require_parameters, require_data=require_data, device=device, *args, **kwargs)

        if kernel is None:
            raise ValueError('the kernel is required for the statistical kernel based interdependence function')
        self.kernel = kernel
        self.kernel = kernel

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        """
            Computes the interdependence matrix using the specified kernel function.

            Parameters
            ----------
            x : torch.Tensor, optional
                Input tensor of shape `(batch_size, num_features)`. Required for computation. Defaults to None.
            w : torch.nn.Parameter, optional
                Parameter tensor. Defaults to None.
            device : str, optional
                Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments for interdependence functions.
            **kwargs : dict
                Additional keyword arguments for interdependence functions.

            Returns
            -------
            torch.Tensor
                The computed interdependence matrix.

            Raises
            ------
            AssertionError
                If `x` is not provided or has an incorrect shape.
        """
        if not self.require_data and not self.require_parameters and self.A is not None:
            return self.A
        else:
            assert x is not None and x.ndim == 2
            x = self.pre_process(x=x, device=device)
            A = self.kernel(x)
            A = self.post_process(x=A, device=device)

            if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
                assert A.shape == (self.m, self.calculate_m_prime())
            elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
                assert A.shape == (self.b, self.calculate_b_prime())

            if not self.require_data and not self.require_parameters and self.A is None:
                self.A = A

            return A

__init__(b, m, kernel, interdependence_type='attribute', name='statistical_kernel_based_interdependence', require_data=True, require_parameters=False, device='cpu', *args, **kwargs)

Initializes the statistical kernel-based interdependence function.

Parameters:

Name Type Description Default
b int

Number of rows in the input tensor.

required
m int

Number of columns in the input tensor.

required
kernel Callable

The kernel function used to compute the interdependence matrix.

required
interdependence_type str

Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.

'attribute'
name str

Name of the interdependence function. Defaults to 'statistical_kernel_based_interdependence'.

'statistical_kernel_based_interdependence'
require_data bool

If True, requires input data for matrix computation. Defaults to True.

True
require_parameters bool

If True, requires parameters for matrix computation. Defaults to False.

False
device str

Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments for the parent class.

()
**kwargs dict

Additional keyword arguments for the parent class.

{}

Raises:

Type Description
ValueError

If no kernel function is provided.

Source code in tinybig/interdependence/statistical_kernel_interdependence.py
def __init__(
    self,
    b: int, m: int, kernel: Callable,
    interdependence_type: str = 'attribute',
    name: str = 'statistical_kernel_based_interdependence',
    require_data: bool = True,
    require_parameters: bool = False,
    device: str = 'cpu', *args, **kwargs
):
    """
        Initializes the statistical kernel-based interdependence function.

        Parameters
        ----------
        b : int
            Number of rows in the input tensor.
        m : int
            Number of columns in the input tensor.
        kernel : Callable
            The kernel function used to compute the interdependence matrix.
        interdependence_type : str, optional
            Type of interdependence ('attribute', 'instance', etc.). Defaults to 'attribute'.
        name : str, optional
            Name of the interdependence function. Defaults to 'statistical_kernel_based_interdependence'.
        require_data : bool, optional
            If True, requires input data for matrix computation. Defaults to True.
        require_parameters : bool, optional
            If True, requires parameters for matrix computation. Defaults to False.
        device : str, optional
            Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments for the parent class.
        **kwargs : dict
            Additional keyword arguments for the parent class.

        Raises
        ------
        ValueError
            If no kernel function is provided.
    """
    super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_parameters=require_parameters, require_data=require_data, device=device, *args, **kwargs)

    if kernel is None:
        raise ValueError('the kernel is required for the statistical kernel based interdependence function')
    self.kernel = kernel
    self.kernel = kernel

calculate_A(x=None, w=None, device='cpu', *args, **kwargs)

Computes the interdependence matrix using the specified kernel function.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, num_features). Required for computation. Defaults to None.

None
w Parameter

Parameter tensor. Defaults to None.

None
device str

Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments for interdependence functions.

()
**kwargs dict

Additional keyword arguments for interdependence functions.

{}

Returns:

Type Description
Tensor

The computed interdependence matrix.

Raises:

Type Description
AssertionError

If x is not provided or has an incorrect shape.

Source code in tinybig/interdependence/statistical_kernel_interdependence.py
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
    """
        Computes the interdependence matrix using the specified kernel function.

        Parameters
        ----------
        x : torch.Tensor, optional
            Input tensor of shape `(batch_size, num_features)`. Required for computation. Defaults to None.
        w : torch.nn.Parameter, optional
            Parameter tensor. Defaults to None.
        device : str, optional
            Device for computation (e.g., 'cpu' or 'cuda'). Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments for interdependence functions.
        **kwargs : dict
            Additional keyword arguments for interdependence functions.

        Returns
        -------
        torch.Tensor
            The computed interdependence matrix.

        Raises
        ------
        AssertionError
            If `x` is not provided or has an incorrect shape.
    """
    if not self.require_data and not self.require_parameters and self.A is not None:
        return self.A
    else:
        assert x is not None and x.ndim == 2
        x = self.pre_process(x=x, device=device)
        A = self.kernel(x)
        A = self.post_process(x=A, device=device)

        if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            assert A.shape == (self.m, self.calculate_m_prime())
        elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            assert A.shape == (self.b, self.calculate_b_prime())

        if not self.require_data and not self.require_parameters and self.A is None:
            self.A = A

        return A