class geometric_interdependence(interdependence):
def __init__(
self,
b: int = 0, m: int = 0,
interdependence_type: str = 'attribute',
name: str = 'geometric_interdependence',
# grid structure initialization options
grid: grid_structure = None,
grid_configs: dict = None,
h: int = None, w: int = None, d: int = 1, channel_num: int = 1,
# patch structure initialization options
patch: Union[cuboid, cylinder, sphere] = None,
patch_configs: dict = None,
# packing options
packing_strategy: str = 'densest_packing',
cd_h: int = None, cd_w: int = None, cd_d: int = None,
interdependence_matrix_mode: str = 'padding',
# interdependence matrix processing options
normalization: bool = False,
normalization_mode: str = 'row_column',
# by default,
require_data: bool = False, require_parameters: bool = False,
device: str = 'cpu', *args, **kwargs
):
super().__init__(b=b, m=m, name=name, interdependence_type=interdependence_type, require_data=require_data, require_parameters=require_parameters, device=device, *args, **kwargs)
if grid is not None:
self.grid = grid
elif grid_configs is not None:
self.grid = config.instantiation_from_configs(configs=grid_configs, class_name='grid_class', parameter_name='grid_parameters')
elif h is not None and w is not None and d is not None:
grid_parameters = {'h': h, 'w': w, 'd': d, 'universe_num': channel_num}
self.grid = grid_structure(**grid_parameters)
else:
raise ValueError('the grid structure is not specified yet...')
if interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
if self.m is None or self.m <= 0:
self.m = self.grid.get_volume(across_universe=True)
assert self.grid.get_volume(across_universe=True) == self.m
elif interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
if self.b is None or self.b <= 0:
self.b = self.grid.get_volume(across_universe=True)
assert self.grid.get_volume(across_universe=True) == self.b
else:
raise ValueError('the interdependence_type is not supported yet...')
if patch is not None:
self.patch = patch
elif patch_configs is not None:
self.patch = config.instantiation_from_configs(configs=patch_configs, class_name='patch_class', parameter_name='patch_parameters')
else:
raise ValueError('the patch structure is not specified yet...')
self.packing_strategy = packing_strategy
self.cd_h, self.cd_w, self.cd_d = (cd_h, cd_w, cd_d) if (cd_h is not None and cd_w is not None and cd_d is not None) else self.patch.packing_strategy_parameters(packing_strategy=self.packing_strategy)
self.interdependence_matrix_mode = interdependence_matrix_mode
self.normalization = normalization
self.normalization_mode = normalization_mode
def update_grid(self, new_grid: grid_structure):
self.grid = new_grid
def update_packing_strategy(self, new_packing_strategy: str):
self.packing_strategy = new_packing_strategy
self.cd_h, self.cd_w, self.cd_d = self.patch.packing_strategy_parameters(packing_strategy=self.packing_strategy)
def update_patch(self, new_patch: Union[cuboid, cylinder, sphere]):
self.patch = new_patch
self.cd_h, self.cd_w, self.cd_d = self.patch.packing_strategy_parameters(packing_strategy=self.packing_strategy)
def update_packing_parameters(self, new_cd_h: int, new_cd_w: int, new_cd_d: int):
self.cd_h = new_cd_h
self.cd_w = new_cd_w
self.cd_d = new_cd_d
def get_channel_num(self):
return self.grid.get_universe_num()
def get_patch_size(self):
return self.patch.get_volume()
def get_patch_num(self, across_universe: bool = False):
return self.grid.get_patch_num(cd_h=self.cd_h, cd_w=self.cd_w, cd_d=self.cd_d, across_universe=across_universe)
def get_grid_size(self, across_universe: bool = False):
return self.grid.get_volume(across_universe=across_universe)
def get_grid_shape(self):
return self.grid.get_grid_shape()
def get_grid_shape_after_packing(self):
return self.grid.get_grid_shape_after_packing(cd_h=self.cd_h, cd_w=self.cd_w, cd_d=self.cd_d)
def calculate_b_prime(self, b: int = None, across_universe: bool = True):
b = b if b is not None else self.b
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
assert self.grid.get_volume(across_universe=across_universe) == b
if self.interdependence_matrix_mode == 'padding':
return self.grid.get_patch_num(cd_h=self.cd_h, cd_w=self.cd_w, cd_d=self.cd_d, across_universe=across_universe) * self.patch.get_volume()
elif self.interdependence_matrix_mode == 'aggregation':
return b
else:
return b
def calculate_m_prime(self, m: int = None, across_universe: bool = True):
m = m if m is not None else self.m
if self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
assert self.grid.get_volume(across_universe=across_universe) == m
if self.interdependence_matrix_mode == 'padding':
return self.grid.get_patch_num(cd_h=self.cd_h, cd_w=self.cd_w, cd_d=self.cd_d, across_universe=across_universe) * self.patch.get_volume()
elif self.interdependence_matrix_mode == 'aggregation':
return m
else:
return m
def calculate_A(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, across_universe: bool = False, device: str = 'cpu', *args, **kwargs):
if not self.require_data and not self.require_parameters and self.A is not None:
return self.A
else:
A = self.grid.to_matrix(
patch=self.patch, packing_strategy=self.packing_strategy,
cd_h=self.cd_h, cd_w=self.cd_w, cd_d=self.cd_d,
interdependence_matrix_mode=self.interdependence_matrix_mode,
normalization=self.normalization, normalization_mode=self.normalization_mode,
across_universe=across_universe, device=device, *args, **kwargs
)
A = self.post_process(x=A, device=device)
if not self.require_data and not self.require_parameters and self.A is None:
self.A = A
return A
def forward(self, x: torch.Tensor = None, w: torch.nn.Parameter = None, kappa_x: torch.Tensor = None, device: str = 'cpu', *args, **kwargs):
if self.require_data:
assert x is not None and x.ndim == 2
if self.require_parameters:
assert w is not None and w.ndim == 2
data_x = kappa_x if kappa_x is not None else x
b, m = data_x.shape
data_x = data_x.view(b*self.grid.get_universe_num(), -1)
if self.interdependence_type in ['row', 'left', 'instance', 'instance_interdependence']:
# A shape: b * b'
A = self.calculate_A(x.transpose(0, 1), w, device=device)
assert A is not None and A.size(0) == data_x.size(0)
if data_x.is_sparse or A.is_sparse:
xi_x = torch.sparse.mm(A.t(), data_x)
else:
xi_x = torch.matmul(A.t(), data_x)
return xi_x
elif self.interdependence_type in ['column', 'right', 'attribute', 'attribute_interdependence']:
# A shape: m * m'
A = self.calculate_A(x, w, device=device)
assert A is not None and A.size(0) == data_x.size(-1)
if data_x.is_sparse or A.is_sparse:
xi_x = torch.sparse.mm(data_x, A)
else:
xi_x = torch.matmul(data_x, A)
if self.interdependence_matrix_mode == 'padding':
# shape [b, c, g, p] -> shape [b, g, c, p]
xi_x = xi_x.view(b, self.grid.get_universe_num(), self.get_patch_num(), self.get_patch_size())
xi_x = xi_x.permute(0, 2, 1, 3)
elif self.interdependence_matrix_mode == 'aggregation':
# shape [b, c, g] -> shape [b, g, c]
xi_x = xi_x.view(b, self.grid.get_universe_num(), self.get_patch_num())
xi_x = xi_x.permute(0, 2, 1)
return xi_x.reshape(b, -1)
else:
raise ValueError(f"Invalid interdependence type: {self.interdependence_type}")