Source code for meerkat.ml.embedding_column
from __future__ import annotations
import logging
from collections.abc import Sequence
from types import SimpleNamespace
from typing import Union
import numpy as np
import pandas as pd
import torch
from meerkat.columns.tensor_column import TensorColumn
from meerkat.tools.lazy_loader import LazyLoader
faiss = LazyLoader("faiss")
umap = LazyLoader("umap")
umap_plot = LazyLoader("umap.plot")
sklearn_decom = LazyLoader("sklearn.decomposition")
Columnable = Union[Sequence, np.ndarray, pd.Series, torch.Tensor]
logger = logging.getLogger(__name__)
[docs]class EmbeddingColumn(TensorColumn):
def __init__(
self,
data: Sequence = None,
*args,
**kwargs,
):
if data is not None and isinstance(data, TensorColumn):
data = data._data
super(EmbeddingColumn, self).__init__(data=data, *args, **kwargs)
# Cast to float32
self._data = self._data.type(torch.FloatTensor)
self.faiss_index = None
[docs] def build_faiss_index(self, index=None, overwrite=False):
if self.ndim < 2:
raise ValueError("Building an index requires `ndim` >= 2.")
if self.faiss_index is not None and not overwrite:
return
if index is None:
index = faiss.IndexFlatL2
# Create the faiss index
self.faiss_index = index(self.shape[1])
# Add the data: must be np.ndarray
self.faiss_index.add(self.numpy())
[docs] def search(self, query, k: int):
assert (
self.faiss_index is not None
), "You must call ``build_faiss_index`` first."
if isinstance(query, np.ndarray):
query = query.astype("float32")
return self.faiss_index.search(query, k)
return NotImplemented("`query` be np.ndarray.")
[docs] def pca(self, n_components=2):
pca = sklearn_decom.PCA(n_components)
return SimpleNamespace(
embeddings=pca.fit_transform(self.numpy()),
reducer=pca,
)
[docs] def umap(self, n_neighbors=15, n_components=2):
reducer = umap.UMAP(
n_neighbors=n_neighbors,
n_components=n_components,
)
return SimpleNamespace(
embeddings=reducer.fit_transform(self.numpy()),
reducer=reducer,
)
[docs] def visualize_umap(self, n_neighbors=15, n_components=2, point_size=4):
p = umap_plot.interactive(
self.umap(n_neighbors, n_components).reducer,
point_size=point_size,
)
umap_plot.show(p)