class ricker_wavelet(discrete_wavelet):
def __init__(self, sigma: float = 1.0, name: str = 'ricker_wavelet', *args, **kwargs):
super().__init__(name=name, *args, **kwargs)
if sigma < 0.0:
raise ValueError('sigma must be >= 0.')
self.sigma = sigma
def psi(self, tau: torch.Tensor):
term1 = 2.0*(1.0-(tau/self.sigma)**2)/(math.sqrt(3*self.sigma)*(torch.pi**0.25))
term2 = torch.exp(-tau**2/(2*self.sigma**2))
return term1*term2