class random_matrix_adaption_reconciliation(fabrication):
def __init__(self, name: str = 'random_matrix_adaption_reconciliation', r: int = 2, *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
self.r = r
self.A = None
self.B = None
def calculate_l(self, n: int, D: int):
return n + self.r
def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)
lambda_1, lambda_2 = torch.split(w, [n, self.r], dim=1)
Lambda_1 = torch.diag(lambda_1.view(-1)).to(device)
Lambda_2 = torch.diag(lambda_2.view(-1)).to(device)
if self.A is None or (self.A is not None and self.A.shape != (n, self.r)):
self.A = torch.randn(n, self.r, device=device)
if self.B is None or (self.B is not None and self.B.shape != (D, self.r)):
self.B = torch.randn(D, self.r, device=device)
assert self.A.shape == (n, self.r) and self.B.shape == (D, self.r)
W = torch.matmul(torch.matmul(torch.matmul(Lambda_1, self.A), Lambda_2), self.B.t())
assert W.shape == (n, D)
return W