Skip to content

random_matrix_hypernet_reconciliation

Bases: fabrication

Source code in tinybig/reconciliation/random_matrix_reconciliation.py
class random_matrix_hypernet_reconciliation(fabrication):

    def __init__(self, name='random_matrix_hypernet_reconciliation', r: int = 2, l: int = 64, hidden_dim: int = 128, *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        self.r = r
        self.l = l
        self.hidden_dim = hidden_dim

        self.P = None
        self.Q = None
        self.S = None
        self.T = None

    def calculate_l(self, n: int = None, D: int = None):
        assert self.l is not None
        return self.l

    def forward(self, n: int, D: int, w: torch.nn.Parameter, device='cpu', *args, **kwargs):
        assert w.ndim == 2 and w.numel() == self.calculate_l(n=n, D=D)

        if self.P is None or (self.P is not None and self.P.shape != (self.l, self.r)):
            self.P = torch.randn(self.l, self.r, device=device)
        if self.Q is None or (self.Q is not None and self.Q.shape != (self.hidden_dim, self.r)):
            self.Q = torch.randn(self.hidden_dim, self.r, device=device)
        assert self.P.shape == (self.l, self.r) and self.Q.shape == (self.hidden_dim, self.r)

        if self.S is None or (self.S is not None and self.S.shape != (self.hidden_dim, self.r)):
            self.S = torch.randn(self.hidden_dim, self.r, device=device)
        if self.T is None or (self.T is not None and self.T.shape != (n*D, self.r)):
            self.T = torch.randn(n*D, self.r, device=device)
        assert self.S.shape == (self.hidden_dim, self.r) and self.T.shape == (n*D, self.r)

        W = torch.matmul(
            torch.matmul(
                F.sigmoid(torch.matmul(torch.matmul(w, self.P), self.Q.t())),
                self.S),
            self.T.t()
        ).view(n, D)

        assert W.shape == (n, D)
        return W