class grid_compression_layer(layer):
def __init__(
self,
h: int, w: int, channel_num: int,
d: int = 1,
name: str = 'grid_compression_layer',
pooling_metric: str = 'batch_max',
patch_shape: str = 'cuboid',
p_h: int = None, p_h_prime: int = None,
p_w: int = None, p_w_prime: int = None,
p_d: int = 0, p_d_prime: int = None,
p_r: int = None,
cd_h: int = None, cd_w: int = None, cd_d: int = 1,
with_dropout: bool = False, p: float = 0.5,
packing_strategy: str = 'densest_packing',
# other parameters
parameters_init_method: str = 'xavier_normal',
device: str = 'cpu', *args, **kwargs
):
print('* grid_compression_layer')
heads = [
grid_compression_head(
h=h, w=w, d=d,
channel_num=channel_num,
pooling_metric=pooling_metric,
patch_shape=patch_shape,
p_h=p_h, p_h_prime=p_h_prime,
p_w=p_w, p_w_prime=p_w_prime,
p_d=p_d, p_d_prime=p_d_prime,
p_r=p_r,
cd_h=cd_h, cd_w=cd_w, cd_d=cd_d,
packing_strategy=packing_strategy,
with_dropout=with_dropout, p=p,
parameters_init_method=parameters_init_method,
device=device, *args, **kwargs
)
]
assert len(heads) >= 1
m, n = heads[0].get_m(), heads[0].get_n()
print('--------------------------')
super().__init__(name=name, m=m, n=n, heads=heads, device=device, *args, **kwargs)
def get_output_grid_shape(self):
assert len(self.heads) >= 1
return self.heads[0].get_output_grid_shape()