Skip to content

identity_interdependence

Bases: constant_interdependence

Source code in tinybig/interdependence/basic_interdependence.py
class identity_interdependence(constant_interdependence):

    def __init__(
        self,
        b: int, m: int,
        b_prime: int = None, m_prime: int = None,
        name: str = 'identity_interdependence',
        interdependence_type: str = 'attribute',
        device: str = 'cpu',
        *args, **kwargs
    ):
        if interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            assert b_prime is not None
            A = torch.eye(b, b_prime, device=device)
            if b != b_prime:
                warnings.warn("b and b_prime are different, this function will change the row dimensions of the inputs and cannot guarantee identity interdependence...")
        elif interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            assert m_prime is not None
            A = torch.eye(m, m_prime, device=device)
            if m != m_prime:
                warnings.warn("m and m_prime are different, this function will change the column dimensions of the inputs and cannot guarantee identity interdependence...")

        else:
            raise ValueError(f'Interdependence type {interdependence_type} is not supported')
        super().__init__(b=b, m=m, A=A, name=name, interdependence_type=interdependence_type, device=device, *args, **kwargs)