class discrete_wavelet(object):
def __init__(self, name: str = 'discrete_wavelet', a: float = 1.0, b: float = 1.0, *args, **kwargs):
if a < 1 or b < 0:
raise ValueError('a must be > 1 and b mush > 0...')
self.name = name
self.a = a
self.b = b
@abstractmethod
def psi(self, tau: torch.Tensor):
pass
def forward(self, x: torch.Tensor, s: int, t: int):
tau = x/(self.a**s) - t*self.b
return 1.0/math.sqrt(self.a**s) * self.psi(tau=tau)
def __call__(self, x: torch.Tensor, s: int, t: int):
return self.forward(x=x, s=s, t=t)