def kernel(
kernel_name: str = 'pearson_correlation',
x: torch.Tensor = None, x2: torch.Tensor = None,
*args, **kwargs
):
if 'batch' in kernel_name:
assert x is not None and x2 is None
else:
assert x is not None and x2 is not None
match kernel_name:
case 'pearson_correlation_kernel' | 'pearson_correlation' | 'pearson': return instance_pearson_correlation_kernel(x1=x, x2=x2)
case 'batch_pearson_correlation_kernel' | 'batch_pearson_correlation' | 'batch_pearson': return batch_pearson_correlation_kernel(x=x, *args, **kwargs)
case 'kl_divergence_kernel' | 'kl_divergence': return instance_kl_divergence_kernel(x1=x, x2=x2)
case 'batch_kl_divergence_kernel' | 'batch_kl_divergence': return batch_kl_divergence_kernel(x=x, *args, **kwargs)
case 'rv_coefficient_kernel' | 'rv_coefficient': return instance_rv_coefficient_kernel(x1=x, x2=x2)
case 'batch_rv_coefficient_kernel' | 'batch_rv_coefficient': return batch_rv_coefficient_kernel(x=x, *args, **kwargs)
case 'mutual_information_kernel' | 'mutual_information': return instance_mutual_information_kernel(x1=x, x2=x2)
case 'batch_mutual_information_kernel' | 'batch_mutual_information': return batch_mutual_information_kernel(x=x, *args, **kwargs)
case 'custom_hybrid_kernel' | 'custom_hybrid': return instance_custom_hybrid_kernel(x1=x, x2=x2, *args, **kwargs)
case 'batch_custom_hybrid_kernel' | 'batch_custom_hybrid': return batch_custom_hybrid_kernel(x=x, *args, **kwargs)
case _: raise ValueError(f'kernel {kernel_name} not supported')