Source code for ytree.frontends.moria.io

"""
MoriaArbor io classes and member functions



"""

#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

import h5py
import numpy as np
import re

from ytree.data_structures.io import \
    DefaultRootFieldIO, \
    DataFile, \
    TreeFieldIO

[docs]class MoriaDataFile(DataFile): field_cache = None full_read = False fh = None def open(self): self.fh = h5py.File(self.filename, mode="r") def close(self): self.fh.close() self.fh = None def read_data(self, field, index): if self.full_read: if self.field_cache is None: self.field_cache = {} if field not in self.field_cache: self.field_cache[field] = self.fh[field][()] return self.field_cache[field][index] else: return self.fh[field][index]
[docs]class MoriaTreeFieldIO(TreeFieldIO): def get_fields(self, data_object, fields=None, **kwargs): """ Call _setup_tree if asking for desc_uid so we can correct it. """ if fields is not None and "desc_uid" in fields: self.arbor._setup_tree(data_object) super().get_fields(data_object, fields=fields, **kwargs) def _read_fields(self, root_node, fields, dtypes=None, root_only=False): """ Read fields from disk for a single tree. """ fi = self.arbor.field_info afields = [field for field in fields if fi[field].get("source") == "arbor"] rfields = list(set(fields).difference(afields)) for afield in afields: rfields.extend( [dfield for dfield in fi[afield].get("dependencies", []) if dfield not in rfields]) data_file = self.arbor.data_files[0] close = False if data_file.fh is None: close = True data_file.open() if root_only: index = (-1, slice(root_node._si, root_node._si+1)) dfilter = None else: index = (slice(None), slice(root_node._si, root_node._ei)) if not hasattr(root_node, "_status"): status = data_file.read_data("status_sparta", index) status = self._transform_data(status) root_node._status = np.where(status != 0)[0] dfilter = root_node._status # this field cache is for temporarily storing vector field data field_cache = {} field_data = {} freg = re.compile(r"(^.+)_(\d+$)") for field in rfields: if fi[field].get("vector", False): fs = freg.search(field) fieldname, ifield = fs.groups() ifield = int(ifield) if fieldname not in field_cache: field_cache[fieldname] = data_file.read_data(fieldname, index) data = field_cache[fieldname][..., ifield] else: data = data_file.read_data(field, index) field_data[field] = self._transform_data( data, my_filter=dfilter) if afields: field_data.update(self._get_arbor_fields( root_node, field_data, fields, afields, root_only, my_filter=dfilter)) if close: data_file.close() self._apply_units(rfields, field_data) return field_data def _transform_data(self, data, my_filter=None): data = np.flip(data, axis=0).flatten() if my_filter is not None: data = data[my_filter] return data def _get_arbor_fields(self, root_node, field_data, fields, afields, root_only, my_filter=None): """ Generate special fields from the arbor/treenode. """ adata = {} if "snap_index" in fields: if root_only: adata["snap_index"] = \ np.array([self.arbor._redshifts.size-1], dtype=int) else: data, _ = np.mgrid[:self.arbor._redshifts.size, root_node._si:root_node._ei] adata["snap_index"] = self._transform_data( data, my_filter=my_filter) return adata
class MoriaRootFieldIO(DefaultRootFieldIO): def _read_fields(self, storage_object, fields, dtypes=None): self.arbor._plant_trees() if dtypes is None: dtypes = {} fi = self.arbor.field_info afields = [field for field in fields if fi[field].get("source") == "arbor"] rfields = list(set(fields).difference(afields)) for afield in afields: rfields.extend( [dfield for dfield in fi[afield].get("dependencies", []) if dfield not in rfields]) data_file = self.arbor.data_files[0] data_file.open() fh = data_file.fh index = self.arbor._node_info['_si'] field_cache = {} field_data = {} freg = re.compile(r"(^.+)_(\d+$)") for field in rfields: if fi[field].get("vector", False): fs = freg.search(field) fieldname, ifield = fs.groups() ifield = int(ifield) if fieldname not in field_cache: field_cache[fieldname] = fh[fieldname][-1][index] data = field_cache[fieldname][:, ifield] else: data = fh[field][-1][index] dtype = dtypes.get(field) if dtype is not None: data = data.astype(dtype) field_data[field] = data self._apply_units(rfields, field_data) if afields: field_data.update(self._get_arbor_fields( field_data, fields, afields)) data_file.close() return field_data def _get_arbor_fields(self, field_data, fields, afields): adata = {} if "snap_index" in fields: adata["snap_index"] = \ np.full(self.arbor.size, self.arbor._redshifts.size-1) return adata