Skip to content

hermite_expansion

Bases: transformation

Source code in tinybig/expansion/orthogonal_polynomial_expansion.py
class hermite_expansion(transformation):

    def __init__(self, name: str = 'chebyshev_polynomial_expansion', d: int = 2, *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        self.d = d

    def calculate_D(self, m: int):
        return m * self.d

    def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        b, m = x.shape
        x = self.pre_process(x=x, device=device)

        # base case: order 0
        expansion = torch.ones(size=[x.size(0), x.size(1), self.d + 1]).to(device)
        # base case: order 1
        if self.d > 0:
            expansion[:, :, 1] = x
        # high-order cases
        for n in range(2, self.d + 1):
            expansion[:, :, n] = x * expansion[:, :, n-1].clone() - (n-1) * expansion[:, :, n-2].clone()

        expansion = expansion[:, :, 1:].contiguous().view(x.size(0), -1)

        assert expansion.shape == (b, self.calculate_D(m=m))
        return self.post_process(x=expansion, device=device)