Bases: duplicated_padding_reconciliation
Source code in tinybig/reconciliation/basic_reconciliation.py
| class duplicated_diagonal_padding_reconciliation(duplicated_padding_reconciliation):
def __init__(self, name='duplicated_diagonal_padding_reconciliation', *args, **kwargs):
super().__init__(name, *args, **kwargs)
def forward(self, n: int, D: int, w: torch.nn.Parameter, device: str = 'cpu', *args, **kwargs):
assert w.ndim == 2 and w.size(1) == self.calculate_l(n=n, D=D)
assert self.p == n and self.q * self.calculate_l(n=n, D=D) == D
W = torch.block_diag(*[w]*self.p).view(n, D)
if device == 'mps':
return W.to(device)
else:
return W.to_sparse_coo().to(device)
|