class graph_interdependence_head(head):
def __init__(
self,
m: int, n: int,
name: str = 'graph_interdependence_head',
channel_num: int = 1,
# graph structure parameters
graph: graph_class = None,
graph_file_path: str = None,
nodes: list = None,
links: list = None,
directed: bool = False,
# graph interdependence function parameters
with_multihop: bool = False, h: int = 1, accumulative: bool = False,
with_pagerank: bool = False, c: float = 0.15,
require_data: bool = False,
require_parameters: bool = False,
# adj matrix processing parameters
normalization: bool = True,
normalization_mode: str = 'column',
self_dependence: bool = True,
# parameter reconciliation and remainder functions
with_dual_lphm: bool = False,
with_lorr: bool = False, r: int = 3,
with_residual: bool = False,
enable_bias: bool = False,
# output processing parameters
with_batch_norm: bool = False,
with_relu: bool = True,
with_softmax: bool = True,
with_dropout: bool = True, p: float = 0.5,
# other parameters
parameters_init_method: str = 'xavier_normal',
device: str = 'cpu', *args, **kwargs
):
if graph is not None:
graph_structure = graph
elif graph_file_path is not None:
graph_structure = graph_class.load(complete_path=graph_file_path)
elif nodes is not None and links is not None:
graph_structure = graph_class(
nodes=nodes,
links=links,
directed=directed,
device=device,
)
else:
raise ValueError('You must provide a graph_file_path or nodes or links...')
if with_pagerank:
instance_interdependence = pagerank_multihop_graph_interdependence(
b=graph_structure.get_node_num(), m=m,
c=c,
interdependence_type='instance',
graph=graph_structure,
normalization=normalization,
normalization_mode=normalization_mode,
self_dependence=self_dependence,
require_data=require_data,
require_parameters=require_parameters,
device=device
)
elif with_multihop:
instance_interdependence = multihop_graph_interdependence(
b=graph_structure.get_node_num(), m=m,
h=h, accumulative=accumulative,
interdependence_type='instance',
graph=graph_structure,
normalization=normalization,
normalization_mode=normalization_mode,
self_dependence=self_dependence,
require_data=require_data,
require_parameters=require_parameters,
device=device
)
else:
instance_interdependence = graph_interdependence(
b=graph_structure.get_node_num(), m=m,
interdependence_type='instance',
graph=graph_structure,
normalization=normalization,
normalization_mode=normalization_mode,
self_dependence=self_dependence,
require_data=require_data,
require_parameters=require_parameters,
device=device
)
print('** instance_interdependence', instance_interdependence)
data_transformation = identity_expansion(
device=device
)
print('** data_transformation', data_transformation)
if with_dual_lphm:
parameter_fabrication = dual_lphm_reconciliation(
r=r,
device=device,
enable_bias=enable_bias,
)
elif with_lorr:
parameter_fabrication = lorr_reconciliation(
r=r,
enable_bias=enable_bias,
device=device
)
else:
parameter_fabrication = identity_reconciliation(
enable_bias=enable_bias,
device=device
)
print('** parameter_fabrication', parameter_fabrication)
if with_residual:
remainder = linear_remainder(
device=device
)
else:
remainder = zero_remainder(
device=device,
)
print('** remainder', remainder)
output_process_functions = []
if with_batch_norm:
output_process_functions.append(torch.nn.BatchNorm1d(num_features=n, device=device))
if with_relu:
output_process_functions.append(torch.nn.ReLU())
if with_dropout:
output_process_functions.append(torch.nn.Dropout(p=p))
if with_softmax:
output_process_functions.append(torch.nn.LogSoftmax(dim=-1))
print('** output_process_functions', output_process_functions)
super().__init__(
m=m, n=n, name=name,
instance_interdependence=instance_interdependence,
data_transformation=data_transformation,
parameter_fabrication=parameter_fabrication,
remainder=remainder,
output_process_functions=output_process_functions,
channel_num=channel_num,
parameters_init_method='fanout_std_uniform',
device=device, *args, **kwargs
)