class kan_head(head):
def __init__(
self, m: int, n: int,
grid_range=(-1, 1), t: int = 5, d: int = 3,
name: str = 'kan_head',
enable_bias: bool = False,
# optional parameters
with_lorr: bool = False, r: int = 3,
channel_num: int = 1,
with_batch_norm: bool = False,
with_softmax: bool = False,
# other parameters
parameters_init_method: str = 'xavier_normal',
device: str = 'cpu', *args, **kwargs
):
data_transformation = bspline_expansion(
grid_range=grid_range,
t=t, d=d,
device=device,
)
if with_lorr:
parameter_fabrication = lorr_reconciliation(
r=r,
enable_bias=enable_bias,
device=device,
)
else:
parameter_fabrication = identity_reconciliation(
enable_bias=enable_bias,
device=device,
)
remainder = linear_remainder(
require_remainder_parameters=True,
activation_functions=[torch.nn.SiLU()],
device=device,
)
output_process_functions = []
if with_batch_norm:
output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
if with_softmax:
output_process_functions.append(torch.nn.Softmax(dim=-1))
super().__init__(
m=m, n=n, name=name,
data_transformation=data_transformation,
parameter_fabrication=parameter_fabrication,
remainder=remainder,
output_process_functions=output_process_functions,
channel_num=channel_num,
parameters_init_method=parameters_init_method,
device=device, *args, **kwargs
)