Source code for ytree.arbor.frontends.tree_farm.io

"""
TreeFarmArbor io classes and member functions



"""

#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import h5py
import numpy as np

from ytree.arbor.io import \
    CatalogDataFile, \
    TreeFieldIO

[docs]class TreeFarmDataFile(CatalogDataFile): _default_dtype = np.float64 def open(self): self.fh = h5py.File(self.filename, "r") def _parse_header(self): self.open() fh = self.fh self.redshift = fh.attrs["current_redshift"] self.nhalos = fh.attrs["num_halos"] # Files with no halos won't have the units. # Keep trying until we get one. if not hasattr(self.arbor, "field_list"): self._setup_field_info(fh) self.close() def _setup_field_info(self, fh): fields = list(fh.keys()) fi = {} for field in fields: if fh[field].size == 0: # Zero-sized arrays won't have units, so don't bother. return units = fh[field].attrs["units"] if isinstance(units, bytes): units = units.decode("utf") fi[field] = {"source": "file", "units": units} fields.append("redshift") fi["redshift"] = {"source": "header", "units": ""} self.arbor.field_list = fields self.arbor.field_info.update(fi) def _read_fields(self, fields, tree_nodes=None, dtypes=None): if dtypes is None: dtypes = {} fi = self.arbor.field_info afields = [field for field in fields if fi[field].get("source") == "arbor"] hfields = [field for field in fields if fi[field].get("source") == "header"] rfields = set(fields).difference(afields + hfields) hfield_values = dict((field, getattr(self, field)) for field in hfields) if tree_nodes is None: ntrees = self.nhalos self.open() fh = self.fh field_data = dict((field, fh[field].value) for field in rfields) self.close() else: ntrees = len(tree_nodes) file_ids = np.array([node._fi for node in tree_nodes]) field_data = {} # fields from arbor-related info if afields: for field in afields: field_data[field] = \ np.empty(ntrees, dtype=dtypes.get(field, self._default_dtype)) for i in range(ntrees): for field in afields: field_data[field][i] = getattr(tree_nodes[i], field) if rfields: self.open() fh = self.fh for field in rfields: field_data[field] = fh[field].value[file_ids] self.close() for field in hfields: field_data[field] = hfield_values[field] * \ np.ones(ntrees, dtypes.get(field, self._default_dtype)) for field in dtypes: if field in hfields: continue field_data[field] = field_data[field].astype(dtypes[field]) return field_data
[docs]class TreeFarmTreeFieldIO(TreeFieldIO): _default_dtype = np.float64