Skip to content

hm_parameterized_bilinear_interdependence

Bases: parameterized_bilinear_interdependence

A hierarchical mapping (HM) parameterized bilinear interdependence function.

Notes

Formally, given a data batch \(\mathbf{X} \in R^{b \times m}\), we can represent the parameterized bilinear form-based interdependence function as follows:

\[
    \begin{equation}\label{equ:bilinear_interdependence_function}
    \xi(\mathbf{X} | \mathbf{w}) = \mathbf{X}^\top \mathbf{W} \mathbf{X} = \mathbf{A} \in R^{m \times m}.
    \end{equation}
\]

Notation \(\mathbf{W} \in R^{b \times b}\) denotes the parameter matrix fabricated from the learnable parameter vector \(\mathbf{w} \in R^{l_{\xi}}\), which can be represented as follows:

$$ \begin{equation} \psi(\mathbf{w}) = \mathbf{A} \otimes \mathbf{B} \in R^{b \times b}, \end{equation} $$ where \(\mathbf{A} \in R^{p \times q}\) and \(\mathbf{B} \in R^{s \times t}\) (where \(s =\frac{b}{p}\) and \(t = \frac{b}{q}\)) are partitioned and reshaped from the parameter vector \(\mathbf{w}\).

The required length of parameter vector of this interdependence function is \(l_{\xi} = pq + \frac{b^2}{pq}\).

Attributes:

Name Type Description
p int

Number of partitions in the input dimension.

q int

Number of partitions in the output dimension.

Methods:

Name Description
__init__

Initializes the hierarchical mapping parameterized bilinear interdependence function.

Source code in tinybig/interdependence/parameterized_bilinear_interdependence.py
class hm_parameterized_bilinear_interdependence(parameterized_bilinear_interdependence):
    r"""
        A hierarchical mapping (HM) parameterized bilinear interdependence function.

        Notes
        ----------
        Formally, given a data batch $\mathbf{X} \in R^{b \times m}$, we can represent the parameterized bilinear form-based interdependence function as follows:

        $$
            \begin{equation}\label{equ:bilinear_interdependence_function}
            \xi(\mathbf{X} | \mathbf{w}) = \mathbf{X}^\top \mathbf{W} \mathbf{X} = \mathbf{A} \in R^{m \times m}.
            \end{equation}
        $$

        Notation $\mathbf{W} \in R^{b \times b}$ denotes the parameter matrix fabricated from the learnable parameter vector $\mathbf{w} \in R^{l_{\xi}}$,
        which can be represented as follows:

        $$
            \begin{equation}
            \psi(\mathbf{w}) = \mathbf{A} \otimes \mathbf{B} \in R^{b \times b},
            \end{equation}
        $$
        where $\mathbf{A} \in R^{p \times q}$ and $\mathbf{B} \in R^{s \times t}$ (where $s =\frac{b}{p}$ and $t = \frac{b}{q}$) are partitioned and reshaped from the parameter vector $\mathbf{w}$.

        The required length of parameter vector of this interdependence function is $l_{\xi} = pq + \frac{b^2}{pq}$.

        Attributes
        ----------
        p : int
            Number of partitions in the input dimension.
        q : int
            Number of partitions in the output dimension.

        Methods
        -------
        __init__(...)
            Initializes the hierarchical mapping parameterized bilinear interdependence function.
    """
    def __init__(self, p: int, q: int = None, name: str = 'hm_parameterized_bilinear_interdependence', *args, **kwargs):
        """
            Initializes the hierarchical mapping parameterized bilinear interdependence function.

            Parameters
            ----------
            p : int
                Number of partitions in the input dimension.
            q : int, optional
                Number of partitions in the output dimension. Defaults to `p`.
            name : str, optional
                Name of the interdependence function. Defaults to 'hm_parameterized_bilinear_interdependence'.
            *args : tuple
                Additional positional arguments.
            **kwargs : dict
                Additional keyword arguments.
        """

        super().__init__(name=name, *args, **kwargs)

        self.p = p
        self.q = q if q is not None else p

        if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
            d, d_prime = self.m, self.calculate_m_prime()
        elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
            d, d_prime = self.b, self.calculate_b_prime()
        else:
            raise ValueError(f'Interdependence type {self.interdependence_type} not supported')

        assert d % self.p == 0 and d_prime % self.q == 0

        self.parameter_fabrication = hm_reconciliation(p=self.p, q=self.q)

__init__(p, q=None, name='hm_parameterized_bilinear_interdependence', *args, **kwargs)

Initializes the hierarchical mapping parameterized bilinear interdependence function.

Parameters:

Name Type Description Default
p int

Number of partitions in the input dimension.

required
q int

Number of partitions in the output dimension. Defaults to p.

None
name str

Name of the interdependence function. Defaults to 'hm_parameterized_bilinear_interdependence'.

'hm_parameterized_bilinear_interdependence'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/interdependence/parameterized_bilinear_interdependence.py
def __init__(self, p: int, q: int = None, name: str = 'hm_parameterized_bilinear_interdependence', *args, **kwargs):
    """
        Initializes the hierarchical mapping parameterized bilinear interdependence function.

        Parameters
        ----------
        p : int
            Number of partitions in the input dimension.
        q : int, optional
            Number of partitions in the output dimension. Defaults to `p`.
        name : str, optional
            Name of the interdependence function. Defaults to 'hm_parameterized_bilinear_interdependence'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.
    """

    super().__init__(name=name, *args, **kwargs)

    self.p = p
    self.q = q if q is not None else p

    if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
        d, d_prime = self.m, self.calculate_m_prime()
    elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
        d, d_prime = self.b, self.calculate_b_prime()
    else:
        raise ValueError(f'Interdependence type {self.interdependence_type} not supported')

    assert d % self.p == 0 and d_prime % self.q == 0

    self.parameter_fabrication = hm_reconciliation(p=self.p, q=self.q)