Source code for ytree.visualization.tree_plot

visualization imports


# Copyright (c) ytree development team. All rights reserved.
# Distributed under the terms of the Modified BSD License.
# The full license is in the file COPYING.txt, distributed with this software.

from functools import wraps
import numpy as np

from unyt import \

    import pydot
except ImportError:
    pydot = None

def clear_graph(f):
    def newfunc(*args, **kwargs):
        rv = f(*args, **kwargs)
        args[0].graph = None
        return rv
    return newfunc

[docs]class TreePlot: """ Make a simple merger tree plot using pydot and graphviz. Parameters ---------- tree : merger tree node :class:`~ytree.data_structures.tree_node.TreeNode` The merger tree to be plotted. dot_kwargs : optional, dict A dictionary of keyword arguments to be passed to pydot.Dot. Default: None. node_function: optional, function A function accepting a single argument of a :class:`~ytree.data_structures.tree_node.TreeNode` and returning a dictionary of keywords to be given to pydot for creating the node object on the plot. This can be used to customize the appearance of the nodes. See examples below. Default: None. edge_function: optional, function A function accepting two :class:`~ytree.data_structures.tree_node.TreeNode` objects and returning a dictionary of keywords to be given to pydot for creating the edge object on the plot (the lines connecting halos). This can be used to customize the appearance of the edges. See examples below. Default: None. Attributes ---------- size_field : str The field to determine the size of each circle. Default: 'mass'. size_log : bool Whether to scale circle sizes based on log of size field. Default: True. min_mass : float or unyt_quantity The minimum halo mass to be included in the plot. If given as a float, units are assumed to be Msun. Default: None. min_mass_ratio : float The minimum ratio between a halo's mass and the mass of the main halo to be included in the plot. Default: None. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> p = ytree.TreePlot(a[0]) >>> p.min_mass = 1e6 # Msun >>> >>> # customizing nodes >>> import ytree >>> def my_node(halo): ... label = f"{halo['uid']}" ... my_kwargs = {"label": label, "fontsize": 8, "shape": "square"} ... return my_kwargs >>> a = ytree.load("tree_0_0_0.dat") >>> p = ytree.TreePlot(a[0], node_function=my_node) >>> >>> # customizing edges >>> import ytree >>> def my_edge(ancestor, descendent): ... if descendent['mass'] < ancestor['mass']: ... color = 'blue' ... else: ... color = 'black' ... my_kwargs = {"color": color, "penwidth": 5} ... return my_kwargs >>> a = ytree.load("tree_0_0_0.dat") >>> p = ytree.TreePlot(a[0], edge_function=my_edge) >>> """ _min_dot_size = 0.2 _max_dot_size = 2 _min_field_size = None _max_field_size = None _size_field = 'mass' _size_log = True _min_mass = None _min_mass_ratio = None
[docs] def __init__(self, tree, dot_kwargs=None, node_function=None, edge_function=None): """ Initialize a TreePlot. """ if pydot is None: raise RuntimeError( "TreePlot requires the pydot module. " + "You may also need to install graphviz.") self.tree = tree self.dot_kwargs = dict() if dot_kwargs is None: dot_kwargs = {} self.dot_kwargs.update(dot_kwargs) if node_function is not None and \ not callable(node_function): raise RuntimeError( "node_function should be a callable function.") self.node_function = node_function if edge_function is not None and \ not callable(edge_function): raise RuntimeError( "edge_function should be a callable function.") self.edge_function = edge_function self.graph = None
[docs] def save(self, filename=None): """ Save the merger tree plot. Parameters ---------- filename: optional, str The output filename. If none given, the uid of the head node is used. Default: None. Examples -------- >>> import ytree >>> a = ytree.load("tree_0_0_0.dat") >>> p = ytree.TreePlot(a[0]) >>>'tree.png') """ if filename is None: filename = f"tree_{self.tree.uid:06d}.pdf" if self.graph is None: self._plot() suffix = filename[filename.rfind(".")+1:] func = getattr(self.graph, f"write_{suffix}", None) if func is None: raise RuntimeError(f"Cannot save to file format: {suffix}.") func(filename) return filename
def _plot(self): self.graph = pydot.Dot(graph_type='graph', **self.dot_kwargs) self._plot_ancestors(self.tree) def _plot_ancestors(self, halo): graph = self.graph my_node = self._plot_node(halo) ancestors = list(halo.ancestors) if ancestors is None: return for anc in ancestors: if self.min_mass is not None and \ anc['mass'] < self.min_mass: continue if self.min_mass_ratio is not None and \ anc['mass'] / anc.root['mass'] < self.min_mass_ratio: continue anc_node = self._plot_node(anc) if self.edge_function is not None: edge_kwargs = self.edge_function(anc, halo) else: edge_kwargs = {"penwidth": 5} graph.add_edge(pydot.Edge(my_node, anc_node, **edge_kwargs)) self._plot_ancestors(anc) def _plot_node(self, halo): graph = self.graph node_name = f"{halo.uid}" my_node = graph.get_node(node_name) if len(my_node) == 0: if self.node_function is not None: node_kwargs = self.node_function(halo) else: prog_ids = halo.find_root()['prog', 'uid'] if halo['uid'] in prog_ids: color = 'red' else: color = 'black' node_kwargs = \ {'style': 'filled', 'label': '', 'fillcolor': color, 'shape': 'circle', 'fixedsized': 'true', 'width': self._size_norm(halo)} my_node = pydot.Node( node_name, **node_kwargs) graph.add_node(my_node) else: my_node = my_node[0] return my_node def _size_norm(self, halo): if self._min_field_size is None: tdata = self.tree['tree', self.size_field] if self.size_log: self._min_field_size = tdata[tdata > 0].min() else: self._min_field_size = tdata.min() nmin = self._min_field_size if self._max_field_size is None: tdata = self.tree['tree', self.size_field] self._max_field_size = tdata.max() nmax = self._max_field_size fval = halo[self.size_field] if self.size_log: val = np.log(fval / nmin) / np.log(nmax / nmin) else: val = (fval - nmin) / (nmax - nmin) val = np.clip(float(val), 0, 1) size = val * (self._max_dot_size - self._min_dot_size) + \ self._min_dot_size return size @property def min_mass(self): """ The minimum halo mass to be included in the plot. """ return self._min_mass @min_mass.setter @clear_graph def min_mass(self, val): if not isinstance(val, unyt_quantity): val = unyt_quantity(val, 'Msun') self._min_mass = val @property def min_mass_ratio(self): """ The minimum halo mass to main halo mass. """ return self._min_mass_ratio @min_mass_ratio.setter @clear_graph def min_mass_ratio(self, val): self._min_mass_ratio = val @property def size_field(self): """ The field to determine the size of each circle. """ return self._size_field @size_field.setter @clear_graph def size_field(self, val): self._size_field = val @property def size_log(self): """ Whether to scale circle sizes based on log of size field. """ return self._size_log @size_log.setter @clear_graph def size_log(self, val): self._size_log = val