class metric_fusion(base_fusion):
def __init__(
self,
dims: list[int] | tuple[int],
metric: Callable[[torch.Tensor], torch.Tensor],
name: str = "metric_fusion",
*args, **kwargs
):
super().__init__(dims=dims, name=name, require_parameters=False, *args, **kwargs)
self.metric = metric
def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
dims = dims if dims is not None else self.dims
assert dims is not None and all(dim == dims[0] for dim in dims)
return dims[0]
def calculate_l(self, *args, **kwargs):
return 0
def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
if not x:
raise ValueError("The input x cannot be empty...")
if not all(x[0].shape == t.shape for t in x):
raise ValueError("The input x must have the same shape.")
x = torch.stack(x, dim=0)
x = self.pre_process(x=x, device=device)
assert self.metric is not None and isinstance(self.metric, Callable)
x_shape = x.shape
x_permuted = x.permute(*range(1, x.ndim), 0)
fused_x = self.metric(x_permuted.view(-1, x_shape[0])).view(x_shape[1:])
assert fused_x.size(-1) == self.calculate_n([element.size(-1) for element in x])
return self.post_process(x=fused_x, device=device)