Skip to content

grid_compression_head

Bases: head

Source code in tinybig/head/grid_based_heads.py
class grid_compression_head(head):
    def __init__(
        self,
        h: int, w: int, channel_num: int,
        d: int = 1, name: str = 'grid_compression_head',
        pooling_metric: str = 'batch_max',
        patch_shape: str = 'cuboid',
        p_h: int = None, p_h_prime: int = None,
        p_w: int = None, p_w_prime: int = None,
        p_d: int = 0, p_d_prime: int = None,
        p_r: int = None,
        cd_h: int = None, cd_w: int = None, cd_d: int = 1,
        packing_strategy: str = 'densest_packing',
        with_dropout: bool = True, p: float = 0.5,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):

        if channel_num is None or channel_num <=0:
            raise ValueError(f'positive channel number={channel_num} must be specified...')
        self.channel_num = channel_num
        if h is None or w is None or d is None:
            raise ValueError(f'h={h} and w={w} and d={d} must be specified...')
        grid_structure = grid(
            h=h, w=w, d=d, universe_num=channel_num
        )

        if patch_shape == 'cuboid':
            assert p_h is not None
            p_w = p_w if p_w is not None else p_h
            patch_structure = cuboid(p_h=p_h, p_w=p_w, p_d=p_d, p_h_prime=p_h_prime, p_w_prime=p_w_prime, p_d_prime=p_d_prime)
        elif patch_shape == 'cylinder':
            assert p_r is not None
            patch_structure = cylinder(p_r=p_r, p_d=p_d, p_d_prime=p_d_prime)
        elif patch_shape == 'sphere':
            assert p_r is not None
            patch_structure = sphere(p_r=p_r)
        else:
            raise ValueError(f'patch_shape={patch_shape} must be either cuboid, cylinder or sphere...')

        data_transformation = geometric_compression(
            grid=grid_structure,
            patch=patch_structure,
            packing_strategy=packing_strategy,
            cd_h=cd_h, cd_w=cd_w, cd_d=cd_d,
            metric=partial(metric, metric_name=pooling_metric),
            device=device,
        )

        remainder = zero_remainder(
            device=device,
        )

        output_process_functions = []
        if with_dropout:
            output_process_functions.append(torch.nn.Dropout(p=p))
        print('pooling layer', output_process_functions)

        m = data_transformation.get_grid_size(across_universe=True)
        n = data_transformation.get_patch_num(across_universe=True)

        super().__init__(
            m=m, n=n,
            name=name,
            data_transformation=data_transformation,
            remainder=remainder,
            output_process_functions=output_process_functions,
            parameters_init_method=parameters_init_method,
            device=device, *args, **kwargs
        )

    def get_patch_size(self):
        return self.data_transformation.get_patch_size()

    def get_input_grid_shape(self):
        return self.data_transformation.get_grid_shape()

    def get_output_grid_shape(self):
        output_h, output_w, output_d = self.data_transformation.get_grid_shape_after_packing()
        return output_h, output_w, output_d