Skip to content

discrete_wavelet_expansion

Bases: transformation

Discrete Wavelet Expansion Transformation.

Implements the discrete wavelet expansion transformation, enabling feature expansion based on wavelet functions.

Notes

Formally, given the input variable \(\mathbf{x} \in R^{m}\), to approximate the underlying mapping \(f: R^m \to R^n\) with wavelet analysis, we can define the approximated output as

\[
    \begin{equation}
    f(\mathbf{x}) \approx \sum_{s, t} \left \langle f(\mathbf{x}), \phi_{s, t} (\mathbf{x} | a, b) \right \rangle \cdot \phi_{s, t} (\mathbf{x} | a, b),
    \end{equation}
\]

where \(\phi_{s, t} (\cdot | a, b)\) denotes the child wavelet defined by hyper-parameters \(a > 1\) and \(b > 0\):

\[
    \begin{equation}
    \phi_{s, t}(x | a, b) = \frac{1}{\sqrt{a^s}} \phi \left( \frac{x - t \cdot b \cdot a^s}{a^s} \right).
    \end{equation}
\]

Based on the wavelet mapping \(\phi_{s, t} (\cdot | a, b)\), we can introduce the \(1_{st}\)-order and \(2_{nd}\)-order wavelet data expansion functions as follows:

\[
    \begin{equation}
    \kappa(\mathbf{x} | d=1) = \left[ \phi_{0, 0}(\mathbf{x}), \phi_{0, 1}(\mathbf{x}), \cdots, \phi_{s, t}(\mathbf{x}) \right] \in R^{D_1}.
    \end{equation}
\]

and

\[
    \begin{equation}
    \kappa(\mathbf{x} | d=2) = \kappa(\mathbf{x} | d=1) \otimes \kappa(\mathbf{x} | d=1) \in R^{D_2}.
    \end{equation}
\]

The output dimensions of the order-1 and order-2 wavelet expansions are \(D_1 = s \cdot t \cdot m\) and \(D_2 = (s \cdot t \cdot m)^2\), respectively.

Attributes:

Name Type Description
name str

Name of the transformation.

d int

Maximum order of wavelet-based polynomial expansion.

s int

Number of scaling factors for the wavelet.

t int

Number of translation factors for the wavelet.

wavelet callable

The wavelet function applied during the transformation.

Methods:

Name Description
calculate_D

Calculate the total dimensionality of the expanded feature space.

wavelet_x

Apply the wavelet transformation to the input data.

forward

Perform the discrete wavelet expansion on the input data.

Source code in tinybig/expansion/wavelet_expansion.py
class discrete_wavelet_expansion(transformation):
    r"""
        Discrete Wavelet Expansion Transformation.

        Implements the discrete wavelet expansion transformation, enabling feature expansion based on wavelet functions.

        Notes
        ---------

        Formally, given the input variable $\mathbf{x} \in R^{m}$, to approximate the underlying mapping $f: R^m \to R^n$ with wavelet analysis, we can define the approximated output as

        $$
            \begin{equation}
            f(\mathbf{x}) \approx \sum_{s, t} \left \langle f(\mathbf{x}), \phi_{s, t} (\mathbf{x} | a, b) \right \rangle \cdot \phi_{s, t} (\mathbf{x} | a, b),
            \end{equation}
        $$

        where $\phi_{s, t} (\cdot | a, b)$ denotes the child wavelet defined by hyper-parameters $a > 1$ and $b > 0$:

        $$
            \begin{equation}
            \phi_{s, t}(x | a, b) = \frac{1}{\sqrt{a^s}} \phi \left( \frac{x - t \cdot b \cdot a^s}{a^s} \right).
            \end{equation}
        $$

        Based on the wavelet mapping $\phi_{s, t} (\cdot | a, b)$, we can introduce the $1_{st}$-order and $2_{nd}$-order wavelet data expansion functions as follows:

        $$
            \begin{equation}
            \kappa(\mathbf{x} | d=1) = \left[ \phi_{0, 0}(\mathbf{x}), \phi_{0, 1}(\mathbf{x}), \cdots, \phi_{s, t}(\mathbf{x}) \right] \in R^{D_1}.
            \end{equation}
        $$

        and

        $$
            \begin{equation}
            \kappa(\mathbf{x} | d=2) = \kappa(\mathbf{x} | d=1) \otimes \kappa(\mathbf{x} | d=1) \in R^{D_2}.
            \end{equation}
        $$

        The output dimensions of the order-1 and order-2 wavelet expansions are $D_1 = s \cdot t \cdot m$ and $D_2 = (s \cdot t \cdot m)^2$, respectively.


        Attributes
        ----------
        name : str
            Name of the transformation.
        d : int
            Maximum order of wavelet-based polynomial expansion.
        s : int
            Number of scaling factors for the wavelet.
        t : int
            Number of translation factors for the wavelet.
        wavelet : callable
            The wavelet function applied during the transformation.

        Methods
        -------
        calculate_D(m: int)
            Calculate the total dimensionality of the expanded feature space.
        wavelet_x(x: torch.Tensor, device: str = 'cpu', *args, **kwargs)
            Apply the wavelet transformation to the input data.
        forward(x: torch.Tensor, device: str = 'cpu', *args, **kwargs)
            Perform the discrete wavelet expansion on the input data.
    """
    def __init__(self, name: str = 'discrete_wavelet_expansion', d: int = 1, s: int = 1, t: int = 1, *args, **kwargs):
        """
            Initializes the discrete wavelet expansion transformation.

            Parameters
            ----------
            name : str, optional
                Name of the transformation. Defaults to 'discrete_wavelet_expansion'.
            d : int, optional
                The maximum order of wavelet-based polynomial expansion. Defaults to 1.
            s : int, optional
                The number of scaling factors for the wavelet. Defaults to 1.
            t : int, optional
                The number of translation factors for the wavelet. Defaults to 1.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.
        """
        super().__init__(name=name, *args, **kwargs)
        self.d = d
        self.s = s
        self.t = t
        self.wavelet = None

    def calculate_D(self, m: int):
        """
            Calculates the expanded dimensionality of the transformed data.

            Parameters
            ----------
            m : int
                The original number of features.

            Returns
            -------
            int
                The total number of features after expansion.
        """
        return np.sum([(m * self.s * self.t) ** d for d in range(1, self.d + 1)])

    def wavelet_x(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        """
            Applies the wavelet function to the input data.

            Parameters
            ----------
            x : torch.Tensor
                The input tensor to be transformed.
            device : str, optional
                The device to perform computation on. Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                The transformed tensor after applying the wavelet function.
        """
        assert self.wavelet is not None and isinstance(self.wavelet, discrete_wavelet)

        combinations = list(itertools.product(range(self.s), range(self.t)))
        combination_index = {comb: idx for idx, comb in enumerate(combinations)}

        expansion = torch.ones(size=[x.size(0), x.size(1), self.s * self.t]).to(device)
        for s, t in combination_index:
            n = combination_index[(s, t)]
            expansion[:, :, n] = self.wavelet(x=x, s=s, t=t)
        expansion = expansion[:, :, :].contiguous().view(x.size(0), -1)
        return expansion

    def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        """
            Expands the input data using discrete wavelet expansion.

            Parameters
            ----------
            x : torch.Tensor
                The input tensor to be expanded.
            device : str, optional
                The device to perform computation on. Defaults to 'cpu'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.

            Returns
            -------
            torch.Tensor
                The expanded tensor.
        """
        b, m = x.shape
        x = self.pre_process(x=x, device=device)

        wavelet_x = self.wavelet_x(x, device=device, *args, **kwargs)

        if self.d > 1:
            wavelet_x_powers = torch.ones(size=[wavelet_x.size(0), 1]).to(device)
            expansion = torch.Tensor([]).to(device)

            for i in range(1, self.d + 1):
                wavelet_x_powers = torch.einsum('ba,bc->bac', wavelet_x_powers.clone(), wavelet_x).view(wavelet_x_powers.size(0), wavelet_x_powers.size(1) * wavelet_x.size(1))
                expansion = torch.cat((expansion, wavelet_x_powers), dim=1)
        else:
            expansion = wavelet_x

        assert expansion.shape == (b, self.calculate_D(m=m))
        return self.post_process(x=expansion, device=device)

__init__(name='discrete_wavelet_expansion', d=1, s=1, t=1, *args, **kwargs)

Initializes the discrete wavelet expansion transformation.

Parameters:

Name Type Description Default
name str

Name of the transformation. Defaults to 'discrete_wavelet_expansion'.

'discrete_wavelet_expansion'
d int

The maximum order of wavelet-based polynomial expansion. Defaults to 1.

1
s int

The number of scaling factors for the wavelet. Defaults to 1.

1
t int

The number of translation factors for the wavelet. Defaults to 1.

1
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/expansion/wavelet_expansion.py
def __init__(self, name: str = 'discrete_wavelet_expansion', d: int = 1, s: int = 1, t: int = 1, *args, **kwargs):
    """
        Initializes the discrete wavelet expansion transformation.

        Parameters
        ----------
        name : str, optional
            Name of the transformation. Defaults to 'discrete_wavelet_expansion'.
        d : int, optional
            The maximum order of wavelet-based polynomial expansion. Defaults to 1.
        s : int, optional
            The number of scaling factors for the wavelet. Defaults to 1.
        t : int, optional
            The number of translation factors for the wavelet. Defaults to 1.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.
    """
    super().__init__(name=name, *args, **kwargs)
    self.d = d
    self.s = s
    self.t = t
    self.wavelet = None

calculate_D(m)

Calculates the expanded dimensionality of the transformed data.

Parameters:

Name Type Description Default
m int

The original number of features.

required

Returns:

Type Description
int

The total number of features after expansion.

Source code in tinybig/expansion/wavelet_expansion.py
def calculate_D(self, m: int):
    """
        Calculates the expanded dimensionality of the transformed data.

        Parameters
        ----------
        m : int
            The original number of features.

        Returns
        -------
        int
            The total number of features after expansion.
    """
    return np.sum([(m * self.s * self.t) ** d for d in range(1, self.d + 1)])

forward(x, device='cpu', *args, **kwargs)

Expands the input data using discrete wavelet expansion.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be expanded.

required
device str

The device to perform computation on. Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

The expanded tensor.

Source code in tinybig/expansion/wavelet_expansion.py
def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
    """
        Expands the input data using discrete wavelet expansion.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor to be expanded.
        device : str, optional
            The device to perform computation on. Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            The expanded tensor.
    """
    b, m = x.shape
    x = self.pre_process(x=x, device=device)

    wavelet_x = self.wavelet_x(x, device=device, *args, **kwargs)

    if self.d > 1:
        wavelet_x_powers = torch.ones(size=[wavelet_x.size(0), 1]).to(device)
        expansion = torch.Tensor([]).to(device)

        for i in range(1, self.d + 1):
            wavelet_x_powers = torch.einsum('ba,bc->bac', wavelet_x_powers.clone(), wavelet_x).view(wavelet_x_powers.size(0), wavelet_x_powers.size(1) * wavelet_x.size(1))
            expansion = torch.cat((expansion, wavelet_x_powers), dim=1)
    else:
        expansion = wavelet_x

    assert expansion.shape == (b, self.calculate_D(m=m))
    return self.post_process(x=expansion, device=device)

wavelet_x(x, device='cpu', *args, **kwargs)

Applies the wavelet function to the input data.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be transformed.

required
device str

The device to perform computation on. Defaults to 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

The transformed tensor after applying the wavelet function.

Source code in tinybig/expansion/wavelet_expansion.py
def wavelet_x(self, x: torch.Tensor, device='cpu', *args, **kwargs):
    """
        Applies the wavelet function to the input data.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor to be transformed.
        device : str, optional
            The device to perform computation on. Defaults to 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.

        Returns
        -------
        torch.Tensor
            The transformed tensor after applying the wavelet function.
    """
    assert self.wavelet is not None and isinstance(self.wavelet, discrete_wavelet)

    combinations = list(itertools.product(range(self.s), range(self.t)))
    combination_index = {comb: idx for idx, comb in enumerate(combinations)}

    expansion = torch.ones(size=[x.size(0), x.size(1), self.s * self.t]).to(device)
    for s, t in combination_index:
        n = combination_index[(s, t)]
        expansion[:, :, n] = self.wavelet(x=x, s=s, t=t)
    expansion = expansion[:, :, :].contiguous().view(x.size(0), -1)
    return expansion