Skip to content

beta_wavelet

Bases: discrete_wavelet

Source code in tinybig/koala/signal_processing/wavelet.py
class beta_wavelet(discrete_wavelet):
    def __init__(self, alpha: float = 1.0, beta: float = 1.0, name: str = 'beta_wavelet', *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)
        self.alpha = alpha
        self.beta = beta
        if self.alpha < 1.0 or self.beta < 1.0:
            raise ValueError('alpha and beta must be >= 1.')

    def psi(self, tau: torch.Tensor):
        if not torch.all((tau >= 0) & (tau <= 1)):
            tau = torch.sigmoid(tau)
        assert torch.all((tau >= 0) & (tau <= 1))
        beta_coeff = 1.0/beta(self.alpha, self.beta)
        return beta_coeff * tau**(self.alpha - 1) * (1.0 - tau)**(self.beta - 1.0)