Skip to content

duplicated_diagonal_padding_reconciliation

Bases: duplicated_padding_reconciliation

Source code in tinybig/reconciliation/basic_reconciliation.py
class duplicated_diagonal_padding_reconciliation(duplicated_padding_reconciliation):
    def __init__(self, name='duplicated_diagonal_padding_reconciliation', *args, **kwargs):
        super().__init__(name, *args, **kwargs)

    def forward(self, n: int, D: int, w: torch.nn.Parameter, device: str = 'cpu', *args, **kwargs):
        assert w.ndim == 2 and w.size(1) == self.calculate_l(n=n, D=D)
        assert self.p == n and self.q * self.calculate_l(n=n, D=D) == D
        W = torch.block_diag(*[w]*self.p).view(n, D)
        if device == 'mps':
            return W.to(device)
        else:
            return W.to_sparse_coo().to(device)