class meyer_wavelet(discrete_wavelet):
def __init__(self, name: str = 'meyer_wavelet', *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
def psi(self, tau: torch.Tensor):
result = torch.zeros_like(tau)
zero_mask = (tau == 0)
result[zero_mask] = 2.0/3.0 + 4.0/(3.0 * torch.pi)
nonzero_mask = ~zero_mask
t_nonzero = tau[nonzero_mask]
result[nonzero_mask] = (
(torch.sin((2.0 * torch.pi / 3.0) * t_nonzero) +
4.0/3.0 * t_nonzero * torch.cos((4.0 * torch.pi / 3.0) * t_nonzero)) /
(torch.pi * t_nonzero - (16.0 * torch.pi / 9.0) * t_nonzero ** 3)
)
return result