Skip to content

naive_bayes_head

Bases: head

Source code in tinybig/head/basic_heads.py
class naive_bayes_head(head):

    def __init__(
        self, m: int, n: int,
        name: str = 'perceptron_head',
        distribution: str = 'normal',
        enable_bias: bool = False,
        # optional parameters
        with_lorr: bool = False,
        r: int = 3,
        with_residual: bool = False,
        channel_num: int = 1,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):
        if distribution == 'normal':
            data_transformation = naive_normal_expansion(
                device=device,
            )
        elif distribution == 'exponential':
            data_transformation = naive_exponential_expansion(
                device=device,
            )
        elif distribution == 'cauchy':
            data_transformation = naive_cauchy_expansion(
                device=device,
            )
        elif distribution == 'gamma':
            data_transformation = naive_gamma_expansion(
                device=device,
            )
        elif distribution == 'chi2':
            data_transformation = naive_chi2_expansion(
                device=device,
            )
        elif distribution == 'laplace':
            data_transformation = naive_laplace_expansion(
                device=device,
            )
        else:
            raise ValueError('tinybig only supports normal, exponential, cauchy, gamma, laplace or chi2 distributions...')

        if with_lorr:
            parameter_fabrication = lorr_reconciliation(
                r=r,
                enable_bias=enable_bias,
                device=device,
            )
        else:
            parameter_fabrication = identity_reconciliation(
                enable_bias=enable_bias,
                device=device,
            )

        if with_residual:
            remainder = linear_remainder(
                device=device
            )
        else:
            remainder = zero_remainder(
                device=device,
            )

        super().__init__(
            m=m, n=n, name=name,
            data_transformation=data_transformation,
            parameter_fabrication=parameter_fabrication,
            remainder=remainder,
            channel_num=channel_num,
            parameters_init_method=parameters_init_method,
            device=device, *args, **kwargs
        )