Skip to content

dog_wavelet

Bases: discrete_wavelet

Source code in tinybig/koala/signal_processing/wavelet.py
class dog_wavelet(discrete_wavelet):
    def __init__(self, sigma_1: float = 1.0, sigma_2: float = 2.0, name: str = 'difference_of_Gaussians_wavelet', *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        if sigma_1 < 0.0 or sigma_2 < 0.0:
            raise ValueError('sigma_1 and sigma_2 must be >= 0.')
        self.sigma_1 = sigma_1
        self.sigma_2 = sigma_2

    def psi(self, tau: torch.Tensor):
        gauss1 = torch.exp(-0.5 * (tau / self.sigma_1) ** 2) / (math.sqrt(2 * torch.pi * self.sigma_1 ** 2))
        gauss2 = torch.exp(-0.5 * (tau / self.sigma_2) ** 2) / (math.sqrt(2 * torch.pi * self.sigma_2 ** 2))
        return gauss1 - gauss2