Skip to content

constant_interdependence

Bases: interdependence

Source code in tinybig/interdependence/basic_interdependence.py
class constant_interdependence(interdependence):

    def __init__(
        self,
        b: int, m: int,
        A: torch.Tensor,
        interdependence_type: str = 'attribute',
        name: str = 'constant_interdependence',
        device: str = 'cpu',
        *args, **kwargs
    ):
        super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=False, require_parameters=False, device=device, *args, **kwargs)
        if A is None or A.ndim != 2:
            raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
        self.A = A
        if self.A.device != device:
            self.A.to(device)

    def update_A(self, A: torch.Tensor):
        if A is None or A.ndim != 2:
            raise ValueError('The parameter matrix A is required and should have ndim: 2 by default')
        self.check_A_shape_validity(A=A)
        self.A = A

    def calculate_b_prime(self, b: int = None):
        b = b if b is not None else self.b
        if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            assert self.A is not None and b is not None and self.A.size(0) == b
            return self.A.size(1)
        else:
            return b

    def calculate_m_prime(self, m: int = None):
        m = m if m is not None else self.m
        if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            assert self.A is not None and m is not None and self.A.size(0) == m
            return self.A.size(1)
        else:
            return m

    def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        assert self.A is not None and self.require_data is False and self.require_parameters is False
        return self.A