Bases: weighted_summation_fusion
Source code in tinybig/fusion/basic_fusion.py
| class average_fusion(weighted_summation_fusion):
def __init__(self, dims: list[int] | tuple[int], name: str = "average_fusion", require_parameters: bool = False, *args, **kwargs):
super().__init__(dims=dims, weights=1.0/len(dims)*torch.ones(len(dims)), name=name, require_parameters=False, *args, **kwargs)
|