Source code for ytree.frontends.csv.arbor

"""
CSVArbor class and member functions



"""

#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

from collections import defaultdict
import numpy as np

from yt.funcs import get_pbar

from ytree.data_structures.arbor import CatalogArbor
from ytree.data_structures.tree_node import TreeNode

from ytree.frontends.csv.io import CSVDataFile

from ytree.utilities.io import f_text_block

from numpy.dtypes import StringDType

field_data_types = {
    "FLOAT": float,
    "INT": int,
    "STR": StringDType(),
}

[docs] class CSVArbor(CatalogArbor): """ Arbors loaded from consistent-trees tree_*.dat files. """ _data_file_class = CSVDataFile _has_uids = True _default_dtype = np.float32 _node_con_attrs = () _node_io_attrs = () _ancestors = None
[docs] def __init__(self, filename, sep=","): self.sep = sep super().__init__(filename)
def _get_data_files(self): self.data_files = [CSVDataFile(self.filename, self)] def _parse_parameter_file(self, filename=None): if filename is None: filename = self.filename with open(filename, mode="r") as f: ldata = [] for i in range(3): ldata.append(f.readline().strip()[1:].split(self.sep)) self._hoffset = f.tell() lens = [len(_) for _ in ldata] if min(lens) != max(lens): raise RuntimeError( "Header lines must have same number of values.") fields, dtypes, units = ldata fi = {} for i, (field, dtype, unit) in enumerate(zip(fields, dtypes, units)): my_unit = "" if unit == "None" else unit fi[field] = {"column": i, "units": my_unit, "dtype": field_data_types[dtype]} try: fi["uid"]["dtype"] = np.int64 fi["desc_uid"]["dtype"] = np.int64 except KeyError: raise IOError(f"{filename} is missing either uid or desc_uid fields.") self.field_list = list(fi.keys()) self.field_info.update(fi) def _plant_trees(self): if self.is_planted or self._size == 0: return data_file = self.data_files[0] data_file.open() data_file.fh.seek(self._hoffset) fi = self.field_info col_uid = fi["uid"]["column"] typ_uid = fi["uid"]["dtype"] col_des = fi["desc_uid"]["column"] typ_des = fi["desc_uid"]["dtype"] roots = [] root_offsets = [] # these two are for nonroots only desc_uids = {} offsets = {} for line, loc in f_text_block( data_file.fh, pbar_string="Loading tree roots"): online = line.split(self.sep) uid = typ_uid(online[col_uid]) desc_uid = typ_des(online[col_des]) if desc_uid == -1: roots.append(uid) root_offsets.append(loc) else: desc_uids[uid] = desc_uid offsets[uid] = loc n_nonroots = len(desc_uids) uids = np.array(list(desc_uids.keys()) + roots) nr_desc_uids = np.array(list(desc_uids.values())) # look for any nodes with missing descendents missing = np.isin(nr_desc_uids, uids, invert=True) # need to address only up to n_nonroots since # uids has everything and desc_uids has only nonroots missing_uids = uids[:n_nonroots][missing] for uid in missing_uids: del desc_uids[uid] trees = [] for uid, offset in zip(roots, root_offsets): my_node = TreeNode(uid, arbor=self, root=True) my_node._offset = offset my_node.data_file = self.data_files[0] trees.append(my_node) for uid in missing_uids: offset = offsets.pop(uid) my_node = TreeNode(uid, arbor=self, root=True) my_node._offset = offset my_node.data_file = self.data_files[0] trees.append(my_node) self._size = len(trees) self._trees = np.array(trees) # these are for the nonroot nodes # we will construct them later when they are needed self._desc_uids = desc_uids self._offsets = offsets def _build_trees(self): """ Resolve the dictionary of desc_uids into trees. """ # If we have already done this, get out of here. if self._ancestors is not None: return if self.size < 1: return desc_uids = getattr(self, "_desc_uids", None) if desc_uids is None: return offsets = getattr(self, "_offsets") ancestors = defaultdict(list) pbar = get_pbar("Building trees", len(desc_uids)) for i, (uid, desc_uid) in enumerate(desc_uids.items()): my_node = TreeNode(uid, arbor=self) my_node._offset = offsets[uid] my_node.data_file = self.data_files[0] ancestors[desc_uid].append(my_node) pbar.update(i+1) pbar.finish() self._ancestors = ancestors del self._desc_uids, self._offsets def _grow_tree(self, tree_node): """ We don't actually attach the tree's _ancestors during planting, so do it here. """ self._get_ancestors(tree_node) def _get_ancestors(self, tree_node): """ Get this tree's ancestors from the main ancestor dictionary. This arbor stores every tree's ancestors in a dictionary. When we finally need it, grab it from the arbor and attach it to the tree. """ self._build_trees() if tree_node._ancestors is None: tree_node._ancestors = self._ancestors.pop(tree_node.uid, []) for node in tree_node._ancestors: node._descendent = tree_node self._get_ancestors(node) def _setup_tree(self, tree_node): if self.is_setup(tree_node): return self._get_ancestors(tree_node) super()._setup_tree(tree_node) @classmethod def _is_valid(self, *args, **kwargs): """ File should end in .csv and start with three lines of the following format. #<fieldname>,<fieldname>,<fieldname> #<type>,<type>,<type> #<units>,<units>,<units> For example: #uid,desc_uid,mass,redshift #int,int,float,float #None,None,"Msun",None """ fn = args[0] sep = kwargs.get("sep", ",") if not fn.endswith(".csv"): return False with open(fn, "r") as f: for i in range(3): line = f.readline() if not line.startswith("#"): return False if sep not in line: return False return True