Skip to content

random_matrix_adaption_reconciliation

Bases: fabrication

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
class random_matrix_adaption_reconciliation(fabrication):

    def __init__(self, name: str = 'random_matrix_adaption_reconciliation', r: int = 2, *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        self.r = r
        self.A = None
        self.B = None

    def calculate_l(self, n: int, D: int):
        return n + self.r

    def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
        assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)
        lambda_1, lambda_2 = torch.split(w, [n, self.r], dim=1)

        Lambda_1 = torch.diag(lambda_1.view(-1)).to(device)
        Lambda_2 = torch.diag(lambda_2.view(-1)).to(device)

        if self.A is None or (self.A is not None and self.A.shape != (n, self.r)):
            self.A = torch.randn(n, self.r, device=device)
        if self.B is None or (self.B is not None and self.B.shape != (D, self.r)):
            self.B = torch.randn(D, self.r, device=device)
        assert self.A.shape == (n, self.r) and self.B.shape == (D, self.r)

        W = torch.matmul(torch.matmul(torch.matmul(Lambda_1, self.A), Lambda_2), self.B.t())
        assert W.shape == (n, D)
        return W