class constant_c_interdependence(constant_interdependence):
def __init__(
self,
b: int, m: int,
b_prime: int = None, m_prime: int = None,
c: float | int = 1.0,
name: str = 'constant_c_interdependence',
interdependence_type: str = 'attribute',
device: str = 'cpu',
*args, **kwargs
):
if interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert b_prime is not None
A = c * torch.ones((b, b_prime), device=device)
elif interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert m_prime is not None
A = c * torch.ones((m, m_prime), device=device)
else:
raise ValueError(f'Interdependence type {interdependence_type} is not supported')
super().__init__(b=b, m=m, A=A, name=name, interdependence_type=interdependence_type, device=device, *args, **kwargs)