Skip to content

grid_interdependence_head

Bases: head

Source code in tinybig/head/grid_based_heads.py
class grid_interdependence_head(head):
    def __init__(
        self,
        h: int, w: int, in_channel: int, out_channel: int,
        d: int = 1, name: str = 'grid_interdependence_head',
        # patch structure parameters
        patch_shape: str = 'cuboid',
        p_h: int = None, p_h_prime: int = None,
        p_w: int = None, p_w_prime: int = None,
        p_d: int = 0, p_d_prime: int = None,
        p_r: int = None,
        # packing parameters
        cd_h: int = None, cd_w: int = None, cd_d: int = 1,
        packing_strategy: str = 'densest_packing',
        # output processing function parameters
        with_batch_norm: bool = True,
        with_relu: bool = True,
        with_residual: bool = False,
        # other parameters
        with_dual_lphm: bool = False,
        with_lorr: bool = False, r: int = 3,
        enable_bias: bool = False,
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):
        if in_channel is None or out_channel is None or in_channel <=0 or out_channel <=0:
            raise ValueError(f'positive in_channel={in_channel} and out_channel={out_channel} must be specified...')
        self.in_channel = in_channel
        self.out_channel = out_channel

        if h is None or w is None or d is None:
            raise ValueError(f'h={h} and w={w} and d={d} must be specified...')
        grid_structure = grid(
            h=h, w=w, d=d, universe_num=in_channel
        )

        if patch_shape == 'cuboid':
            assert p_h is not None
            p_w = p_w if p_w is not None else p_h
            patch_structure = cuboid(p_h=p_h, p_w=p_w, p_d=p_d, p_h_prime=p_h_prime, p_w_prime=p_w_prime, p_d_prime=p_d_prime)
        elif patch_shape == 'cylinder':
            assert p_r is not None
            patch_structure = cylinder(p_r=p_r, p_d=p_d, p_d_prime=p_d_prime)
        elif patch_shape == 'sphere':
            assert p_r is not None
            patch_structure = sphere(p_r=p_r)
        else:
            raise ValueError(f'patch_shape={patch_shape} must be either cuboid, cylinder or sphere...')

        attribute_interdependence = geometric_interdependence(
            interdependence_type='attribute',
            grid=grid_structure,
            patch=patch_structure,
            packing_strategy=packing_strategy,
            cd_h=cd_h, cd_w=cd_w, cd_d=cd_d,
            interdependence_matrix_mode='padding',
            normalization=False,
            require_data=False, require_parameters=False,
            device=device
        )

        data_transformation = identity_expansion(
            device=device
        )

        if with_dual_lphm:
            print('grid head', 'with_dual_lphm:', with_dual_lphm, 'r:', r)
            parameter_fabrication = dual_lphm_reconciliation(
                r=r,
                device=device,
                enable_bias=enable_bias,
            )
        elif with_lorr:
            print('grid head', 'with_lorr:', with_lorr, 'r:', r)
            parameter_fabrication = lorr_reconciliation(
                r=r,
                device=device,
                enable_bias=enable_bias,
            )
        else:
            parameter_fabrication = identity_reconciliation(
                device=device,
                enable_bias=enable_bias
            )
        # to save computational cost, the n we provide the parameter fabrication function is different from the n of the head,
        # we need to manually provide the l for the parameter fabrication functions...
        l = parameter_fabrication.calculate_l(
            n=self.out_channel, D=self.in_channel*attribute_interdependence.get_patch_size()
        )

        if with_residual:
            remainder = linear_remainder(
                device=device
            )
        else:
            remainder = zero_remainder(
                device=device,
            )

        m = attribute_interdependence.get_grid_size(across_universe=True)
        n = attribute_interdependence.get_patch_num(across_universe=False) * out_channel

        output_process_functions = []
        if with_batch_norm:
            output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
        if with_relu:
            output_process_functions.append(torch.nn.ReLU())

        print('conv layer', output_process_functions)

        super().__init__(
            name=name,
            m=m, n=n, channel_num=1, l=l,
            data_transformation=data_transformation,
            parameter_fabrication=parameter_fabrication,
            remainder=remainder,
            attribute_interdependence=attribute_interdependence,
            output_process_functions=output_process_functions,
            parameters_init_method=parameters_init_method,
            device=device, *args, **kwargs
        )

    def get_patch_size(self):
        return self.attribute_interdependence.get_patch_size()

    def get_input_grid_shape(self):
        return self.attribute_interdependence.get_grid_shape()

    def get_output_grid_shape(self):
        return self.attribute_interdependence.get_grid_shape_after_packing()

    def calculate_phi_w(self, channel_index: int = 0, device='cpu', *args, **kwargs):
        assert channel_index in range(self.channel_num)
        w_chunk = self.w[channel_index:channel_index + 1, :]
        n, D = self.out_channel, self.in_channel * self.attribute_interdependence.get_patch_size()
        assert w_chunk.size(1) == self.parameter_fabrication.calculate_l(n=n, D=D)
        phi_w = self.parameter_fabrication(w=w_chunk, n=n, D=D, device=device)
        return phi_w

    def calculate_inner_product(self, kappa_xi_x: torch.Tensor, phi_w: torch.Tensor, device='cpu', *args, **kwargs):
        assert kappa_xi_x.ndim == 2 and phi_w.ndim == 2
        b = kappa_xi_x.size(0)
        inner_prod = torch.matmul(kappa_xi_x.view(b, -1, self.get_patch_size() * self.in_channel), phi_w.T)
        inner_prod = inner_prod.permute(0, 2, 1).reshape(b, -1)
        if self.b is not None:
            inner_prod += self.b
        return inner_prod