Skip to content

graph_interdependence_head

Bases: head

A head class for implementing graph-based interdependence mechanisms.

This class supports various graph-based interdependence strategies, data transformations, parameter reconciliations, and customizable output processing functions.

Attributes:

Name Type Description
m int

Input dimension of the head.

n int

Output dimension of the head.

name str

Name of the head.

channel_num int

Number of channels for multi-channel processing.

graph graph

Graph structure used for interdependence, provided as an instance of graph_class.

graph_file_path str

Path to load the graph structure.

nodes list

List of nodes in the graph.

links list

List of links (edges) in the graph.

directed bool

Whether the graph is directed.

normalization bool

Whether to normalize the adjacency matrix.

normalization_mode str

Mode of normalization for the adjacency matrix, e.g., 'row' or 'column'.

self_dependence bool

Whether to include self-loops in the graph structure.

parameters_init_method str

Initialization method for parameters.

device str

Device to host the head (e.g., 'cpu' or 'cuda').

Source code in tinybig/head/graph_based_heads.py
class graph_interdependence_head(head):
    """
    A head class for implementing graph-based interdependence mechanisms.

    This class supports various graph-based interdependence strategies, data transformations,
    parameter reconciliations, and customizable output processing functions.

    Attributes
    ----------
    m : int
        Input dimension of the head.
    n : int
        Output dimension of the head.
    name : str
        Name of the head.
    channel_num : int
        Number of channels for multi-channel processing.
    graph : graph_class
        Graph structure used for interdependence, provided as an instance of `graph_class`.
    graph_file_path : str
        Path to load the graph structure.
    nodes : list
        List of nodes in the graph.
    links : list
        List of links (edges) in the graph.
    directed : bool
        Whether the graph is directed.
    normalization : bool
        Whether to normalize the adjacency matrix.
    normalization_mode : str
        Mode of normalization for the adjacency matrix, e.g., 'row' or 'column'.
    self_dependence : bool
        Whether to include self-loops in the graph structure.
    parameters_init_method : str
        Initialization method for parameters.
    device : str
        Device to host the head (e.g., 'cpu' or 'cuda').
    """

    def __init__(
        self,
        m: int, n: int,
        name: str = 'graph_interdependence_head',
        channel_num: int = 1,
        # graph structure parameters
        graph: graph_class = None,
        graph_file_path: str = None,
        nodes: list = None,
        links: list = None,
        directed: bool = False,
        # graph interdependence function parameters
        with_multihop: bool = False, h: int = 1, accumulative: bool = False,
        with_pagerank: bool = False, c: float = 0.15,
        require_data: bool = False,
        require_parameters: bool = False,
        # adj matrix processing parameters
        normalization: bool = True,
        normalization_mode: str = 'column',
        self_dependence: bool = True,
        # parameter reconciliation and remainder functions
        with_dual_lphm: bool = False,
        with_lorr: bool = False, r: int = 3,
        with_residual: bool = False,
        enable_bias: bool = False,
        # output processing parameters
        with_batch_norm: bool = False,
        with_relu: bool = True,
        with_softmax: bool = True,
        with_dropout: bool = True, p: float = 0.5,
        # other parameters
        parameters_init_method: str = 'xavier_normal',
        device: str = 'cpu', *args, **kwargs
    ):
        """
        Initializes the graph_interdependence_head class.

        This method sets up the graph structure, interdependence mechanisms, data transformations, parameter reconciliation,
        remainder functions, and output processing pipeline.

        Parameters
        ----------
        m : int
            Input dimension of the head.
        n : int
            Output dimension of the head.
        name : str
            Name of the head.
        channel_num : int
            Number of channels for multi-channel processing.
        graph : graph_class, optional
            Predefined graph structure.
        graph_file_path : str, optional
            Path to the file containing the graph structure.
        nodes : list, optional
            List of nodes for the graph structure.
        links : list, optional
            List of links for the graph structure.
        directed : bool
            Whether the graph is directed.
        with_multihop : bool, optional
            Whether to use multi-hop graph interdependence.
        h : int, optional
            Number of hops for multi-hop interdependence.
        accumulative : bool, optional
            Whether multi-hop connections are accumulative.
        with_pagerank : bool, optional
            Whether to use PageRank-based interdependence.
        c : float, optional
            Damping factor for PageRank, default is 0.15.
        require_data : bool, optional
            Whether data input is required for interdependence.
        require_parameters : bool, optional
            Whether parameters are required for interdependence.
        normalization : bool, optional
            Whether to normalize the adjacency matrix.
        normalization_mode : str, optional
            Mode of normalization for the adjacency matrix.
        self_dependence : bool, optional
            Whether self-dependence is included in interdependence.
        with_dual_lphm : bool, optional
            Whether to use dual LPHM for parameter reconciliation.
        with_lorr : bool, optional
            Whether to use LORR for parameter reconciliation.
        r : int, optional
            Rank for parameter reconciliation.
        with_residual : bool, optional
            Whether to include a residual connection.
        enable_bias : bool, optional
            Whether to include bias in the model.
        with_batch_norm : bool, optional
            Whether to include batch normalization in output processing.
        with_relu : bool, optional
            Whether to include ReLU activation in output processing.
        with_softmax : bool, optional
            Whether to include softmax activation in output processing.
        with_dropout : bool, optional
            Whether to include dropout in output processing.
        p : float, optional
            Dropout probability.
        parameters_init_method : str, optional
            Initialization method for parameters.
        device : str, optional
            Device to host the head (e.g., 'cpu', 'cuda').

        Raises
        ------
        ValueError
            If graph parameters are not properly specified.
        """
        if graph is not None:
            graph_structure = graph
        elif graph_file_path is not None:
            graph_structure = graph_class.load(complete_path=graph_file_path)
        elif nodes is not None and links is not None:
            graph_structure = graph_class(
                nodes=nodes,
                links=links,
                directed=directed,
                device=device,
            )
        else:
            raise ValueError('You must provide a graph_file_path or nodes or links...')

        if with_pagerank:
            instance_interdependence = pagerank_multihop_graph_interdependence(
                b=graph_structure.get_node_num(), m=m,
                c=c,
                interdependence_type='instance',
                graph=graph_structure,
                normalization=normalization,
                normalization_mode=normalization_mode,
                self_dependence=self_dependence,
                require_data=require_data,
                require_parameters=require_parameters,
                device=device
            )
        elif with_multihop:
            instance_interdependence = multihop_graph_interdependence(
                b=graph_structure.get_node_num(), m=m,
                h=h, accumulative=accumulative,
                interdependence_type='instance',
                graph=graph_structure,
                normalization=normalization,
                normalization_mode=normalization_mode,
                self_dependence=self_dependence,
                require_data=require_data,
                require_parameters=require_parameters,
                device=device
            )
        else:
            instance_interdependence = graph_interdependence(
                b=graph_structure.get_node_num(), m=m,
                interdependence_type='instance',
                graph=graph_structure,
                normalization=normalization,
                normalization_mode=normalization_mode,
                self_dependence=self_dependence,
                require_data=require_data,
                require_parameters=require_parameters,
                device=device
            )
        print('** instance_interdependence', instance_interdependence)

        data_transformation = identity_expansion(
            device=device
        )
        print('** data_transformation', data_transformation)

        if with_dual_lphm:
            parameter_fabrication = dual_lphm_reconciliation(
                r=r,
                device=device,
                enable_bias=enable_bias,
            )
        elif with_lorr:
            parameter_fabrication = lorr_reconciliation(
                r=r,
                enable_bias=enable_bias,
                device=device
            )
        else:
            parameter_fabrication = identity_reconciliation(
                enable_bias=enable_bias,
                device=device
            )
        print('** parameter_fabrication', parameter_fabrication)

        if with_residual:
            remainder = linear_remainder(
                device=device
            )
        else:
            remainder = zero_remainder(
                device=device,
            )
        print('** remainder', remainder)

        output_process_functions = []
        if with_batch_norm:
            output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
        if with_relu:
            output_process_functions.append(torch.nn.ReLU())
        if with_dropout:
            output_process_functions.append(torch.nn.Dropout(p=p))
        if with_softmax:
            output_process_functions.append(torch.nn.LogSoftmax(dim=-1))
        print('** output_process_functions', output_process_functions)


        super().__init__(
            m=m, n=n, name=name,
            instance_interdependence=instance_interdependence,
            data_transformation=data_transformation,
            parameter_fabrication=parameter_fabrication,
            remainder=remainder,
            output_process_functions=output_process_functions,
            channel_num=channel_num,
            parameters_init_method='fanout_std_uniform',
            device=device, *args, **kwargs
        )

__init__(m, n, name='graph_interdependence_head', channel_num=1, graph=None, graph_file_path=None, nodes=None, links=None, directed=False, with_multihop=False, h=1, accumulative=False, with_pagerank=False, c=0.15, require_data=False, require_parameters=False, normalization=True, normalization_mode='column', self_dependence=True, with_dual_lphm=False, with_lorr=False, r=3, with_residual=False, enable_bias=False, with_batch_norm=False, with_relu=True, with_softmax=True, with_dropout=True, p=0.5, parameters_init_method='xavier_normal', device='cpu', *args, **kwargs)

Initializes the graph_interdependence_head class.

This method sets up the graph structure, interdependence mechanisms, data transformations, parameter reconciliation, remainder functions, and output processing pipeline.

Parameters:

Name Type Description Default
m int

Input dimension of the head.

required
n int

Output dimension of the head.

required
name str

Name of the head.

'graph_interdependence_head'
channel_num int

Number of channels for multi-channel processing.

1
graph graph

Predefined graph structure.

None
graph_file_path str

Path to the file containing the graph structure.

None
nodes list

List of nodes for the graph structure.

None
links list

List of links for the graph structure.

None
directed bool

Whether the graph is directed.

False
with_multihop bool

Whether to use multi-hop graph interdependence.

False
h int

Number of hops for multi-hop interdependence.

1
accumulative bool

Whether multi-hop connections are accumulative.

False
with_pagerank bool

Whether to use PageRank-based interdependence.

False
c float

Damping factor for PageRank, default is 0.15.

0.15
require_data bool

Whether data input is required for interdependence.

False
require_parameters bool

Whether parameters are required for interdependence.

False
normalization bool

Whether to normalize the adjacency matrix.

True
normalization_mode str

Mode of normalization for the adjacency matrix.

'column'
self_dependence bool

Whether self-dependence is included in interdependence.

True
with_dual_lphm bool

Whether to use dual LPHM for parameter reconciliation.

False
with_lorr bool

Whether to use LORR for parameter reconciliation.

False
r int

Rank for parameter reconciliation.

3
with_residual bool

Whether to include a residual connection.

False
enable_bias bool

Whether to include bias in the model.

False
with_batch_norm bool

Whether to include batch normalization in output processing.

False
with_relu bool

Whether to include ReLU activation in output processing.

True
with_softmax bool

Whether to include softmax activation in output processing.

True
with_dropout bool

Whether to include dropout in output processing.

True
p float

Dropout probability.

0.5
parameters_init_method str

Initialization method for parameters.

'xavier_normal'
device str

Device to host the head (e.g., 'cpu', 'cuda').

'cpu'

Raises:

Type Description
ValueError

If graph parameters are not properly specified.

Source code in tinybig/head/graph_based_heads.py
def __init__(
    self,
    m: int, n: int,
    name: str = 'graph_interdependence_head',
    channel_num: int = 1,
    # graph structure parameters
    graph: graph_class = None,
    graph_file_path: str = None,
    nodes: list = None,
    links: list = None,
    directed: bool = False,
    # graph interdependence function parameters
    with_multihop: bool = False, h: int = 1, accumulative: bool = False,
    with_pagerank: bool = False, c: float = 0.15,
    require_data: bool = False,
    require_parameters: bool = False,
    # adj matrix processing parameters
    normalization: bool = True,
    normalization_mode: str = 'column',
    self_dependence: bool = True,
    # parameter reconciliation and remainder functions
    with_dual_lphm: bool = False,
    with_lorr: bool = False, r: int = 3,
    with_residual: bool = False,
    enable_bias: bool = False,
    # output processing parameters
    with_batch_norm: bool = False,
    with_relu: bool = True,
    with_softmax: bool = True,
    with_dropout: bool = True, p: float = 0.5,
    # other parameters
    parameters_init_method: str = 'xavier_normal',
    device: str = 'cpu', *args, **kwargs
):
    """
    Initializes the graph_interdependence_head class.

    This method sets up the graph structure, interdependence mechanisms, data transformations, parameter reconciliation,
    remainder functions, and output processing pipeline.

    Parameters
    ----------
    m : int
        Input dimension of the head.
    n : int
        Output dimension of the head.
    name : str
        Name of the head.
    channel_num : int
        Number of channels for multi-channel processing.
    graph : graph_class, optional
        Predefined graph structure.
    graph_file_path : str, optional
        Path to the file containing the graph structure.
    nodes : list, optional
        List of nodes for the graph structure.
    links : list, optional
        List of links for the graph structure.
    directed : bool
        Whether the graph is directed.
    with_multihop : bool, optional
        Whether to use multi-hop graph interdependence.
    h : int, optional
        Number of hops for multi-hop interdependence.
    accumulative : bool, optional
        Whether multi-hop connections are accumulative.
    with_pagerank : bool, optional
        Whether to use PageRank-based interdependence.
    c : float, optional
        Damping factor for PageRank, default is 0.15.
    require_data : bool, optional
        Whether data input is required for interdependence.
    require_parameters : bool, optional
        Whether parameters are required for interdependence.
    normalization : bool, optional
        Whether to normalize the adjacency matrix.
    normalization_mode : str, optional
        Mode of normalization for the adjacency matrix.
    self_dependence : bool, optional
        Whether self-dependence is included in interdependence.
    with_dual_lphm : bool, optional
        Whether to use dual LPHM for parameter reconciliation.
    with_lorr : bool, optional
        Whether to use LORR for parameter reconciliation.
    r : int, optional
        Rank for parameter reconciliation.
    with_residual : bool, optional
        Whether to include a residual connection.
    enable_bias : bool, optional
        Whether to include bias in the model.
    with_batch_norm : bool, optional
        Whether to include batch normalization in output processing.
    with_relu : bool, optional
        Whether to include ReLU activation in output processing.
    with_softmax : bool, optional
        Whether to include softmax activation in output processing.
    with_dropout : bool, optional
        Whether to include dropout in output processing.
    p : float, optional
        Dropout probability.
    parameters_init_method : str, optional
        Initialization method for parameters.
    device : str, optional
        Device to host the head (e.g., 'cpu', 'cuda').

    Raises
    ------
    ValueError
        If graph parameters are not properly specified.
    """
    if graph is not None:
        graph_structure = graph
    elif graph_file_path is not None:
        graph_structure = graph_class.load(complete_path=graph_file_path)
    elif nodes is not None and links is not None:
        graph_structure = graph_class(
            nodes=nodes,
            links=links,
            directed=directed,
            device=device,
        )
    else:
        raise ValueError('You must provide a graph_file_path or nodes or links...')

    if with_pagerank:
        instance_interdependence = pagerank_multihop_graph_interdependence(
            b=graph_structure.get_node_num(), m=m,
            c=c,
            interdependence_type='instance',
            graph=graph_structure,
            normalization=normalization,
            normalization_mode=normalization_mode,
            self_dependence=self_dependence,
            require_data=require_data,
            require_parameters=require_parameters,
            device=device
        )
    elif with_multihop:
        instance_interdependence = multihop_graph_interdependence(
            b=graph_structure.get_node_num(), m=m,
            h=h, accumulative=accumulative,
            interdependence_type='instance',
            graph=graph_structure,
            normalization=normalization,
            normalization_mode=normalization_mode,
            self_dependence=self_dependence,
            require_data=require_data,
            require_parameters=require_parameters,
            device=device
        )
    else:
        instance_interdependence = graph_interdependence(
            b=graph_structure.get_node_num(), m=m,
            interdependence_type='instance',
            graph=graph_structure,
            normalization=normalization,
            normalization_mode=normalization_mode,
            self_dependence=self_dependence,
            require_data=require_data,
            require_parameters=require_parameters,
            device=device
        )
    print('** instance_interdependence', instance_interdependence)

    data_transformation = identity_expansion(
        device=device
    )
    print('** data_transformation', data_transformation)

    if with_dual_lphm:
        parameter_fabrication = dual_lphm_reconciliation(
            r=r,
            device=device,
            enable_bias=enable_bias,
        )
    elif with_lorr:
        parameter_fabrication = lorr_reconciliation(
            r=r,
            enable_bias=enable_bias,
            device=device
        )
    else:
        parameter_fabrication = identity_reconciliation(
            enable_bias=enable_bias,
            device=device
        )
    print('** parameter_fabrication', parameter_fabrication)

    if with_residual:
        remainder = linear_remainder(
            device=device
        )
    else:
        remainder = zero_remainder(
            device=device,
        )
    print('** remainder', remainder)

    output_process_functions = []
    if with_batch_norm:
        output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
    if with_relu:
        output_process_functions.append(torch.nn.ReLU())
    if with_dropout:
        output_process_functions.append(torch.nn.Dropout(p=p))
    if with_softmax:
        output_process_functions.append(torch.nn.LogSoftmax(dim=-1))
    print('** output_process_functions', output_process_functions)


    super().__init__(
        m=m, n=n, name=name,
        instance_interdependence=instance_interdependence,
        data_transformation=data_transformation,
        parameter_fabrication=parameter_fabrication,
        remainder=remainder,
        output_process_functions=output_process_functions,
        channel_num=channel_num,
        parameters_init_method='fanout_std_uniform',
        device=device, *args, **kwargs
    )