Skip to content

discrete_wavelet

Bases: object

Source code in tinybig/koala/signal_processing/wavelet.py
class discrete_wavelet(object):
    def __init__(self, name: str = 'discrete_wavelet', a: float = 1.0, b: float = 1.0, *args, **kwargs):
        if a < 1 or b < 0:
            raise ValueError('a must be > 1 and b mush > 0...')

        self.name = name
        self.a = a
        self.b = b

    @abstractmethod
    def psi(self, tau: torch.Tensor):
        pass

    def forward(self, x: torch.Tensor, s: int, t: int):
        tau = x/(self.a**s) - t*self.b
        return 1.0/math.sqrt(self.a**s) * self.psi(tau=tau)

    def __call__(self, x: torch.Tensor, s: int, t: int):
        return self.forward(x=x, s=s, t=t)