Skip to content

incremental_variance_threshold

Bases: feature_selection

Source code in tinybig/koala/machine_learning/feature_selection/incremental_variance_threshold.py
class incremental_variance_threshold(feature_selection):

    def __init__(self, threshold: float = 0.0, name: str = 'incremental_variance_threshold', *args, **kwargs):
        super().__init__(name=name, *args, **kwargs)

        self.threshold = threshold
        self.v = None
        self.t = None

    def update_n_feature(self, new_n_feature: int):
        assert new_n_feature > 0
        self.set_n_feature(new_n_feature)
        self.v = None
        self.t = None

    def update_threshold(self, new_threshold: float):
        self.threshold = new_threshold
        self.v = None
        self.t = None

    def update_v(self, new_v: torch.Tensor):
        if self.incremental:
            if self.v is None:
                self.v = torch.zeros_like(new_v)
                self.t = 0

            assert new_v.shape == self.v.shape and self.t >= 0
            self.t += 1
            old_v = self.v
            self.v = ((self.t - 1) * self.v + new_v)/self.t

            if self.t >= self.t_threshold or euclidean_distance(x=old_v, x2=self.v) < self.incremental_stop_threshold:
                self.incremental = False
        else:
            self.v = new_v

    def fit(self, X: Union[np.ndarray, torch.Tensor], device: str = 'cpu', *args, **kwargs):
        X = torch.tensor(X)
        new_v = batch_variance(X, dim=0)
        self.update_v(new_v)

    def transform(self, X: Union[np.ndarray, torch.Tensor], device: str = 'cpu', *args, **kwargs):
        input_X = torch.tensor(X)

        assert self.v is not None and self.v.shape[0] == input_X.shape[1]

        if self.n_feature is not None:
            n = min(self.n_feature, input_X.shape[1])
            indices = np.argsort(self.v)[-n:]
        else:
            indices = np.where(self.v >= self.threshold)[0]

        if len(indices) == 0:
            indices = np.arange(self.v.size)

        X_selected = input_X[:, indices]

        assert X_selected.shape[1] == self.n_feature
        return X_selected.detach().cpu().numpy() if isinstance(X, np.ndarray) and not isinstance(X_selected, np.ndarray) else X_selected