Skip to content

function

Base class for defining and handling functions in the RPN model.

This class provides mechanisms for managing, applying, and serializing functions, including custom callable functions, predefined operations, and string-based function recognition.

Parameters:

Name Type Description Default
name str

The name of the function, by default 'base_function'.

'base_function'
device str

The computational device for the function, by default 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/module/base_function.py
class function:
    """
    Base class for defining and handling functions in the RPN model.

    This class provides mechanisms for managing, applying, and serializing functions,
    including custom callable functions, predefined operations, and string-based function recognition.

    Parameters
    ----------
    name : str, optional
        The name of the function, by default 'base_function'.
    device : str, optional
        The computational device for the function, by default 'cpu'.
    *args : tuple
        Additional positional arguments.
    **kwargs : dict
        Additional keyword arguments.
    """
    def __init__(self, name: str = 'base_function', device: str = 'cpu', *args, **kwargs):
        """
        Initializes the base function with a name and device.

        Parameters
        ----------
        name : str, optional
            The name of the function, by default 'base_function'.
        device : str, optional
            The computational device for the function, by default 'cpu'.
        *args : tuple
            Additional positional arguments.
        **kwargs : dict
            Additional keyword arguments.
        """
        self.name = name
        self.device = device

    def get_name(self):
        """
        The name retrieval method of the function.

        It returns the name of the function.

        Returns
        -------
        str
            The name of the function.
        """
        return self.name

    def __call__(self, *args, **kwargs):
        """
        Calls the `forward` method of the function.

        Parameters
        ----------
        *args : tuple
            Positional arguments for the `forward` method.
        **kwargs : dict
            Keyword arguments for the `forward` method.
        """
        self.forward(*args, **kwargs)

    def to_config(self):
        """
        Serializes the function into a dictionary configuration.

        Returns
        -------
        dict
            A dictionary containing the class name and parameters of the function.
        """
        class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
        attributes = {attr: getattr(self, attr) for attr in self.__dict__}
        return {
            "function_class": class_name,
            "function_parameters": attributes
        }

    @staticmethod
    def func_x(x, functions, device: str = 'cpu'):
        """
        The function execution to the input data.

        It applies the list of functions to the input vector and returns the calculation results.

        This method will be extensively called for handeling the data processing functions in the
        expansion functions, RPN head and remainder functions in tinyBIG.

        * preprocess_functions in expansion functions
        * postprocess_functions in expansion functions
        * output_process_functions in rpn heads
        * activation_functions in remainder functions

        Parameters
        ----------
        x: torch.Tensor
            The input data vector.
        functions: list | tuple | callable
            The functions to be applied to the input vector. The function can be callable functions,
            string names of the functions, the complete class descriptions of the functions, etc.
        device: str, default = 'cpu'
            The device to perform the function on the input vector.

        Returns
        -------
        torch.Tensor
            The processed input vector by these functions.
        """
        if functions is None or ((isinstance(functions, list) or isinstance(functions, tuple)) and len(functions) == 0):
            return x
        elif isinstance(functions, list) or isinstance(functions, tuple):
            for f in functions:
                if callable(f):
                    x = f(x)
                elif type(f) is str:
                    x = function.str_func_x(x=x, func=f, device=device)
            return x
        else:
            if callable(functions):
                return functions(x)
            elif type(functions) is str:
                return function.str_func_x(x=x, func=functions, device=device)

    @staticmethod
    def str_func_x(x, func: str | Callable, device='cpu', *args, **kwargs):
        """
        Function recognition from their string names or string class descriptions.

        It recognizes the data processing functions from their names or class description in strings,
        e.g., "layer_norm" or "torch.nn.functional.layer_norm".

        Since these functions can be very diverse, whose definitions are also very different,
        it makes it very challenging to process them based on their string descriptions.
        This method can process some basic functions, e.g., activation functions, and normalization functions.
        For the functions that are not implemented in this method, the users may consider to extend the method
        to handle more complex input functions.

        Parameters
        ----------
        x: torch.Tensor
            The input data vector.
        func: str
            The string description of the functoin name or class.
        device: str, default = 'cpu'
            The device to host and apply the recognized functions.

        Returns
        -------
        torch.Tensor
            The processed input data vector by the recognized functions.
        """
        if func is None:
            return x
        elif callable(func):
            # --------------------------
            if func in [F.sigmoid, F.relu, F.leaky_relu, F.tanh, F.softplus, F.silu, F.celu, F.gelu]:
                return func(x)
            # --------------------------
            # dropout functions
            elif func in [
                F.dropout
            ]:
                # --------------------------
                if 'p' in kwargs:
                    p = kwargs['p']
                else:
                    p = 0.5
                # --------------------------
                if func in [F.dropout]:
                    return func(x, p=p)
                else:
                    return func(p=p)(x)
            # --------------------------
            # layer_norm functions
            elif func in [F.layer_norm]:
                # --------------------------
                if 'normalized_shape' in kwargs:
                    normalized_shape = kwargs['normalized_shape']
                else:
                    normalized_shape = [x.size(-1)]
                # --------------------------
                if func in [F.layer_norm]:
                    return func(x, normalized_shape=normalized_shape)
                else:
                    return func(normalized_shape=normalized_shape)(x)
                # --------------------------
            # --------------------------
            # batch_norm functions
            elif func in [
                torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d,
                torch.nn.modules.batchnorm.BatchNorm1d,
                torch.nn.modules.batchnorm.BatchNorm2d,
                torch.nn.modules.batchnorm.BatchNorm3d
            ]:
                # --------------------------
                if 'num_features' in kwargs:
                    num_features = kwargs['num_features']
                else:
                    num_features = x.size(-1)
                # ---------------------------
                return func(num_features=num_features, device=device)(x)
            # --------------------------
            # other functions
            elif func in [
                torch.exp,
            ]:
                return func(x)
            # --------------------------
            else:
                warnings.warn(
                    'input function {} not recognized, the original input x will be returned by default...'.format(
                        func),
                    UserWarning)
                return x
        # ------------------------------
        # All functions from configs will convert from str to object first
        elif type(func) is str:
            try:
                if '.' in func:
                    func = config.get_obj_from_str(func)
                else:
                    func = config.get_obj_from_str("torch.nn.functional.{}".format(func.lower()))
            except:
                raise ValueError(
                    'function {} does\'t belong to "torch.nn.functional.", please provide the complete callable function path, such as "torch.nn.functional.sigmoid..."'.format(
                        func))
            return function.str_func_x(x, func, device=device, *args, **kwargs)
        else:
            warnings.warn('input function not recognized, the original input x will be returned by default...',
                          UserWarning)
            return x

    @staticmethod
    def string_to_function(formula, variable):
        """
        Formula recognition from strings.

        It recognizes and returns the formula and variables from strings via the sympy package.

        Parameters
        ----------
        formula: str
            The function formula as a string.
        variable: list
            The list of the variables involved in the formula.

        Returns
        -------
        sympy.FunctionClass
            The recognized function of the input formula.
        """
        # Define the symbol
        var = sp.symbols(variable)

        # Parse the formula string into a sympy expression
        expression = sp.sympify(formula)

        # Convert the sympy expression to a lambda function
        func = sp.lambdify(var, expression, 'numpy')

        return func

    @staticmethod
    def functions_to_configs(functions: list | tuple | Callable, class_name: str = 'function_class', parameter_name: str = 'function_parameters'):
        """
        Converts a list of functions into a serialized configuration.

        Parameters
        ----------
        functions : list | tuple | Callable
            A list of functions or a single callable function to serialize.
        class_name : str, optional
            The key for the class name in the configuration, by default 'function_class'.
        parameter_name : str, optional
            The key for the parameters in the configuration, by default 'function_parameters'.

        Returns
        -------
        list or dict
            A serialized configuration of the functions.
        """
        if functions is None:
            return None
        elif isinstance(functions, Callable):
            func_class_name = f"{functions.__class__.__module__}.{functions.__class__.__name__}"
            func_parameters = {attr: getattr(functions, attr) for attr in functions.__dict__}
            return {
                class_name: func_class_name,
                parameter_name: func_parameters
            }
        else:
            return [
                function.functions_to_configs(func) for func in functions
            ]

    @abstractmethod
    def forward(self, *args, **kwargs):
        """
        Abstract method for the forward pass of the function.

        This method must be implemented in subclasses.

        Parameters
        ----------
        *args : tuple
            Positional arguments for the function.
        **kwargs : dict
            Keyword arguments for the function.
        """
        pass

__call__(*args, **kwargs)

Calls the forward method of the function.

Parameters:

Name Type Description Default
*args tuple

Positional arguments for the forward method.

()
**kwargs dict

Keyword arguments for the forward method.

{}
Source code in tinybig/module/base_function.py
def __call__(self, *args, **kwargs):
    """
    Calls the `forward` method of the function.

    Parameters
    ----------
    *args : tuple
        Positional arguments for the `forward` method.
    **kwargs : dict
        Keyword arguments for the `forward` method.
    """
    self.forward(*args, **kwargs)

__init__(name='base_function', device='cpu', *args, **kwargs)

Initializes the base function with a name and device.

Parameters:

Name Type Description Default
name str

The name of the function, by default 'base_function'.

'base_function'
device str

The computational device for the function, by default 'cpu'.

'cpu'
*args tuple

Additional positional arguments.

()
**kwargs dict

Additional keyword arguments.

{}
Source code in tinybig/module/base_function.py
def __init__(self, name: str = 'base_function', device: str = 'cpu', *args, **kwargs):
    """
    Initializes the base function with a name and device.

    Parameters
    ----------
    name : str, optional
        The name of the function, by default 'base_function'.
    device : str, optional
        The computational device for the function, by default 'cpu'.
    *args : tuple
        Additional positional arguments.
    **kwargs : dict
        Additional keyword arguments.
    """
    self.name = name
    self.device = device

forward(*args, **kwargs) abstractmethod

Abstract method for the forward pass of the function.

This method must be implemented in subclasses.

Parameters:

Name Type Description Default
*args tuple

Positional arguments for the function.

()
**kwargs dict

Keyword arguments for the function.

{}
Source code in tinybig/module/base_function.py
@abstractmethod
def forward(self, *args, **kwargs):
    """
    Abstract method for the forward pass of the function.

    This method must be implemented in subclasses.

    Parameters
    ----------
    *args : tuple
        Positional arguments for the function.
    **kwargs : dict
        Keyword arguments for the function.
    """
    pass

func_x(x, functions, device='cpu') staticmethod

The function execution to the input data.

It applies the list of functions to the input vector and returns the calculation results.

This method will be extensively called for handeling the data processing functions in the expansion functions, RPN head and remainder functions in tinyBIG.

  • preprocess_functions in expansion functions
  • postprocess_functions in expansion functions
  • output_process_functions in rpn heads
  • activation_functions in remainder functions

Parameters:

Name Type Description Default
x

The input data vector.

required
functions

The functions to be applied to the input vector. The function can be callable functions, string names of the functions, the complete class descriptions of the functions, etc.

required
device str

The device to perform the function on the input vector.

'cpu'

Returns:

Type Description
Tensor

The processed input vector by these functions.

Source code in tinybig/module/base_function.py
@staticmethod
def func_x(x, functions, device: str = 'cpu'):
    """
    The function execution to the input data.

    It applies the list of functions to the input vector and returns the calculation results.

    This method will be extensively called for handeling the data processing functions in the
    expansion functions, RPN head and remainder functions in tinyBIG.

    * preprocess_functions in expansion functions
    * postprocess_functions in expansion functions
    * output_process_functions in rpn heads
    * activation_functions in remainder functions

    Parameters
    ----------
    x: torch.Tensor
        The input data vector.
    functions: list | tuple | callable
        The functions to be applied to the input vector. The function can be callable functions,
        string names of the functions, the complete class descriptions of the functions, etc.
    device: str, default = 'cpu'
        The device to perform the function on the input vector.

    Returns
    -------
    torch.Tensor
        The processed input vector by these functions.
    """
    if functions is None or ((isinstance(functions, list) or isinstance(functions, tuple)) and len(functions) == 0):
        return x
    elif isinstance(functions, list) or isinstance(functions, tuple):
        for f in functions:
            if callable(f):
                x = f(x)
            elif type(f) is str:
                x = function.str_func_x(x=x, func=f, device=device)
        return x
    else:
        if callable(functions):
            return functions(x)
        elif type(functions) is str:
            return function.str_func_x(x=x, func=functions, device=device)

functions_to_configs(functions, class_name='function_class', parameter_name='function_parameters') staticmethod

Converts a list of functions into a serialized configuration.

Parameters:

Name Type Description Default
functions list | tuple | Callable

A list of functions or a single callable function to serialize.

required
class_name str

The key for the class name in the configuration, by default 'function_class'.

'function_class'
parameter_name str

The key for the parameters in the configuration, by default 'function_parameters'.

'function_parameters'

Returns:

Type Description
list or dict

A serialized configuration of the functions.

Source code in tinybig/module/base_function.py
@staticmethod
def functions_to_configs(functions: list | tuple | Callable, class_name: str = 'function_class', parameter_name: str = 'function_parameters'):
    """
    Converts a list of functions into a serialized configuration.

    Parameters
    ----------
    functions : list | tuple | Callable
        A list of functions or a single callable function to serialize.
    class_name : str, optional
        The key for the class name in the configuration, by default 'function_class'.
    parameter_name : str, optional
        The key for the parameters in the configuration, by default 'function_parameters'.

    Returns
    -------
    list or dict
        A serialized configuration of the functions.
    """
    if functions is None:
        return None
    elif isinstance(functions, Callable):
        func_class_name = f"{functions.__class__.__module__}.{functions.__class__.__name__}"
        func_parameters = {attr: getattr(functions, attr) for attr in functions.__dict__}
        return {
            class_name: func_class_name,
            parameter_name: func_parameters
        }
    else:
        return [
            function.functions_to_configs(func) for func in functions
        ]

get_name()

The name retrieval method of the function.

It returns the name of the function.

Returns:

Type Description
str

The name of the function.

Source code in tinybig/module/base_function.py
def get_name(self):
    """
    The name retrieval method of the function.

    It returns the name of the function.

    Returns
    -------
    str
        The name of the function.
    """
    return self.name

str_func_x(x, func, device='cpu', *args, **kwargs) staticmethod

Function recognition from their string names or string class descriptions.

It recognizes the data processing functions from their names or class description in strings, e.g., "layer_norm" or "torch.nn.functional.layer_norm".

Since these functions can be very diverse, whose definitions are also very different, it makes it very challenging to process them based on their string descriptions. This method can process some basic functions, e.g., activation functions, and normalization functions. For the functions that are not implemented in this method, the users may consider to extend the method to handle more complex input functions.

Parameters:

Name Type Description Default
x

The input data vector.

required
func str | Callable

The string description of the functoin name or class.

required
device

The device to host and apply the recognized functions.

'cpu'

Returns:

Type Description
Tensor

The processed input data vector by the recognized functions.

Source code in tinybig/module/base_function.py
@staticmethod
def str_func_x(x, func: str | Callable, device='cpu', *args, **kwargs):
    """
    Function recognition from their string names or string class descriptions.

    It recognizes the data processing functions from their names or class description in strings,
    e.g., "layer_norm" or "torch.nn.functional.layer_norm".

    Since these functions can be very diverse, whose definitions are also very different,
    it makes it very challenging to process them based on their string descriptions.
    This method can process some basic functions, e.g., activation functions, and normalization functions.
    For the functions that are not implemented in this method, the users may consider to extend the method
    to handle more complex input functions.

    Parameters
    ----------
    x: torch.Tensor
        The input data vector.
    func: str
        The string description of the functoin name or class.
    device: str, default = 'cpu'
        The device to host and apply the recognized functions.

    Returns
    -------
    torch.Tensor
        The processed input data vector by the recognized functions.
    """
    if func is None:
        return x
    elif callable(func):
        # --------------------------
        if func in [F.sigmoid, F.relu, F.leaky_relu, F.tanh, F.softplus, F.silu, F.celu, F.gelu]:
            return func(x)
        # --------------------------
        # dropout functions
        elif func in [
            F.dropout
        ]:
            # --------------------------
            if 'p' in kwargs:
                p = kwargs['p']
            else:
                p = 0.5
            # --------------------------
            if func in [F.dropout]:
                return func(x, p=p)
            else:
                return func(p=p)(x)
        # --------------------------
        # layer_norm functions
        elif func in [F.layer_norm]:
            # --------------------------
            if 'normalized_shape' in kwargs:
                normalized_shape = kwargs['normalized_shape']
            else:
                normalized_shape = [x.size(-1)]
            # --------------------------
            if func in [F.layer_norm]:
                return func(x, normalized_shape=normalized_shape)
            else:
                return func(normalized_shape=normalized_shape)(x)
            # --------------------------
        # --------------------------
        # batch_norm functions
        elif func in [
            torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d,
            torch.nn.modules.batchnorm.BatchNorm1d,
            torch.nn.modules.batchnorm.BatchNorm2d,
            torch.nn.modules.batchnorm.BatchNorm3d
        ]:
            # --------------------------
            if 'num_features' in kwargs:
                num_features = kwargs['num_features']
            else:
                num_features = x.size(-1)
            # ---------------------------
            return func(num_features=num_features, device=device)(x)
        # --------------------------
        # other functions
        elif func in [
            torch.exp,
        ]:
            return func(x)
        # --------------------------
        else:
            warnings.warn(
                'input function {} not recognized, the original input x will be returned by default...'.format(
                    func),
                UserWarning)
            return x
    # ------------------------------
    # All functions from configs will convert from str to object first
    elif type(func) is str:
        try:
            if '.' in func:
                func = config.get_obj_from_str(func)
            else:
                func = config.get_obj_from_str("torch.nn.functional.{}".format(func.lower()))
        except:
            raise ValueError(
                'function {} does\'t belong to "torch.nn.functional.", please provide the complete callable function path, such as "torch.nn.functional.sigmoid..."'.format(
                    func))
        return function.str_func_x(x, func, device=device, *args, **kwargs)
    else:
        warnings.warn('input function not recognized, the original input x will be returned by default...',
                      UserWarning)
        return x

string_to_function(formula, variable) staticmethod

Formula recognition from strings.

It recognizes and returns the formula and variables from strings via the sympy package.

Parameters:

Name Type Description Default
formula

The function formula as a string.

required
variable

The list of the variables involved in the formula.

required

Returns:

Type Description
FunctionClass

The recognized function of the input formula.

Source code in tinybig/module/base_function.py
@staticmethod
def string_to_function(formula, variable):
    """
    Formula recognition from strings.

    It recognizes and returns the formula and variables from strings via the sympy package.

    Parameters
    ----------
    formula: str
        The function formula as a string.
    variable: list
        The list of the variables involved in the formula.

    Returns
    -------
    sympy.FunctionClass
        The recognized function of the input formula.
    """
    # Define the symbol
    var = sp.symbols(variable)

    # Parse the formula string into a sympy expression
    expression = sp.sympify(formula)

    # Convert the sympy expression to a lambda function
    func = sp.lambdify(var, expression, 'numpy')

    return func

to_config()

Serializes the function into a dictionary configuration.

Returns:

Type Description
dict

A dictionary containing the class name and parameters of the function.

Source code in tinybig/module/base_function.py
def to_config(self):
    """
    Serializes the function into a dictionary configuration.

    Returns
    -------
    dict
        A dictionary containing the class name and parameters of the function.
    """
    class_name = f"{self.__class__.__module__}.{self.__class__.__name__}"
    attributes = {attr: getattr(self, attr) for attr in self.__dict__}
    return {
        "function_class": class_name,
        "function_parameters": attributes
    }