Bases: combinatorial_compression
Source code in tinybig/compression/combinatorial_compression.py
| class combinatorial_probabilistic_compression(combinatorial_compression):
def __init__(
self,
name: str = 'combinatorial_probabilistic_compression',
d: int = 1, k: int = 1,
metric: Callable[[torch.Tensor], torch.Tensor] = None,
with_replacement: bool = False,
require_normalization: bool = True,
*args, **kwargs
):
super().__init__(
name=name,
d=d, k=k,
metric=metric,
simply_sampling=False,
log_prob=True,
with_replacement=with_replacement,
require_normalization=require_normalization,
*args, **kwargs
)
|