class grid_compression_head(head):
def __init__(
self,
h: int, w: int, channel_num: int,
d: int = 1, name: str = 'grid_compression_head',
pooling_metric: str = 'batch_max',
patch_shape: str = 'cuboid',
p_h: int = None, p_h_prime: int = None,
p_w: int = None, p_w_prime: int = None,
p_d: int = 0, p_d_prime: int = None,
p_r: int = None,
cd_h: int = None, cd_w: int = None, cd_d: int = 1,
packing_strategy: str = 'densest_packing',
with_dropout: bool = True, p: float = 0.5,
# other parameters
parameters_init_method: str = 'xavier_normal',
device: str = 'cpu', *args, **kwargs
):
if channel_num is None or channel_num <=0:
raise ValueError(f'positive channel number={channel_num} must be specified...')
self.channel_num = channel_num
if h is None or w is None or d is None:
raise ValueError(f'h={h} and w={w} and d={d} must be specified...')
grid_structure = grid(
h=h, w=w, d=d, universe_num=channel_num
)
if patch_shape == 'cuboid':
assert p_h is not None
p_w = p_w if p_w is not None else p_h
patch_structure = cuboid(p_h=p_h, p_w=p_w, p_d=p_d, p_h_prime=p_h_prime, p_w_prime=p_w_prime, p_d_prime=p_d_prime)
elif patch_shape == 'cylinder':
assert p_r is not None
patch_structure = cylinder(p_r=p_r, p_d=p_d, p_d_prime=p_d_prime)
elif patch_shape == 'sphere':
assert p_r is not None
patch_structure = sphere(p_r=p_r)
else:
raise ValueError(f'patch_shape={patch_shape} must be either cuboid, cylinder or sphere...')
data_transformation = geometric_compression(
grid=grid_structure,
patch=patch_structure,
packing_strategy=packing_strategy,
cd_h=cd_h, cd_w=cd_w, cd_d=cd_d,
metric=partial(metric, metric_name=pooling_metric),
device=device,
)
remainder = zero_remainder(
device=device,
)
output_process_functions = []
if with_dropout:
output_process_functions.append(torch.nn.Dropout(p=p))
print('pooling layer', output_process_functions)
m = data_transformation.get_grid_size(across_universe=True)
n = data_transformation.get_patch_num(across_universe=True)
super().__init__(
m=m, n=n,
name=name,
data_transformation=data_transformation,
remainder=remainder,
output_process_functions=output_process_functions,
parameters_init_method=parameters_init_method,
device=device, *args, **kwargs
)
def get_patch_size(self):
return self.data_transformation.get_patch_size()
def get_input_grid_shape(self):
return self.data_transformation.get_grid_shape()
def get_output_grid_shape(self):
output_h, output_w, output_d = self.data_transformation.get_grid_shape_after_packing()
return output_h, output_w, output_d