from __future__ import annotations
import warnings
import weakref
from copy import copy
from functools import wraps
from inspect import getcallargs
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union
import meerkat as mk
from meerkat.errors import ExperimentalWarning
_provenance_enabled = False
[docs]def set_provenance(enabled=True):
global _provenance_enabled
_provenance_enabled = enabled
[docs]class provenance:
def __init__(self, enabled: bool = True):
"""Context manager for enabling provenance capture in Meerkat.
Example:
```python
import meerkat as mk
with mk.provenance():
dp = mk.DataPanel.from_batch({...})
```
"""
self.prev = None
self.enabled = enabled
def __enter__(self):
global _provenance_enabled
self.prev = _provenance_enabled
_provenance_enabled = self.enabled
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
global _provenance_enabled
_provenance_enabled = self.prev
[docs]def is_provenance_enabled():
return _provenance_enabled
[docs]class ProvenanceMixin:
def __init__(self, *args, **kwargs):
super(ProvenanceMixin, self).__init__(*args, **kwargs)
self._init_node()
def _init_node(self):
self._node = ProvenanceObjNode(self)
@property
def node(self):
return self._node
[docs] def get_provenance(
self, include_columns: bool = False, last_parent_only: bool = False
):
if self._node is None:
return None
nodes, edges = self._node.get_provenance(last_parent_only=last_parent_only)
if include_columns:
return nodes, edges
else:
# filter out edges and nodes
nodes = [
node
for node in nodes
if (
(
isinstance(node, ProvenanceOpNode)
and any(
[
issubclass(node.type, mk.DataPanel)
for node, key in node.children
]
)
)
or issubclass(node.type, mk.DataPanel)
)
]
edges = [
edge for edge in edges if all([(node in nodes) for node in edge[:2]])
]
return nodes, edges
[docs]class ProvenanceNode:
[docs] def add_child(self, node: ProvenanceNode, key: Tuple):
self._children.append((node, key))
with provenance(enabled=False):
self.cache_repr()
[docs] def add_parent(self, node: ProvenanceNode, key: Tuple):
self._parents.append((node, key))
with provenance(enabled=False):
self.cache_repr()
@property
def parents(self):
return self._parents
@property
def last_parent(self):
if len(self._parents) > 0:
return self._parents[-1]
else:
return None
@property
def children(self):
return self._children
[docs] def get_provenance(self, last_parent_only: bool = False):
edges = set()
nodes = set()
for parent, key in (
self.parents[-1:]
if (last_parent_only and self.__class__ == ProvenanceObjNode)
else self.parents
):
prev_nodes, prev_edges = parent.get_provenance(
last_parent_only=last_parent_only
)
edges |= prev_edges
edges.add((parent, self, key))
nodes |= prev_nodes
nodes.add(self)
return nodes, edges
[docs] def cache_repr(self):
self._cached_repr = self._repr() + "(deleted)"
def _repr(self):
raise NotImplementedError()
def __repr__(self):
if self.ref() is None:
return self._cached_repr
else:
return self._repr()
def __getstate__(self):
state = copy(self.__dict__)
state["ref"] = None
return state
def __setstate__(self, state: dict):
state["ref"] = lambda: None
self.__dict__.update(state)
[docs]class ProvenanceObjNode(ProvenanceNode):
def __init__(self, obj: ProvenanceMixin):
self.ref = weakref.ref(obj)
self.type = type(obj)
self._parents: List[Tuple(ProvenanceOpNode, tuple)] = []
self._children: List[Tuple(ProvenanceOpNode, tuple)] = []
def _repr(self):
return f"ObjNode({str(self.ref())})"
[docs]class ProvenanceOpNode(ProvenanceNode):
def __init__(
self, fn: callable, inputs: dict, outputs: object, captured_args: dict
):
self.ref = weakref.ref(fn)
self.type = type(fn)
self.name = fn.__qualname__
self.module = fn.__module__
self.captured_args = captured_args
self._parents: List[Tuple(ProvenanceObjNode, tuple)] = []
self._children: List[Tuple(ProvenanceObjNode, tuple)] = []
for key, inp in inputs.items():
self.add_parent(inp.node, key)
inp.node.add_child(self, key)
input_ids = {id(v) for v in inputs.values()}
for key, output in outputs.items():
# only include outputs that weren't passed in
if id(output) not in input_ids:
self.add_child(output.node, key)
output.node.add_parent(self, key)
def _repr(self):
return f"OpNode({self.module}.{self.name})"
[docs]def capture_provenance(capture_args: Sequence[str] = None):
capture_args = [] if capture_args is None else capture_args
def _provenance(fn: callable):
@wraps(fn)
def _wrapper(*args, **kwargs):
if not is_provenance_enabled():
return fn(*args, **kwargs)
args_dict = getcallargs(fn, *args, **kwargs)
if "kwargs" in args_dict:
args_dict.update(args_dict.pop("kwargs"))
captured_args = {arg: args_dict[arg] for arg in capture_args}
out = fn(*args, **kwargs)
# collect instances of ProvenanceMixin nested in the output
inputs = get_nested_objs(args_dict)
outputs = get_nested_objs(out)
ProvenanceOpNode(
fn=fn, inputs=inputs, outputs=outputs, captured_args=captured_args
)
return out
return _wrapper
return _provenance
[docs]def get_nested_objs(data):
"""Recursively get DataPanels and Columns from nested collections."""
objs = {}
_get_nested_objs(objs, (), data)
return objs
def _get_nested_objs(objs: Dict, key: Tuple[str], data: object):
if isinstance(data, Sequence) and not isinstance(data, str):
for idx, item in enumerate(data):
_get_nested_objs(objs, key=(*key, idx), data=item)
elif isinstance(data, Mapping):
for curr_key, item in data.items():
_get_nested_objs(objs, key=(*key, curr_key), data=item)
elif isinstance(data, mk.DataPanel):
objs[key] = data
for curr_key, item in data.items():
_get_nested_objs(objs, key=(*key, curr_key), data=item)
elif isinstance(data, mk.AbstractColumn):
objs[key] = data
[docs]def visualize_provenance(
obj: Union[ProvenanceObjNode, ProvenanceOpNode],
show_columns: bool = False,
last_parent_only: bool = False,
):
warnings.warn( # pragma: no cover
ExperimentalWarning(
"The function `meerkat.provenance.visualize_provenance` is experimental and"
" has limited test coverage. Proceed with caution."
)
)
try: # pragma: no cover
import cyjupyter
except ImportError: # pragma: no cover
raise ImportError(
"`visualize_provenance` requires the `cyjupyter` dependency."
"See https://github.com/cytoscape/cytoscape-jupyter-widget"
)
nodes, edges = obj.get_provenance( # pragma: no cover
include_columns=show_columns, last_parent_only=last_parent_only
)
cy_nodes = [ # pragma: no cover
{
"data": {
"id": id(node),
"name": str(node),
"type": "obj" if isinstance(node, ProvenanceObjNode) else "op",
}
}
for node in nodes
]
cy_edges = [ # pragma: no cover
{
"data": {
"source": id(edge[0]),
"target": id(edge[1]),
"key": edge[2],
}
}
for edge in edges
]
cy_data = {"elements": {"nodes": cy_nodes, "edges": cy_edges}} # pragma: no cover
style = [ # pragma: no cover
{
"selector": "node",
"css": {
"content": "data(name)",
"background-color": "#fc8d62",
"border-color": "#252525",
"border-opacity": 1.0,
"border-width": 3,
},
},
{
"selector": "node[type = 'op']",
"css": {
"shape": "barrel",
"background-color": "#8da0cb",
},
},
{
# need to double index to access metadata (degree etc.)
"selector": "node[[degree = 0]]",
"css": {
"visibility": "hidden",
},
},
{
"selector": "edge",
"css": {
"content": "data(key)",
"line-color": "#252525",
"mid-target-arrow-color": "#8da0cb",
"mid-target-arrow-shape": "triangle",
"arrow-scale": 2.5,
"text-margin-x": 10,
"text-margin-y": 10,
},
},
]
return cyjupyter.Cytoscape( # pragma: no cover
data=cy_data, visual_style=style, layout_name="breadthfirst"
)