def metric(
x: torch.Tensor,
metric_name: str,
*args, **kwargs
):
assert x is not None and metric_name is not None
match metric_name:
case 'norm': return norm(x=x, *args, **kwargs)
case 'batch_norm': return batch_norm(x=x, *args, **kwargs)
case 'l2_norm': return l2_norm(x=x)
case 'batch_l2_norm': return batch_l2_norm(x=x, *args, **kwargs)
case 'l1_norm': return l1_norm(x=x)
case 'batch_l1_norm': return batch_l1_norm(x=x, *args, **kwargs)
case 'max': return max(x=x)
case 'batch_max': return batch_max(x=x, *args, **kwargs)
case 'min': return min(x=x)
case 'batch_min': return batch_min(x=x, *args, **kwargs)
case 'sum': return sum(x=x)
case 'batch_sum': return batch_sum(x=x, *args, **kwargs)
case 'prod': return prod(x=x)
case 'batch_prod': return batch_prod(x=x, *args, **kwargs)
case _: raise ValueError(f'Unknown metric name: {metric_name}...')