Source code for meerkat.ops.sliceby.clusterby

from __future__ import annotations

from typing import Dict, List, Sequence, Tuple, Union

import numpy as np
from sklearn.base import ClusterMixin

from meerkat.dataframe import DataFrame
from meerkat.ops.cluster import cluster

from .sliceby import SliceBy


class ClusterBy(SliceBy):
    def __init__(
        self,
        data: DataFrame,
        by: Union[List[str], str],
        sets: Dict[Union[str, Tuple[str]], np.ndarray] = None,
    ):
        super().__init__(data=data, by=by, sets=sets)


[docs]def clusterby( data: DataFrame, by: Union[str, Sequence[str]], method: Union[str, ClusterMixin] = "KMeans", encoder: str = "clip", # add support for auto selection of encoder modality: str = None, **kwargs, ) -> ClusterBy: """Perform a clusterby operation on a DataFrame. Args: data (DataFrame): The dataframe to cluster. by (Union[str, Sequence[str]]): The column(s) to cluster by. These columns will be embedded using the ``encoder`` and the resulting embedding will be used. method (Union[str, ClusterMixin]): The clustering method to use. encoder (str): The encoder to use for the embedding. Defaults to ``clip``. modality (Union[str, Sequence[str])): The modality to of the **kwargs: Additional keyword arguments to pass to the clustering method. Returns: ClusterBy: A ClusterBy object. """ out_col = f"{method}({by})" # TODO (sabri): Give the user the option to specify the output and remove this guard # once caching is supported if not isinstance(by, str): raise NotImplementedError if out_col not in data: data, _ = cluster( data=data, input=by, method=method, encoder=encoder, modality=modality, **kwargs, ) clusters = data[out_col] cluster_indices = {key: np.where(clusters == key)[0] for key in np.unique(clusters)} if isinstance(by, str): by = [by] return ClusterBy(data=data, sets=cluster_indices, by=by)