def metric(
metric_name: str,
x: torch.Tensor,
*args, **kwargs
):
assert x is not None and metric_name is not None
match metric_name:
case 'mean': return mean(x=x)
case 'batch_mean': return batch_mean(x=x, *args, **kwargs)
case 'weighted_mean' | 'wmean': return weighted_mean(x=x, *args, **kwargs)
case 'batch_weighted_mean' | 'batch_wmean': return batch_weighted_mean(x=x, *args, **kwargs)
case 'geometric_mean' | 'gmean': return geometric_mean(x=x)
case 'batch_geometric_mean' | 'batch_gmean': return batch_geometric_mean(x=x, *args, **kwargs)
case 'harmonic_mean' | 'hmean': return harmonic_mean(x=x, *args, **kwargs)
case 'batch_harmonic_mean' | 'batch_hmean': return batch_harmonic_mean(x=x, *args, **kwargs)
case 'median': return median(x=x)
case 'batch_median': return batch_median(x=x, *args, **kwargs)
case 'mode': return mode(x=x)
case 'batch_mode': return batch_mode(x=x, *args, **kwargs)
case 'entropy': return entropy(x=x)
case 'batch_entropy': return batch_entropy(x=x, *args, **kwargs)
case 'variance' | 'var': return variance(x=x)
case 'batch_variance' | 'batch_var': return batch_variance(x=x, *args, **kwargs)
case 'std' | 'standard_deviation': return std(x=x)
case 'batch_std' | 'batch_standard_deviation': return batch_std(x=x, *args, **kwargs)
case 'skewness' | 'skew': return skewness(x=x)
case 'batch_skewness' | 'batch_skew': return batch_skewness(x=x, *args, **kwargs)
case _: raise ValueError(f'Unknown metric name: {metric_name}...')