class dog_wavelet(discrete_wavelet):
def __init__(self, sigma_1: float = 1.0, sigma_2: float = 2.0, name: str = 'difference_of_Gaussians_wavelet', *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
if sigma_1 < 0.0 or sigma_2 < 0.0:
raise ValueError('sigma_1 and sigma_2 must be >= 0.')
self.sigma_1 = sigma_1
self.sigma_2 = sigma_2
def psi(self, tau: torch.Tensor):
gauss1 = torch.exp(-0.5 * (tau / self.sigma_1) ** 2) / (math.sqrt(2 * torch.pi * self.sigma_1 ** 2))
gauss2 = torch.exp(-0.5 * (tau / self.sigma_2) ** 2) / (math.sqrt(2 * torch.pi * self.sigma_2 ** 2))
return gauss1 - gauss2