Source code for ytree.frontends.lhalotree_hdf5.io

"""
LHaloTreeHDF5Arbor io classes and member functions



"""

#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import h5py
import numpy as np
import re

from ytree.data_structures.io import \
    DataFile, \
    TreeFieldIO

[docs]class LHaloTreeHDF5DataFile(DataFile):
[docs] def __init__(self, filename, linkname): super().__init__(filename) self.open() self._size = self.fh["Header"].attrs["NtreesPerFile"] self.close()
def open(self): self.fh = h5py.File(self.filename, mode="r") def close(self): self.fh.close() self.fh = None
[docs]class LHaloTreeHDF5TreeFieldIO(TreeFieldIO): def _read_fields(self, root_node, fields, dtypes=None, root_only=False): """ Read fields from disk for a single tree. """ fi = self.arbor.field_info afields = [field for field in fields if fi[field].get("source") == "arbor"] rfields = list(set(fields).difference(afields)) for afield in afields: rfields.extend( [dfield for dfield in fi[afield].get("dependencies", []) if dfield not in rfields]) data_file = self.arbor.data_files[root_node._fi] close = False if data_file.fh is None: close = True data_file.open() fh = data_file.fh g = fh[f"Tree{root_node._si}"] if root_only: index = slice(0, 1) else: index = () field_cache = {} field_data = {} freg = re.compile(r"(^.+)_(\d+$)") for field in rfields: fs = freg.search(field) if fs and fs.groups()[0] in g: fieldname, ifield = fs.groups() ifield = int(ifield) if fieldname not in field_cache: field_cache[fieldname] = g[fieldname][index] field_data[field] = field_cache[fieldname][:, ifield] else: field_data[field] = g[field][index] if afields: field_data.update(self._get_arbor_fields( root_node, field_data, fields, afields, root_only)) if close: data_file.close() self._apply_units(rfields, field_data) return field_data def _get_arbor_fields(self, root_node, field_data, fields, afields, root_only): """ Generate special fields from the arbor/treenode. """ adata = {} if "uid" in afields: if root_only: adata["uid"] = np.array([root_node.uid]) else: adata["uid"] = root_node.uid + \ np.arange(root_node._tree_size) if "desc_uid" in afields: if "Descendant" in fields: desc_uids = field_data["Descendant"].copy() else: desc_uids = field_data.pop("Descendant") desc_uids[desc_uids != -1] += root_node.uid adata["desc_uid"] = desc_uids return adata