Skip to content


Bases: transformation

Discrete Wavelet Expansion Transformation.

Implements the discrete wavelet expansion transformation, enabling feature expansion based on wavelet functions.


Formally, given the input variable \(\mathbf{x} \in R^{m}\), to approximate the underlying mapping \(f: R^m \to R^n\) with wavelet analysis, we can define the approximated output as

    f(\mathbf{x}) \approx \sum_{s, t} \left \langle f(\mathbf{x}), \phi_{s, t} (\mathbf{x} | a, b) \right \rangle \cdot \phi_{s, t} (\mathbf{x} | a, b),

where \(\phi_{s, t} (\cdot | a, b)\) denotes the child wavelet defined by hyper-parameters \(a > 1\) and \(b > 0\):

    \phi_{s, t}(x | a, b) = \frac{1}{\sqrt{a^s}} \phi \left( \frac{x - t \cdot b \cdot a^s}{a^s} \right).

Based on the wavelet mapping \(\phi_{s, t} (\cdot | a, b)\), we can introduce the \(1_{st}\)-order and \(2_{nd}\)-order wavelet data expansion functions as follows:

    \kappa(\mathbf{x} | d=1) = \left[ \phi_{0, 0}(\mathbf{x}), \phi_{0, 1}(\mathbf{x}), \cdots, \phi_{s, t}(\mathbf{x}) \right] \in R^{D_1}.


    \kappa(\mathbf{x} | d=2) = \kappa(\mathbf{x} | d=1) \otimes \kappa(\mathbf{x} | d=1) \in R^{D_2}.

The output dimensions of the order-1 and order-2 wavelet expansions are \(D_1 = s \cdot t \cdot m\) and \(D_2 = (s \cdot t \cdot m)^2\), respectively.


Name Type Description
name str

Name of the transformation.

d int

Maximum order of wavelet-based polynomial expansion.

s int

Number of scaling factors for the wavelet.

t int

Number of translation factors for the wavelet.

wavelet callable

The wavelet function applied during the transformation.


Name Description

Calculate the total dimensionality of the expanded feature space.


Apply the wavelet transformation to the input data.


Perform the discrete wavelet expansion on the input data.

Source code in tinybig/expansion/
class discrete_wavelet_expansion(transformation):
        Discrete Wavelet Expansion Transformation.

        Implements the discrete wavelet expansion transformation, enabling feature expansion based on wavelet functions.


        Formally, given the input variable $\mathbf{x} \in R^{m}$, to approximate the underlying mapping $f: R^m \to R^n$ with wavelet analysis, we can define the approximated output as

            f(\mathbf{x}) \approx \sum_{s, t} \left \langle f(\mathbf{x}), \phi_{s, t} (\mathbf{x} | a, b) \right \rangle \cdot \phi_{s, t} (\mathbf{x} | a, b),

        where $\phi_{s, t} (\cdot | a, b)$ denotes the child wavelet defined by hyper-parameters $a > 1$ and $b > 0$:

            \phi_{s, t}(x | a, b) = \frac{1}{\sqrt{a^s}} \phi \left( \frac{x - t \cdot b \cdot a^s}{a^s} \right).

        Based on the wavelet mapping $\phi_{s, t} (\cdot | a, b)$, we can introduce the $1_{st}$-order and $2_{nd}$-order wavelet data expansion functions as follows:

            \kappa(\mathbf{x} | d=1) = \left[ \phi_{0, 0}(\mathbf{x}), \phi_{0, 1}(\mathbf{x}), \cdots, \phi_{s, t}(\mathbf{x}) \right] \in R^{D_1}.


            \kappa(\mathbf{x} | d=2) = \kappa(\mathbf{x} | d=1) \otimes \kappa(\mathbf{x} | d=1) \in R^{D_2}.

        The output dimensions of the order-1 and order-2 wavelet expansions are $D_1 = s \cdot t \cdot m$ and $D_2 = (s \cdot t \cdot m)^2$, respectively.

        name : str
            Name of the transformation.
        d : int
            Maximum order of wavelet-based polynomial expansion.
        s : int
            Number of scaling factors for the wavelet.
        t : int
            Number of translation factors for the wavelet.
        wavelet : callable
            The wavelet function applied during the transformation.

        calculate_D(m: int)
            Calculate the total dimensionality of the expanded feature space.
        wavelet_x(x: torch.Tensor, device: str = 'cpu', *args, **kwargs)
            Apply the wavelet transformation to the input data.
        forward(x: torch.Tensor, device: str = 'cpu', *args, **kwargs)
            Perform the discrete wavelet expansion on the input data.
    def __init__(self, name: str = 'discrete_wavelet_expansion', d: int = 1, s: int = 1, t: int = 1, *args, **kwargs):
            Initializes the discrete wavelet expansion transformation.

            name : str, optional
                Name of the transformation. Defaults to 'discrete_wavelet_expansion'.
            d : int, optional
                The maximum order of wavelet-based polynomial expansion. Defaults to 1.
            s : int, optional
                The number of scaling factors for the wavelet. Defaults to 1.
            t : int, optional
                The number of translation factors for the wavelet. Defaults to 1.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.
        super().__init__(name=name, *args, **kwargs)
        self.d = d
        self.s = s
        self.t = t
        self.wavelet = None

    def calculate_D(self, m: int):
            Calculates the expanded dimensionality of the transformed data.

            m : int
                The original number of features.

                The total number of features after expansion.
        return np.sum([(m * self.s * self.t) ** d for d in range(1, self.d + 1)])

    def wavelet_x(self, x: torch.Tensor, device='cpu', *args, **kwargs):
            Applies the wavelet function to the input data.

            x : torch.Tensor
                The input tensor to be transformed.
            device : str, optional
                The device to perform computation on. Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

                The transformed tensor after applying the wavelet function.
        assert self.wavelet is not None and isinstance(self.wavelet, discrete_wavelet)

        combinations = list(itertools.product(range(self.s), range(self.t)))
        combination_index = {comb: idx for idx, comb in enumerate(combinations)}

        expansion = torch.ones(size=[x.size(0), x.size(1), self.s * self.t]).to(device)
        for s, t in combination_index:
            n = combination_index[(s, t)]
            expansion[:, :, n] = self.wavelet(x=x, s=s, t=t)
        expansion = expansion[:, :, :].contiguous().view(x.size(0), -1)
        return expansion

    def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
            Expands the input data using discrete wavelet expansion.

            x : torch.Tensor
                The input tensor to be expanded.
            device : str, optional
                The device to perform computation on. Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

                The expanded tensor.
        b, m = x.shape
        x = self.pre_process(x=x, device=device)

        wavelet_x = self.wavelet_x(x, device=device, *args, **kwargs)

        if self.d > 1:
            wavelet_x_powers = torch.ones(size=[wavelet_x.size(0), 1]).to(device)
            expansion = torch.Tensor([]).to(device)

            for i in range(1, self.d + 1):
                wavelet_x_powers = torch.einsum('ba,bc->bac', wavelet_x_powers.clone(), wavelet_x).view(wavelet_x_powers.size(0), wavelet_x_powers.size(1) * wavelet_x.size(1))
                expansion =, wavelet_x_powers), dim=1)
            expansion = wavelet_x

        assert expansion.shape == (b, self.calculate_D(m=m))
        return self.post_process(x=expansion, device=device)

__init__(name='discrete_wavelet_expansion', d=1, s=1, t=1, *args, **kwargs)

Initializes the discrete wavelet expansion transformation.


Name Type Description Default
name str

Name of the transformation. Defaults to 'discrete_wavelet_expansion'.

d int

The maximum order of wavelet-based polynomial expansion. Defaults to 1.

s int

The number of scaling factors for the wavelet. Defaults to 1.

t int

The number of translation factors for the wavelet. Defaults to 1.

*args tuple

Additional positional arguments.

**kwargs dict

Additional keyword arguments.

Source code in tinybig/expansion/
def __init__(self, name: str = 'discrete_wavelet_expansion', d: int = 1, s: int = 1, t: int = 1, *args, **kwargs):
        Initializes the discrete wavelet expansion transformation.

        name : str, optional
            Name of the transformation. Defaults to 'discrete_wavelet_expansion'.
        d : int, optional
            The maximum order of wavelet-based polynomial expansion. Defaults to 1.
        s : int, optional
            The number of scaling factors for the wavelet. Defaults to 1.
        t : int, optional
            The number of translation factors for the wavelet. Defaults to 1.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.
    super().__init__(name=name, *args, **kwargs)
    self.d = d
    self.s = s
    self.t = t
    self.wavelet = None


Calculates the expanded dimensionality of the transformed data.


Name Type Description Default
m int

The original number of features.



Type Description

The total number of features after expansion.

Source code in tinybig/expansion/
def calculate_D(self, m: int):
        Calculates the expanded dimensionality of the transformed data.

        m : int
            The original number of features.

            The total number of features after expansion.
    return np.sum([(m * self.s * self.t) ** d for d in range(1, self.d + 1)])

forward(x, device='cpu', *args, **kwargs)

Expands the input data using discrete wavelet expansion.


Name Type Description Default
x Tensor

The input tensor to be expanded.

device str

The device to perform computation on. Defaults to 'cpu'.

*args tuple

Additional positional arguments.

**kwargs dict

Additional keyword arguments.



Type Description

The expanded tensor.

Source code in tinybig/expansion/
def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        Expands the input data using discrete wavelet expansion.

        x : torch.Tensor
            The input tensor to be expanded.
        device : str, optional
            The device to perform computation on. Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

            The expanded tensor.
    b, m = x.shape
    x = self.pre_process(x=x, device=device)

    wavelet_x = self.wavelet_x(x, device=device, *args, **kwargs)

    if self.d > 1:
        wavelet_x_powers = torch.ones(size=[wavelet_x.size(0), 1]).to(device)
        expansion = torch.Tensor([]).to(device)

        for i in range(1, self.d + 1):
            wavelet_x_powers = torch.einsum('ba,bc->bac', wavelet_x_powers.clone(), wavelet_x).view(wavelet_x_powers.size(0), wavelet_x_powers.size(1) * wavelet_x.size(1))
            expansion =, wavelet_x_powers), dim=1)
        expansion = wavelet_x

    assert expansion.shape == (b, self.calculate_D(m=m))
    return self.post_process(x=expansion, device=device)

wavelet_x(x, device='cpu', *args, **kwargs)

Applies the wavelet function to the input data.


Name Type Description Default
x Tensor

The input tensor to be transformed.

device str

The device to perform computation on. Defaults to 'cpu'.

*args tuple

Additional positional arguments.

**kwargs dict

Additional keyword arguments.



Type Description

The transformed tensor after applying the wavelet function.

Source code in tinybig/expansion/
def wavelet_x(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        Applies the wavelet function to the input data.

        x : torch.Tensor
            The input tensor to be transformed.
        device : str, optional
            The device to perform computation on. Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

            The transformed tensor after applying the wavelet function.
    assert self.wavelet is not None and isinstance(self.wavelet, discrete_wavelet)

    combinations = list(itertools.product(range(self.s), range(self.t)))
    combination_index = {comb: idx for idx, comb in enumerate(combinations)}

    expansion = torch.ones(size=[x.size(0), x.size(1), self.s * self.t]).to(device)
    for s, t in combination_index:
        n = combination_index[(s, t)]
        expansion[:, :, n] = self.wavelet(x=x, s=s, t=t)
    expansion = expansion[:, :, :].contiguous().view(x.size(0), -1)
    return expansion