Skip to content

constant_c_interdependence

Bases: constant_interdependence

Source code in tinybig/interdependence/basic_interdependence.py
class constant_c_interdependence(constant_interdependence):

    def __init__(
        self,
        b: int, m: int,
        b_prime: int = None, m_prime: int = None,
        c: float | int = 1.0,
        name: str = 'constant_c_interdependence',
        interdependence_type: str = 'attribute',
        device: str = 'cpu',
        *args, **kwargs
    ):
        if interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            assert b_prime is not None
            A = c * torch.ones((b, b_prime), device=device)
        elif interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            assert m_prime is not None
            A = c * torch.ones((m, m_prime), device=device)
        else:
            raise ValueError(f'Interdependence type {interdependence_type} is not supported')
        super().__init__(b=b, m=m, A=A, name=name, interdependence_type=interdependence_type, device=device, *args, **kwargs)