Source code for meerkat.cells.spacy
from __future__ import annotations
from meerkat.cells.abstract import AbstractCell
from meerkat.tools.lazy_loader import LazyLoader
spacy = LazyLoader("spacy")
spacy_attrs = LazyLoader("spacy.attrs")
spacy_tokens = LazyLoader("spacy.tokens")
[docs]class SpacyCell(AbstractCell):
def __init__(
self,
doc: spacy_tokens.Doc,
*args,
**kwargs,
):
super(SpacyCell, self).__init__(*args, **kwargs)
# Put some data into the cell
self.doc = doc
[docs] def default_loader(self, *args, **kwargs):
return self
[docs] def get(self, *args, **kwargs):
return self.doc
[docs] def get_state(self):
attrs = [name for name in spacy_attrs.NAMES if name != "HEAD"]
arr = self.get().to_array(attrs)
return {
"arr": arr.flatten(),
"shape": list(arr.shape),
"words": [t.text for t in self.get()],
}
[docs] @classmethod
def from_state(cls, state, nlp: spacy.language.Language):
doc = spacy_tokens.Doc(nlp.vocab, words=state["words"])
attrs = [name for name in spacy_attrs.NAMES if name != "HEAD"]
return cls(doc.from_array(attrs, state["arr"].reshape(state["shape"])))
def __getitem__(self, index):
return self.get()[index]
def __getattr__(self, item):
try:
return getattr(self.get(), item)
except AttributeError:
raise AttributeError(f"Attribute {item} not found.")
def __repr__(self):
return f"{self.__class__.__name__}({self.get().__repr__()})"
[docs]class LazySpacyCell(AbstractCell):
def __init__(
self,
text: str,
nlp: spacy.language.Language,
*args,
**kwargs,
):
super(LazySpacyCell, self).__init__(*args, **kwargs)
# Put some data into the cell
self.text = text
self.nlp = nlp
[docs] def default_loader(self, *args, **kwargs):
return self
[docs] def get(self, *args, **kwargs):
return self.nlp(self.text)
[docs] def get_state(self):
return {
"text": self.text,
}
[docs] @classmethod
def from_state(cls, state, nlp: spacy.language.Language):
return cls(text=state["text"], nlp=nlp)
def __getitem__(self, index):
return self.get()[index]
def __getattr__(self, item):
try:
return getattr(self.get(), item)
except AttributeError:
raise AttributeError(f"Attribute {item} not found.")
def __repr__(self):
snippet = f"{self.text[:15]}..." if len(self.text) > 20 else self.text
return f"{self.__class__.__name__}({snippet})"