"""
TreeNode class and member functions
"""
#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------
import numpy as np
import weakref
from ytree.data_structures.fields import \
FieldContainer
from ytree.utilities.exceptions import \
ArborUnsettableField
[docs]class TreeNode:
"""
Class for objects stored in Arbors.
Each TreeNode represents a halo in a tree. A TreeNode knows
its halo ID, the level in the tree, and its global ID in the
Arbor that holds it. It also has a list of its ancestors.
Fields can be queried for it, its progenitor list, and the
tree beneath.
"""
_link = None
[docs] def __init__(self, uid, arbor=None, root=False):
"""
Initialize a TreeNode with at least its halo catalog ID and
its level in the tree.
"""
self.uid = uid
self.arbor = weakref.proxy(arbor)
if root:
self.root = -1
self.field_data = FieldContainer(arbor)
else:
self.root = None
_tree_id = None # used by CatalogArbor
@property
def tree_id(self):
"""
Return the index of this node in a list of all nodes in the tree.
"""
if self.is_root:
return 0
elif self._link is not None:
return self._link.tree_id
else:
return self._tree_id
@tree_id.setter
def tree_id(self, value):
"""
Set the tree_id manually in CatalogArbors.
"""
self._tree_id = value
@property
def is_root(self):
"""
Is this node the last in the tree?
"""
return self.root in [-1, self]
def find_root(self):
"""
Find the root node.
"""
if self.is_root:
return self
root = self.root
if root is not None:
return root
return self.walk_to_root()
def walk_to_root(self):
"""
Walk descendents until root.
"""
my_node = self
while not my_node.is_root:
if my_node.descendent in (-1, None):
break
my_node = my_node.descendent
return my_node
def clear_fields(self):
"""
If a root node, delete field data.
If not root node, do nothing.
"""
if not self.is_root:
return
self.field_data.clear()
_descendent = None # used by CatalogArbor
@property
def descendent(self):
"""
Return the descendent node.
"""
if self.is_root:
return None
# set in CatalogArbor._plant_trees
if self._descendent is not None:
return self._descendent
# conventional Arbor object
desc_link = self._link.descendent
if desc_link is None:
return None
return self.arbor._generate_tree_node(self.root, desc_link)
_ancestors = None # used by CatalogArbor
@property
def ancestors(self):
"""
Return a generator of ancestor nodes.
"""
self.arbor._grow_tree(self)
# conventional Arbor object
if self._link is not None:
for link in self._link.ancestors:
yield self.arbor._generate_tree_node(self.root, link)
return
# If tree is not setup yet, the ancestor nodes will not have
# root pointers yet.
need_root = not self.arbor.is_setup(self)
if need_root:
root = self.walk_to_root()
# set in CatalogArbor._plant_trees
if self._ancestors is not None:
for ancestor in self._ancestors:
if need_root:
ancestor.root = root
yield ancestor
return
return None
_uids = None
@property
def uids(self):
"""
Array of uids for all nodes in the tree.
"""
if not self.is_root:
return None
if self._uids is None:
self.arbor._build_attr("_uids", self)
return self._uids
_desc_uids = None
@property
def desc_uids(self):
"""
Array of descendent uids for all nodes in the tree.
"""
if not self.is_root:
return None
if self._desc_uids is None:
self.arbor._build_attr("_desc_uids", self)
return self._desc_uids
_tree_size = None
@property
def tree_size(self):
"""
Number of nodes in the tree.
"""
if self._tree_size is not None:
return self._tree_size
if self.is_root:
self.arbor._setup_tree(self)
# pass back to the arbor to avoid calculating again
self.arbor._store_node_info(self, '_tree_size')
else:
self._tree_size = len(list(self["tree"]))
return self._tree_size
_link_storage = None
@property
def _links(self):
"""
Array of NodeLink objects with the ancestor/descendent structure.
This is only used by conventional Arbor objects, i.e., not
CatalogArbor objects.
"""
if not self.is_root:
return None
if self._link_storage is None:
self.arbor._build_attr("_link_storage", self)
return self._link_storage
def __setitem__(self, key, value):
"""
Set analysis field value for this node.
"""
fi = self.arbor.field_info[key]
ftype = fi.get('type')
if ftype not in ['analysis', 'analysis_saved']:
raise ArborUnsettableField(key, self.arbor)
vector_fieldname = fi.get("vector_fieldname", None)
has_vector_field = vector_fieldname is not None
if self.is_root:
root = self
tree_id = 0
# if root, set the value in the arbor field storage
self.arbor[key][self._arbor_index] = value
if has_vector_field and vector_fieldname in self.arbor.field_data:
del self.arbor.field_data[vector_fieldname]
else:
root = self.root
tree_id = self.tree_id
self.arbor._node_io.get_fields(self, fields=[key],
root_only=False)
data = root.field_data[key]
data[tree_id] = value
if has_vector_field and vector_fieldname in root.field_data:
del root.field_data[vector_fieldname]
def __getitem__(self, key):
"""
Return field values or tree/prog generators.
"""
return self.query(key)
def query(self, key):
"""
Return field values for this TreeNode, progenitor list, or tree.
Parameters
----------
key : string or tuple
If a single string, it can be either a field to be queried or
one of "tree" or "prog". If a field, then return the value of
the field for this TreeNode. If "tree" or "prog", then return
the list of TreeNodes in the tree or progenitor list.
If a tuple, this can be either (string, string) or (string, int),
where the first argument must be either "tree" or "prog".
If second argument is a string, then return the field values
for either the tree or the progenitor list. If second argument
is an int, then return the nth TreeNode in the tree or progenitor
list list.
Examples
--------
>>> # virial mass for this halo
>>> print (my_tree["mvir"].to("Msun/h"))
>>> # all TreeNodes in the progenitor list
>>> print (my_tree["prog"])
>>> # all TreeNodes in the entire tree
>>> print (my_tree["tree"])
>>> # virial masses for the progenitor list
>>> print (my_tree["prog", "mvir"].to("Msun/h"))
>>> # the 3rd TreeNode in the progenitor list
>>> print (my_tree["prog", 2])
Returns
-------
float, ndarray/unyt_array, TreeNode
"""
arr_types = ("forest", "prog", "tree")
if isinstance(key, tuple):
if len(key) != 2:
raise SyntaxError(
"Must be either 1 or 2 arguments.")
ftype, field = key
if ftype not in arr_types:
raise SyntaxError(
f"First argument must be one of {str(arr_types)}.")
if not isinstance(field, str):
raise SyntaxError("Second argument must be a string.")
self.arbor._node_io.get_fields(self, fields=[field], root_only=False)
indices = getattr(self, f"_{ftype}_field_indices")
data_object = self.find_root()
return data_object.field_data[field][indices]
else:
if not isinstance(key, str):
raise SyntaxError("Single argument must be a string.")
# return the progenitor list or tree nodes in a list
if key in arr_types:
self.arbor._setup_tree(self)
return getattr(self, f"_{key}_nodes")
# return field value for this node
self.arbor._node_io.get_fields(self, fields=[key],
root_only=self.is_root)
data_object = self.find_root()
return data_object.field_data[key][self.tree_id]
def __repr__(self):
"""
Call me TreeNode.
"""
return f"TreeNode[{self.uid}]"
[docs] def get_node(self, selector, index):
"""
Get a single TreeNode from a tree.
Use this to get the nth TreeNode from a forest, tree, or
progenitor list for which the calling TreeNode is the head.
Parameters
----------
selector : str ("forest", "tree", or "prog")
The tree selector from which to get the TreeNode. This
should be "forest", "tree", or "prog".
index : int
The index of the desired TreeNode in the forest, tree,
or progenitor list.
Returns
-------
node: :class:`~ytree.data_structures.tree_node.TreeNode`
Examples
--------
>>> import ytree
>>> a = ytree.load("tiny_ctrees/locations.dat")
>>> my_tree = a[0]
>>> # get 6th TreeNode in the progenitor list
>>> my_node = my_tree.get_node('prog', 5)
"""
self.arbor._setup_tree(self)
self.arbor._grow_tree(self)
indices = getattr(self, f"_{selector}_field_indices", None)
if indices is None:
raise RuntimeError("Bad selector.")
my_link = self.root._links[indices][index]
return self.arbor._generate_tree_node(self.root, my_link)
[docs] def get_leaf_nodes(self, selector=None):
"""
Get all leaf nodes from the tree of which this is the head.
This returns a generator of all leaf nodes belonging to this
tree. A leaf node is a node that has no ancestors.
Parameters
----------
selector : optional, str ("forest", "tree", or "prog")
The tree selector from which leaf nodes will be found.
If none given, this will be set to "forest" if the
calling node is a root node and "tree" otherwise.
Returns
-------
leaf_nodes : a generator of
:class:`~ytree.data_structures.tree_node.TreeNode` objects.
Examples
--------
>>> import ytree
>>> a = ytree.load("tiny_ctrees/locations.dat")
>>> my_tree = a[0]
>>> for leaf in my_tree.get_leaf_nodes():
... print (leaf["mass"])
"""
if selector is None:
if self.is_root:
selector = "forest"
else:
selector = "tree"
uids = self[selector, "uid"]
desc_uids = self[selector, "desc_uid"]
lids = np.where(~np.in1d(uids, desc_uids))[0]
for lid in lids:
yield self.get_node(selector, lid)
[docs] def get_root_nodes(self):
"""
Get all root nodes from the forest to which this node belongs.
This returns a generator of all root nodes in the forest. A root
node is a node that has no descendents.
Returns
-------
root_nodes : a generator of
:class:`~ytree.data_structures.tree_node.TreeNode` objects.
Examples
--------
>>> import ytree
>>> a = ytree.load("consistent_trees_hdf5/soa/forest.h5",
... access="forest")
>>> my_tree = a[0]
>>> for root in my_tree.get_root_nodes():
... print (root["mass"])
"""
selector = "forest"
desc_uids = self[selector, "desc_uid"]
rids = np.where(desc_uids == -1)[0]
for rid in rids:
yield self.get_node(selector, rid)
_ffi = slice(None)
@property
def _forest_field_indices(self):
"""
Return default slice to select the whole forest.
"""
return self._ffi
@property
def _forest_nodes(self):
"""
An iterator over all TreeNodes in the forest.
This is different from _tree_nodes in that we don't walk
through the ancestors lists. We just yield every TreeNode
there is.
"""
self.arbor._grow_tree(self)
root = self.root
for link in root._links:
yield self.arbor._generate_tree_node(self.root, link)
@property
def _tree_nodes(self):
"""
An iterator over all TreeNodes in the tree beneath,
starting with this TreeNode.
For internal use only. Use the following instead:
>>> for my_node in my_tree['tree']:
... print (my_node)
Examples
--------
>>> for my_node in my_tree._tree_nodes:
... print (my_node)
"""
self.arbor._grow_tree(self)
yield self
if self.ancestors is None:
return
for ancestor in self.ancestors:
for a_node in ancestor._tree_nodes:
yield a_node
_tfi = None
@property
def _tree_field_indices(self):
"""
Return the field array indices for all TreeNodes in
the tree beneath, starting with this TreeNode.
"""
if self._tfi is not None:
return self._tfi
self.arbor._grow_tree(self)
self._tfi = np.array([node.tree_id for node in self._tree_nodes])
return self._tfi
@property
def _prog_nodes(self):
"""
An iterator over all TreeNodes in the progenitor list,
starting with this TreeNode.
For internal use only. Use the following instead:
>>> for my_node in my_tree['prog']:
... print (my_node)
Examples
--------
>>> for my_node in my_tree._prog_nodes:
... print (my_node)
"""
self.arbor._grow_tree(self)
my_node = self
while my_node is not None:
yield my_node
ancestors = list(my_node.ancestors)
if ancestors:
my_node = my_node.arbor.selector(ancestors)
else:
my_node = None
_pfi = None
@property
def _prog_field_indices(self):
"""
Return the field array indices for all TreeNodes in
the progenitor list, starting with this TreeNode.
"""
if self._pfi is not None:
return self._pfi
self.arbor._grow_tree(self)
self._pfi = np.array([node.tree_id for node in self._prog_nodes])
return self._pfi
[docs] def save_tree(self, filename=None, fields=None):
r"""
Save the tree to a file.
The saved tree can be re-loaded as an arbor.
Parameters
----------
filename : optional, string
Output file keyword. Main header file will be named
<filename>/<filename>.h5.
Default: "tree_<uid>".
fields : optional, list of strings
The fields to be saved. If not given, all
fields will be saved.
Returns
-------
filename : string
The filename of the saved arbor.
Examples
--------
>>> import ytree
>>> a = ytree.load("rockstar_halos/trees/tree_0_0_0.dat")
>>> # save the first tree
>>> fn = a[0].save_tree()
>>> # reload it
>>> a2 = ytree.load(fn)
"""
if filename is None:
filename = f"tree_{self.uid}"
return self.arbor.save_arbor(
filename=filename, fields=fields,
trees=[self])