Skip to content

chain_interdependence_head

Bases: head

Source code in tinybig/head/chain_based_heads.py
class chain_interdependence_head(head):
    def __init__(
        self,
        m: int, n: int,
        chain_length: int,
        channel_num: int = 1,
        name: str = 'chain_interdependence_head',
        # interdependence function parameters
        bi_directional: bool = False,
        with_multihop: bool = False, h: int = 1, accumulative: bool = False,
        with_inverse_approx: bool = False,
        with_exponential_approx: bool = False,
        self_dependence: bool = True,
        self_scaling: float = 1.0,
        # data transformation function parameters
        with_taylor: bool = False, d: int = 2,
        # parameter reconciliation function parameters
        with_dual_lphm: bool = False,
        with_lorr: bool = False, r: int = 3,
        enable_bias: bool = False,
        # remainder function parameters
        with_residual: bool = False,
        # output processing parameters
        with_batch_norm: bool = False,
        with_relu: bool = True,
        with_softmax: bool = True,
        with_dropout: bool = False, p: float = 0.25,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):
        self.chain_length = chain_length
        chain_structure = chain(
            length=chain_length,
            name=name,
            bi_directional=bi_directional,
            device=device,
        )

        if with_exponential_approx:
            instance_interdependence = exponential_approx_multihop_chain_interdependence(
                b=chain_length, m=m,
                chain=chain_structure,
                interdependence_type='instance',
                normalization=False,
                self_dependence=self_dependence,
                self_scaling=self_scaling,
                require_data=False,
                require_parameters=False,
            )
        elif with_inverse_approx:
            instance_interdependence = inverse_approx_multihop_chain_interdependence(
                b=chain_length, m=m,
                chain=chain_structure,
                interdependence_type='instance',
                normalization=False,
                self_dependence=self_dependence,
                self_scaling=self_scaling,
                require_data=False,
                require_parameters=False,
            )
        elif with_multihop:
            instance_interdependence = multihop_chain_interdependence(
                h=h, accumulative=accumulative,
                b=chain_length, m=m,
                chain=chain_structure,
                interdependence_type='instance',
                normalization=False,
                self_dependence=self_dependence,
                self_scaling=self_scaling,
                require_data=False,
                require_parameters=False,
            )
        else:
            instance_interdependence = chain_interdependence(
                b=chain_length, m=m,
                chain=chain_structure,
                interdependence_type='instance',
                normalization=False,
                self_dependence=self_dependence,
                self_scaling=self_scaling,
                require_data=False,
                require_parameters=False,
            )
        print('** instance_interdependence', instance_interdependence)
        print('*** bi_directional', bi_directional)

        data_transformation = identity_expansion(
            device=device,
        )
        print('** data_transformation', data_transformation)

        if with_dual_lphm:
            parameter_fabrication = dual_lphm_reconciliation(
                r=r,
                enable_bias=enable_bias,
                device=device
            )
        elif with_lorr:
            parameter_fabrication = lorr_reconciliation(
                r=r,
                enable_bias=enable_bias,
                device=device,
            )
        else:
            parameter_fabrication = identity_reconciliation(
                enable_bias=enable_bias,
                device=device,
            )
        l = parameter_fabrication.calculate_l(
            n=n, D=data_transformation.calculate_D(m=m)
        )
        print('** parameter_fabrication', parameter_fabrication)

        if with_residual:
            remainder = linear_remainder(
                device=device
            )
        else:
            remainder = zero_remainder(
                device=device,
            )
        print('** remainder', remainder)

        output_process_functions = []
        if with_batch_norm:
            output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
        if with_relu:
            output_process_functions.append(torch.nn.ReLU())
        if with_dropout:
            output_process_functions.append(torch.nn.Dropout(p=p))
        if with_softmax:
            output_process_functions.append(torch.nn.Softmax(dim=-1))
        print('** output_process_functions', output_process_functions)

        super().__init__(
            m=m, n=n,
            name=name, channel_num=channel_num,
            batch_num=chain_length,
            instance_interdependence=instance_interdependence,
            data_transformation=data_transformation,
            parameter_fabrication=parameter_fabrication,
            remainder=remainder,
            output_process_functions=output_process_functions,
            parameters_init_method=parameters_init_method,
            device=device, *args, **kwargs
        )

    def calculate_instance_xi_x(self, x: torch.Tensor, channel_index: int = 0, kappa_x: torch.Tensor = None, device='cpu', *args, **kwargs):
        if self.instance_interdependence is not None:
            if self.instance_interdependence.device != device:
                self.instance_interdependence.to(device)

            if self.w_instance_interdependence is not None and 0 <= channel_index < self.w_instance_interdependence.size(0):
                w_chunks = self.w_instance_interdependence[channel_index:channel_index+1, :]
            else:
                w_chunks = None
            b, m = kappa_x.shape
            kappa_x = kappa_x.view(b, self.chain_length, -1).permute(1, 0, 2).reshape(self.chain_length, -1)
            xi_x = self.instance_interdependence(x=x, w=w_chunks, kappa_x=kappa_x, device=device)
            xi_x = xi_x.view(self.chain_length, b, -1).permute(1, 0, 2)
            return xi_x
        else:
            return kappa_x if kappa_x is not None else x

    def calculate_inner_product(self, kappa_xi_x: torch.Tensor, phi_w: torch.Tensor, device: str = 'cpu', *args, **kwargs):
        if phi_w is not None:
            assert phi_w.ndim == 2 and kappa_xi_x.size(-1) == phi_w.size(-1)
            if device != 'mps' and (kappa_xi_x.is_sparse or phi_w.is_sparse):
                inner_prod = torch.sparse.mm(kappa_xi_x, phi_w.T)
                if self.b is not None:
                    inner_prod += self.b
            else:
                inner_prod = F.linear(kappa_xi_x, phi_w, bias=self.b)
            inner_prod = inner_prod.view(inner_prod.size(0), -1)
        else:
            inner_prod = kappa_xi_x
        return inner_prod

    def fusion(self, inner_products: list[torch.Tensor], device: str = 'cpu', *args, **kwargs):
        if self.channel_fusion is not None:
            assert self.channel_fusion.get_dims() is None or self.channel_fusion.get_num() == len(inner_products)
            result = self.channel_fusion(x=inner_products, w=self.w_channel_fusion, device=device)
            n = self.channel_fusion.calculate_n(dims=[result.size(-1) for result in inner_products])
        else:
            assert len(inner_products) == 1
            result = inner_products[0]
            n = self.n
        assert result.size(-1) == n*self.chain_length
        return result

    def calculate_pi_x(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        if self.remainder is not None:
            if isinstance(self.remainder, zero_remainder):
                return None
            if self.remainder.device != device:
                self.remainder.to(device)
            b, m = x.shape
            x = x.view(b * self.chain_length, -1)
            pi_x = self.remainder(x=x, w=self.w_remainder, b=self.b_remainder, m=self.m, n=self.n, device=device)
            pi_x = pi_x.view(b, -1)
            return pi_x
        else:
            return None