Skip to content

numerical_kernel_based_interdependence

Bases: interdependence

Source code in tinybig/interdependence/numerical_kernel_interdependence.py
class numerical_kernel_based_interdependence(interdependence):

    def __init__(
        self,
        b: int, m: int, kernel: Callable,
        interdependence_type: str = 'attribute',
        name: str = 'kernel_based_interdependence',
        require_data: bool = True,
        require_parameters: bool = False,
        device: str = 'cpu', *args, **kwargs
    ):
        super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_parameters=require_parameters, require_data=require_data, device=device, *args, **kwargs)

        if kernel is None:
            raise ValueError('the kernel is required for the kernel based interdependence function')
        self.kernel = kernel

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        if not self.require_data and not self.require_parameters and self.A is not None:
            return self.A
        else:
            assert x is not None and x.ndim == 2
            x = self.pre_process(x=x, device=device)
            A = self.kernel(x)
            A = self.post_process(x=A, device=device)

            if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
                print(x.shape, A.shape, self.m, self.calculate_m_prime())
                assert A.shape == (self.m, self.calculate_m_prime())
            elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
                assert A.shape == (self.b, self.calculate_b_prime())

            if not self.require_data and not self.require_parameters and self.A is None:
                self.A = A

            return A