class legendre_expansion(transformation):
def __init__(self, name='legendre_polynomial_expansion', d: int = 2, *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
self.d = d
def calculate_D(self, m: int):
return m * self.d
def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
b, m = x.shape
x = self.pre_process(x=x, device=device)
# base case: order 0
expansion = torch.ones(size=[x.size(0), x.size(1), self.d + 1]).to(device)
# base case: order 1
if self.d > 0:
expansion[:, :, 1] = x
# high-order cases
for n in range(2, self.d + 1):
expansion[:, :, n] = (2*n-1)/n * x * expansion[:, :, n-1].clone() - (n-1)/n * expansion[:, :, n-2].clone()
expansion = expansion[:, :, 1:].contiguous().view(x.size(0), -1)
assert expansion.shape == (b, self.calculate_D(m=m))
return self.post_process(x=expansion, device=device)