Skip to content

chain_interdependence_layer

Bases: layer

Source code in tinybig/layer/chain_based_layers.py
class chain_interdependence_layer(layer):
    def __init__(
        self,
        m: int, n: int,
        chain_length: int,
        channel_num: int = 1,
        width: int = 1,
        name: str = 'chain_interdependence_layer',
        # interdependence function parameters
        bi_directional: bool = False,
        with_multihop: bool = False, h: int = 1, accumulative: bool = False,
        with_inverse_approx: bool = False,
        with_exponential_approx: bool = False,
        self_dependence: bool = True,
        self_scaling: float = 1.0,
        # parameter reconciliation function parameters
        with_dual_lphm: bool = False,
        with_lorr: bool = False, r: int = 3,
        enable_bias: bool = False,
        # remainder function parameters
        with_residual: bool = False,
        # output processing parameters
        with_batch_norm: bool = False,
        with_relu: bool = True,
        with_dropout: bool = False, p: float = 0.25,
        with_softmax: bool = True,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, ** kwargs
    ):
        print('* chain_interdependence_layer, width:', width)
        heads = [
            chain_interdependence_head(
                m=m, n=n,
                chain_length=chain_length,
                channel_num=channel_num,
                # -----------------------
                bi_directional=bi_directional,
                with_multihop=with_multihop, h=h, accumulative=accumulative,
                with_inverse_approx=with_inverse_approx,
                with_exponential_approx=with_exponential_approx,
                self_dependence=self_dependence,
                self_scaling=self_scaling,
                # -----------------------
                with_dual_lphm=with_dual_lphm,
                with_lorr=with_lorr, r=r,
                enable_bias=enable_bias,
                # -----------------------
                with_residual=with_residual,
                # -----------------------
                with_batch_norm=with_batch_norm,
                with_relu=with_relu,
                with_dropout=with_dropout, p=p,
                with_softmax=with_softmax,
                # -----------------------
                parameters_init_method=parameters_init_method,
                device=device, *args, ** kwargs
            )
        ] * width
        print('--------------------------')
        super().__init__(name=name, m=m, n=n, heads=heads, device=device, *args, **kwargs)


    def forward(self, x: torch.Tensor, fusion_strategy: str = 'average', device: str = 'cpu', *args, **kwargs):
        assert x is not None and x.ndim == 2

        results = []
        for head in self.heads:
            results.append(head(x=x, device=device))
        assert results != [] and [results[0].shape] * len(results) == [result.shape for result in results]

        if self.head_fusion is not None:
            assert self.head_fusion.get_num() == len(results) and [results[0].shape] * len(results) == [result.shape for result in results]
            result = self.head_fusion(x=results, w=self.w_head_fusion, device=device)
        else:
            assert len(results) == 1
            result = results[0]

        return result