class identity_interdependence(constant_interdependence):
def __init__(
self,
b: int, m: int,
b_prime: int = None, m_prime: int = None,
name: str = 'identity_interdependence',
interdependence_type: str = 'attribute',
device: str = 'cpu',
*args, **kwargs
):
if interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert b_prime is not None
A = torch.eye(b, b_prime, device=device)
if b != b_prime:
warnings.warn("b and b_prime are different, this function will change the row dimensions of the inputs and cannot guarantee identity interdependence...")
elif interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert m_prime is not None
A = torch.eye(m, m_prime, device=device)
if m != m_prime:
warnings.warn("m and m_prime are different, this function will change the column dimensions of the inputs and cannot guarantee identity interdependence...")
else:
raise ValueError(f'Interdependence type {interdependence_type} is not supported')
super().__init__(b=b, m=m, A=A, name=name, interdependence_type=interdependence_type, device=device, *args, **kwargs)