Skip to content

fusion

Bases: Module, function

Source code in tinybig/module/base_fusion.py
class fusion(Module, function):

    def __init__(
        self,
        dims: list[int] | tuple[int] = None,
        name: str = 'base_fusion',
        require_parameters: bool = False,
        preprocess_functions=None,
        postprocess_functions=None,
        preprocess_function_configs=None,
        postprocess_function_configs=None,
        device: str = 'cpu',
        *args, **kwargs
    ):
        Module.__init__(self)
        function.__init__(self, name=name, device=device)

        self.dims = dims
        self.require_parameters = require_parameters

        self.preprocess_functions = config.instantiation_functions(preprocess_functions, preprocess_function_configs, device=self.device)
        self.postprocess_functions = config.instantiation_functions(postprocess_functions, postprocess_function_configs, device=self.device)

    def get_name(self):
        return self.name

    def get_dims(self):
        return self.dims

    def get_num(self):
        if self.dims is not None:
            return len(self.dims)
        else:
            return 0

    def get_dim(self, index: int):
        if self.dims is not None:
            if index is not None and 0 <= index <= len(self.dims):
                return self.dims[index]
            else:
                raise ValueError(f'Index {index} is out of dim list bounds...')
        else:
            return None

    def pre_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        return function.func_x(x, self.preprocess_functions, device=device)

    def post_process(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        return function.func_x(x, self.postprocess_functions, device=device)

    def to_config(self):
        class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
        attributes = {attr: getattr(self, attr) for attr in self.__dict__}
        attributes.pop('preprocess_functions')
        attributes.pop('postprocess_functions')

        if self.preprocess_functions is not None:
            attributes['preprocess_function_configs'] = function.functions_to_configs(self.preprocess_functions)
        if self.postprocess_functions is not None:
            attributes['postprocess_function_configs'] = function.functions_to_configs(self.postprocess_functions)

        return {
            "function_class": class_name,
            "function_parameters": attributes
        }

    @abstractmethod
    def calculate_n(self, dims: list[int] | tuple[int] = None, *args, **kwargs):
        pass

    @abstractmethod
    def calculate_l(self, *args, **kwargs):
        pass

    @abstractmethod
    def forward(self, x: list[torch.Tensor] | tuple[torch.Tensor], w: torch.nn.Parameter = None, device: str = 'cpu', *args, **kwargs):
        pass