class parameterized_bilinear_interdependence(interdependence):
def __init__(
self,
b: int, m: int,
interdependence_type: str = 'instance',
name: str = 'parameterized_bilinear_interdependence',
require_parameters: bool = True,
require_data: bool = True,
device: str = 'cpu', *args, **kwargs
):
super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=require_data, require_parameters=require_parameters, device=device, *args, **kwargs)
self.parameter_fabrication = None
def calculate_b_prime(self, b: int = None):
b = b if b is not None else self.b
return b
def calculate_m_prime(self, m: int = None):
m = m if m is not None else self.m
return m
def calculate_l(self):
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
if self.parameter_fabrication is None:
return self.m ** 2
else:
return self.parameter_fabrication.calculate_l(n=self.m, D=self.m)
elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
if self.parameter_fabrication is None:
return self.b ** 2
else:
return self.parameter_fabrication.calculate_l(n=self.b, D=self.b)
else:
raise ValueError(f'Interdependence type {self.interdependence_type} not supported')
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
if not self.require_data and not self.require_parameters and self.A is not None:
return self.A
else:
assert x is not None and x.ndim == 2
assert w is not None and w.ndim == 2 and w.numel() == self.calculate_l()
x = self.pre_process(x=x, device=device)
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
# for instance interdependence, the parameter for calculating x.t*W*x will have dimension m*m'
d, d_prime = self.m, self.calculate_m_prime()
elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
# for attribute interdependence, the parameter for calculating x.t*W*x will have dimension b*b'
d, d_prime = self.b, self.calculate_b_prime()
else:
raise ValueError(f'Interdependence type {self.interdependence_type} not supported')
if self.parameter_fabrication is None:
W = w.reshape(d, d_prime).to(device=device)
else:
W = self.parameter_fabrication(w=w, n=d, D=d_prime, device=device)
A = torch.matmul(x.t(), torch.matmul(W, x))
A = self.post_process(x=A, device=device)
# if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
# assert A.shape == (self.m, self.calculate_m_prime())
# elif self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
# assert A.shape == (self.b, self.calculate_b_prime())
if not self.require_data and not self.require_parameters and self.A is None:
self.A = A
return A