Skip to content

naive_probabilistic_compression

Bases: transformation

Source code in tinybig/compression/probabilistic_compression.py
class naive_probabilistic_compression(transformation):
    def __init__(
        self,
        k: int,
        name: str = 'probabilistic_compression',
        simply_sampling: bool = True,
        distribution_function: torch.distributions = None,
        distribution_function_configs: dict = None,
        metric: Callable[[torch.Tensor], torch.Tensor] = None,
        with_replacement: bool = False,
        require_normalization: bool = True,
        log_prob: bool = False,
        *args, **kwargs
    ):
        super().__init__(name=name, *args, **kwargs)
        self.k = k
        self.metric = metric
        self.simply_sampling = simply_sampling
        self.with_replacement = with_replacement
        self.require_normalization = require_normalization
        self.log_prob = log_prob

        if self.simply_sampling:
            self. log_prob = False

        if distribution_function is not None:
            self.distribution_function = distribution_function
        elif distribution_function_configs is not None:
            function_class = distribution_function_configs['function_class']
            function_parameters = distribution_function_configs['function_parameters'] if 'function_parameters' in distribution_function_configs else {}
            self.distribution_function = config.get_obj_from_str(function_class)(**function_parameters)
        else:
            self.distribution_function = None

        if self.distribution_function is None:
            self.distribution_function = torch.distributions.uniform.Uniform(low=0.0, high=1.0)

    def calculate_D(self, m: int):
        assert self.k is not None and 0 <= self.k <= m
        return self.k

    def to_config(self):
        configs = super().to_config()
        configs['function_parameters'].pop('distribution_function')
        if self.distribution_function is not None:
            configs['function_parameters']['distribution_function_configs'] = function.functions_to_configs(self.distribution_function)
        return configs

    def calculate_weights(self, x: torch.Tensor):
        if self.distribution_function is not None:
            x = torch.exp(self.distribution_function.log_prob(x))
            weights = x/x.sum(dim=-1, keepdim=True)
        else:
            b, m = x.shape
            weights = torch.ones((b, m)) / m
        return weights

    def forward(self, x: torch.Tensor, device='cpu', *args, **kwargs):
        b, m = x.shape
        x = self.pre_process(x=x, device=device)

        data_x = None
        if self.simply_sampling:
            data_x = x.clone()

        if self.metric is not None:
            x = self.metric(x)
        if self.require_normalization:
            x = 0.99 * torch.nn.functional.sigmoid(x) + 0.001

        weights = self.calculate_weights(x)
        sampled_indices = torch.multinomial(weights, self.calculate_D(m=m), replacement=self.with_replacement)
        sampled_indices, _ = torch.sort(sampled_indices, dim=1)

        if self.simply_sampling:
            compression = torch.gather(data_x, 1, sampled_indices)
        else:
            compression = torch.gather(x, 1, sampled_indices)

        if self.log_prob:
            compression = self.distribution_function.log_prob(compression)

        assert compression.shape == (b, self.calculate_D(m=m))
        return self.post_process(x=compression, device=device)