Skip to content

kan_head

Bases: head

A knowledge-aware network (KAN)-based head using B-spline expansion.

Supports B-spline-based data transformation, parameter reconciliation, and flexible output processing.

Attributes:

Name Type Description
m int

Input dimension of the head.

n int

Output dimension of the head.

grid_range tuple

Range for B-spline grid.

t int

Number of knots for B-spline.

d int

Degree of the B-spline.

channel_num int

Number of channels for multi-channel processing.

parameters_init_method str

Initialization method for parameters.

device str

Device to host the head (e.g., 'cpu' or 'cuda').

Methods:

Name Description
__init__

Initializes the KAN head with specified configurations.

Source code in tinybig/head/basic_heads.py
class kan_head(head):
    """
    A knowledge-aware network (KAN)-based head using B-spline expansion.

    Supports B-spline-based data transformation, parameter reconciliation, and flexible output processing.

    Attributes
    ----------
    m : int
        Input dimension of the head.
    n : int
        Output dimension of the head.
    grid_range : tuple
        Range for B-spline grid.
    t : int
        Number of knots for B-spline.
    d : int
        Degree of the B-spline.
    channel_num : int
        Number of channels for multi-channel processing.
    parameters_init_method : str
        Initialization method for parameters.
    device : str
        Device to host the head (e.g., 'cpu' or 'cuda').

    Methods
    -------
    __init__(...)
        Initializes the KAN head with specified configurations.
    """
    def __init__(
        self, m: int, n: int,
        grid_range=(-1, 1), t: int = 5, d: int = 3,
        name: str = 'kan_head',
        enable_bias: bool = False,
        # optional parameters
        with_lorr: bool = False, r: int = 3,
        channel_num: int = 1,
        with_batch_norm: bool = False,
        with_softmax: bool = False,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):
        """
        Initializes the KAN head.

        Parameters
        ----------
        m : int
            Input dimension.
        n : int
            Output dimension.
        grid_range : tuple, optional
            Range for B-spline grid, default is (-1, 1).
        t : int, optional
            Number of knots for B-spline, default is 5.
        d : int, optional
            Degree of the B-spline, default is 3.
        name : str, optional
            Name of the KAN head, default is 'kan_head'.
        enable_bias : bool, optional
            Whether to enable bias in reconciliation functions, default is False.
        with_lorr : bool, optional
            Whether to use LORR reconciliation, default is False.
        r : int, optional
            Parameter for reconciliation functions, default is 3.
        channel_num : int, optional
            Number of channels, default is 1.
        with_batch_norm : bool, optional
            Whether to include batch normalization, default is False.
        with_softmax : bool, optional
            Whether to use softmax activation, default is False.
        parameters_init_method : str, optional
            Initialization method for parameters, default is 'xavier_normal'.
        device : str, optional
            Device to host the head, default is 'cpu'.

        Returns
        -------
        None
        """
        data_transformation = bspline_expansion(
            grid_range=grid_range,
            t=t, d=d,
            device=device,
        )

        if with_lorr:
            parameter_fabrication = lorr_reconciliation(
                r=r,
                enable_bias=enable_bias,
                device=device,
            )
        else:
            parameter_fabrication = identity_reconciliation(
                enable_bias=enable_bias,
                device=device,
            )

        remainder = linear_remainder(
            require_remainder_parameters=True,
            activation_functions=[torch.nn.SiLU()],
            device=device,
        )

        output_process_functions = []
        if with_batch_norm:
            output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
        if with_softmax:
            output_process_functions.append(torch.nn.Softmax(dim=-1))

        super().__init__(
            m=m, n=n, name=name,
            data_transformation=data_transformation,
            parameter_fabrication=parameter_fabrication,
            remainder=remainder,
            output_process_functions=output_process_functions,
            channel_num=channel_num,
            parameters_init_method=parameters_init_method,
            device=device, *args, **kwargs
        )

__init__(m, n, grid_range=(-1, 1), t=5, d=3, name='kan_head', enable_bias=False, with_lorr=False, r=3, channel_num=1, with_batch_norm=False, with_softmax=False, parameters_init_method='xavier_normal', device='cpu', *args, **kwargs)

Initializes the KAN head.

Parameters:

Name Type Description Default
m int

Input dimension.

required
n int

Output dimension.

required
grid_range tuple

Range for B-spline grid, default is (-1, 1).

(-1, 1)
t int

Number of knots for B-spline, default is 5.

5
d int

Degree of the B-spline, default is 3.

3
name str

Name of the KAN head, default is 'kan_head'.

'kan_head'
enable_bias bool

Whether to enable bias in reconciliation functions, default is False.

False
with_lorr bool

Whether to use LORR reconciliation, default is False.

False
r int

Parameter for reconciliation functions, default is 3.

3
channel_num int

Number of channels, default is 1.

1
with_batch_norm bool

Whether to include batch normalization, default is False.

False
with_softmax bool

Whether to use softmax activation, default is False.

False
parameters_init_method str

Initialization method for parameters, default is 'xavier_normal'.

'xavier_normal'
device str

Device to host the head, default is 'cpu'.

'cpu'

Returns:

Type Description
None
Source code in tinybig/head/basic_heads.py
def __init__(
    self, m: int, n: int,
    grid_range=(-1, 1), t: int = 5, d: int = 3,
    name: str = 'kan_head',
    enable_bias: bool = False,
    # optional parameters
    with_lorr: bool = False, r: int = 3,
    channel_num: int = 1,
    with_batch_norm: bool = False,
    with_softmax: bool = False,
    # other parameters
    parameters_init_method: str = 'xavier_normal',
    device: str = 'cpu', *args, **kwargs
):
    """
    Initializes the KAN head.

    Parameters
    ----------
    m : int
        Input dimension.
    n : int
        Output dimension.
    grid_range : tuple, optional
        Range for B-spline grid, default is (-1, 1).
    t : int, optional
        Number of knots for B-spline, default is 5.
    d : int, optional
        Degree of the B-spline, default is 3.
    name : str, optional
        Name of the KAN head, default is 'kan_head'.
    enable_bias : bool, optional
        Whether to enable bias in reconciliation functions, default is False.
    with_lorr : bool, optional
        Whether to use LORR reconciliation, default is False.
    r : int, optional
        Parameter for reconciliation functions, default is 3.
    channel_num : int, optional
        Number of channels, default is 1.
    with_batch_norm : bool, optional
        Whether to include batch normalization, default is False.
    with_softmax : bool, optional
        Whether to use softmax activation, default is False.
    parameters_init_method : str, optional
        Initialization method for parameters, default is 'xavier_normal'.
    device : str, optional
        Device to host the head, default is 'cpu'.

    Returns
    -------
    None
    """
    data_transformation = bspline_expansion(
        grid_range=grid_range,
        t=t, d=d,
        device=device,
    )

    if with_lorr:
        parameter_fabrication = lorr_reconciliation(
            r=r,
            enable_bias=enable_bias,
            device=device,
        )
    else:
        parameter_fabrication = identity_reconciliation(
            enable_bias=enable_bias,
            device=device,
        )

    remainder = linear_remainder(
        require_remainder_parameters=True,
        activation_functions=[torch.nn.SiLU()],
        device=device,
    )

    output_process_functions = []
    if with_batch_norm:
        output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
    if with_softmax:
        output_process_functions.append(torch.nn.Softmax(dim=-1))

    super().__init__(
        m=m, n=n, name=name,
        data_transformation=data_transformation,
        parameter_fabrication=parameter_fabrication,
        remainder=remainder,
        output_process_functions=output_process_functions,
        channel_num=channel_num,
        parameters_init_method=parameters_init_method,
        device=device, *args, **kwargs
    )