diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e57be1df177..aa499731f4a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,7 +49,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ - +- Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) + By `Matt Savoie `_ and `Tom Nicholas + `_. .. _whats-new.2024.02.0: @@ -145,9 +147,11 @@ Internal Changes ``xarray/namedarray``. (:pull:`8319`) By `Tom Nicholas `_ and `Anderson Banihirwe `_. - Imports ``datatree`` repository and history into internal location. (:pull:`8688`) - By `Matt Savoie `_ and `Justus Magin `_. + By `Matt Savoie `_, `Justus Magin `_ + and `Tom Nicholas `_. - Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`) - By `Matt Savoie `_. + By `Matt Savoie `_ and `Tom Nicholas + `_. - Refactor :py:meth:`xarray.core.indexing.DaskIndexingAdapter.__getitem__` to remove an unnecessary rewrite of the indexer key (:issue: `8377`, :pull:`8758`) By `Anderson Banihirwe `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 6245b3442a3..7d3cc00a52d 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -137,8 +137,8 @@ def _open_datatree_netcdf( **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset + from xarray.core.treenode import NodePath from xarray.datatree_.datatree import DataTree - from xarray.datatree_.datatree.treenode import NodePath ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) @@ -159,7 +159,7 @@ def _open_datatree_netcdf( def _iter_nc_groups(root, parent="/"): - from xarray.datatree_.datatree.treenode import NodePath + from xarray.core.treenode import NodePath parent = NodePath(parent) for path, group in root.groups.items(): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index ac208da097a..e9465dc0ba0 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1048,8 +1048,8 @@ def open_datatree( import zarr from xarray.backends.api import open_dataset + from xarray.core.treenode import NodePath from xarray.datatree_.datatree import DataTree - from xarray.datatree_.datatree.treenode import NodePath zds = zarr.open_group(filename_or_obj, mode="r") ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) @@ -1075,7 +1075,7 @@ def open_datatree( def _iter_zarr_groups(root, parent="/"): - from xarray.datatree_.datatree.treenode import NodePath + from xarray.core.treenode import NodePath parent = NodePath(parent) for path, group in root.groups(): diff --git a/xarray/datatree_/datatree/treenode.py b/xarray/core/treenode.py similarity index 90% rename from xarray/datatree_/datatree/treenode.py rename to xarray/core/treenode.py index 1689d261c34..b3e6e43f306 100644 --- a/xarray/datatree_/datatree/treenode.py +++ b/xarray/core/treenode.py @@ -1,17 +1,12 @@ from __future__ import annotations import sys -from collections import OrderedDict +from collections.abc import Iterator, Mapping from pathlib import PurePosixPath from typing import ( TYPE_CHECKING, Generic, - Iterator, - Mapping, - Optional, - Tuple, TypeVar, - Union, ) from xarray.core.utils import Frozen, is_dict_like @@ -25,7 +20,7 @@ class InvalidTreeError(Exception): class NotFoundInTreeError(ValueError): - """Raised when operation can't be completed because one node is part of the expected tree.""" + """Raised when operation can't be completed because one node is not part of the expected tree.""" class NodePath(PurePosixPath): @@ -55,8 +50,8 @@ class TreeNode(Generic[Tree]): This class stores no data, it has only parents and children attributes, and various methods. - Stores child nodes in an Ordered Dictionary, which is necessary to ensure that equality checks between two trees - also check that the order of child nodes is the same. + Stores child nodes in an dict, ensuring that equality checks between trees + and order of child nodes is preserved (since python 3.7). Nodes themselves are intrinsically unnamed (do not possess a ._name attribute), but if the node has a parent you can find the key it is stored under via the .name property. @@ -73,15 +68,16 @@ class TreeNode(Generic[Tree]): Also allows access to any other node in the tree via unix-like paths, including upwards referencing via '../'. (This class is heavily inspired by the anytree library's NodeMixin class.) + """ - _parent: Optional[Tree] - _children: OrderedDict[str, Tree] + _parent: Tree | None + _children: dict[str, Tree] - def __init__(self, children: Optional[Mapping[str, Tree]] = None): + def __init__(self, children: Mapping[str, Tree] | None = None): """Create a parentless node.""" self._parent = None - self._children = OrderedDict() + self._children = {} if children is not None: self.children = children @@ -91,7 +87,7 @@ def parent(self) -> Tree | None: return self._parent def _set_parent( - self, new_parent: Tree | None, child_name: Optional[str] = None + self, new_parent: Tree | None, child_name: str | None = None ) -> None: # TODO is it possible to refactor in a way that removes this private method? @@ -127,17 +123,15 @@ def _detach(self, parent: Tree | None) -> None: if parent is not None: self._pre_detach(parent) parents_children = parent.children - parent._children = OrderedDict( - { - name: child - for name, child in parents_children.items() - if child is not self - } - ) + parent._children = { + name: child + for name, child in parents_children.items() + if child is not self + } self._parent = None self._post_detach(parent) - def _attach(self, parent: Tree | None, child_name: Optional[str] = None) -> None: + def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: if parent is not None: if child_name is None: raise ValueError( @@ -167,7 +161,7 @@ def children(self: Tree) -> Mapping[str, Tree]: @children.setter def children(self: Tree, children: Mapping[str, Tree]) -> None: self._check_children(children) - children = OrderedDict(children) + children = {**children} old_children = self.children del self.children @@ -242,7 +236,7 @@ def _iter_parents(self: Tree) -> Iterator[Tree]: yield node node = node.parent - def iter_lineage(self: Tree) -> Tuple[Tree, ...]: + def iter_lineage(self: Tree) -> tuple[Tree, ...]: """Iterate up the tree, starting from the current node.""" from warnings import warn @@ -254,7 +248,7 @@ def iter_lineage(self: Tree) -> Tuple[Tree, ...]: return tuple((self, *self.parents)) @property - def lineage(self: Tree) -> Tuple[Tree, ...]: + def lineage(self: Tree) -> tuple[Tree, ...]: """All parent nodes and their parent nodes, starting with the closest.""" from warnings import warn @@ -266,12 +260,12 @@ def lineage(self: Tree) -> Tuple[Tree, ...]: return self.iter_lineage() @property - def parents(self: Tree) -> Tuple[Tree, ...]: + def parents(self: Tree) -> tuple[Tree, ...]: """All parent nodes and their parent nodes, starting with the closest.""" return tuple(self._iter_parents()) @property - def ancestors(self: Tree) -> Tuple[Tree, ...]: + def ancestors(self: Tree) -> tuple[Tree, ...]: """All parent nodes and their parent nodes, starting with the most distant.""" from warnings import warn @@ -306,7 +300,7 @@ def is_leaf(self) -> bool: return self.children == {} @property - def leaves(self: Tree) -> Tuple[Tree, ...]: + def leaves(self: Tree) -> tuple[Tree, ...]: """ All leaf nodes. @@ -315,20 +309,18 @@ def leaves(self: Tree) -> Tuple[Tree, ...]: return tuple([node for node in self.subtree if node.is_leaf]) @property - def siblings(self: Tree) -> OrderedDict[str, Tree]: + def siblings(self: Tree) -> dict[str, Tree]: """ Nodes with the same parent as this node. """ if self.parent: - return OrderedDict( - { - name: child - for name, child in self.parent.children.items() - if child is not self - } - ) + return { + name: child + for name, child in self.parent.children.items() + if child is not self + } else: - return OrderedDict() + return {} @property def subtree(self: Tree) -> Iterator[Tree]: @@ -341,12 +333,12 @@ def subtree(self: Tree) -> Iterator[Tree]: -------- DataTree.descendants """ - from . import iterators + from xarray.datatree_.datatree import iterators return iterators.PreOrderIter(self) @property - def descendants(self: Tree) -> Tuple[Tree, ...]: + def descendants(self: Tree) -> tuple[Tree, ...]: """ Child nodes and all their child nodes. @@ -431,7 +423,7 @@ def _post_attach(self: Tree, parent: Tree) -> None: """Method call after attaching to `parent`.""" pass - def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]: + def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None: """ Return the child node with the specified key. @@ -445,7 +437,7 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]: # TODO `._walk` method to be called by both `_get_item` and `_set_item` - def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]: + def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: """ Returns the object lying at the given path. @@ -488,24 +480,26 @@ def _set(self: Tree, key: str, val: Tree) -> None: def _set_item( self: Tree, path: str | NodePath, - item: Union[Tree, T_DataArray], + item: Tree | T_DataArray, new_nodes_along_path: bool = False, allow_overwrite: bool = True, ) -> None: """ Set a new item in the tree, overwriting anything already present at that path. - The given value either forms a new node of the tree or overwrites an existing item at that location. + The given value either forms a new node of the tree or overwrites an + existing item at that location. Parameters ---------- path item new_nodes_along_path : bool - If true, then if necessary new nodes will be created along the given path, until the tree can reach the - specified location. + If true, then if necessary new nodes will be created along the + given path, until the tree can reach the specified location. allow_overwrite : bool - Whether or not to overwrite any existing node at the location given by path. + Whether or not to overwrite any existing node at the location given + by path. Raises ------ @@ -580,9 +574,9 @@ class NamedNode(TreeNode, Generic[Tree]): Implements path-like relationships to other nodes in its tree. """ - _name: Optional[str] - _parent: Optional[Tree] - _children: OrderedDict[str, Tree] + _name: str | None + _parent: Tree | None + _children: dict[str, Tree] def __init__(self, name=None, children=None): super().__init__(children=children) @@ -603,8 +597,14 @@ def name(self, name: str | None) -> None: raise ValueError("node names cannot contain forward slashes") self._name = name + def __repr__(self, level=0): + repr_value = "\t" * level + self.__str__() + "\n" + for child in self.children: + repr_value += self.get(child).__repr__(level + 1) + return repr_value + def __str__(self) -> str: - return f"NamedNode({self.name})" if self.name else "NamedNode()" + return f"NamedNode('{self.name}')" if self.name else "NamedNode()" def _post_attach(self: NamedNode, parent: NamedNode) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index f9fd419bddc..071dcbecf8c 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -2,7 +2,7 @@ from .datatree import DataTree from .extensions import register_datatree_accessor from .mapping import TreeIsomorphismError, map_over_subtree -from .treenode import InvalidTreeError, NotFoundInTreeError +from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/datatree_/datatree/datatree.py index 13cca7de80d..10133052185 100644 --- a/xarray/datatree_/datatree/datatree.py +++ b/xarray/datatree_/datatree/datatree.py @@ -50,7 +50,7 @@ MappedDataWithCoords, ) from .render import RenderTree -from .treenode import NamedNode, NodePath, Tree +from xarray.core.treenode import NamedNode, NodePath, Tree try: from xarray.core.variable import calculate_dimensions diff --git a/xarray/datatree_/datatree/iterators.py b/xarray/datatree_/datatree/iterators.py index 52ed8d22422..68e75c4f612 100644 --- a/xarray/datatree_/datatree/iterators.py +++ b/xarray/datatree_/datatree/iterators.py @@ -2,7 +2,7 @@ from collections import abc from typing import Callable, Iterator, List, Optional -from .treenode import Tree +from xarray.core.treenode import Tree """These iterators are copied from anytree.iterators, with minor modifications.""" diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/datatree_/datatree/mapping.py index 34e227d349d..355149060a9 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/datatree_/datatree/mapping.py @@ -9,10 +9,10 @@ from xarray import DataArray, Dataset from .iterators import LevelOrderIter -from .treenode import NodePath, TreeNode +from xarray.core.treenode import NodePath, TreeNode if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree class TreeIsomorphismError(ValueError): diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py index 8726c95fe62..b58c02282e7 100644 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -108,13 +108,13 @@ def test_diff_node_data(self): Data in nodes at position '/a' do not match: Data variables only on the left object: - v int64 1 + v int64 8B 1 Data in nodes at position '/a/b' do not match: Differing data variables: - L w int64 5 - R w int64 6""" + L w int64 8B 5 + R w int64 8B 6""" ) actual = diff_tree_repr(dt_1, dt_2, "equals") assert actual == expected diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 7c30759e499..8590c9fb4e7 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -6,6 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset +from xarray.datatree_.datatree import DataTree from xarray.tests import create_test_data, requires_dask @@ -136,3 +137,64 @@ def d(request, backend, type) -> DataArray | Dataset: return result else: raise ValueError + + +@pytest.fixture(scope="module") +def create_test_datatree(): + """ + Create a test datatree with this structure: + + + |-- set1 + | |-- + | | Dimensions: () + | | Data variables: + | | a int64 0 + | | b int64 1 + | |-- set1 + | |-- set2 + |-- set2 + | |-- + | | Dimensions: (x: 2) + | | Data variables: + | | a (x) int64 2, 3 + | | b (x) int64 0.1, 0.2 + | |-- set1 + |-- set3 + |-- + | Dimensions: (x: 2, y: 3) + | Data variables: + | a (y) int64 6, 7, 8 + | set0 (x) int64 9, 10 + + The structure has deliberately repeated names of tags, variables, and + dimensions in order to better check for bugs caused by name conflicts. + """ + + def _create_test_datatree(modify=lambda ds: ds): + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) + + # Avoid using __init__ so we can independently test it + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + return root + + return _create_test_datatree + + +@pytest.fixture(scope="module") +def simple_datatree(create_test_datatree): + """ + Invoke create_test_datatree fixture (callback). + + Returns a DataTree. + """ + return create_test_datatree() diff --git a/xarray/tests/datatree/conftest.py b/xarray/tests/datatree/conftest.py deleted file mode 100644 index b341f3007aa..00000000000 --- a/xarray/tests/datatree/conftest.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - -import xarray as xr -from xarray.datatree_.datatree import DataTree - - -@pytest.fixture(scope="module") -def create_test_datatree(): - """ - Create a test datatree with this structure: - - - |-- set1 - | |-- - | | Dimensions: () - | | Data variables: - | | a int64 0 - | | b int64 1 - | |-- set1 - | |-- set2 - |-- set2 - | |-- - | | Dimensions: (x: 2) - | | Data variables: - | | a (x) int64 2, 3 - | | b (x) int64 0.1, 0.2 - | |-- set1 - |-- set3 - |-- - | Dimensions: (x: 2, y: 3) - | Data variables: - | a (y) int64 6, 7, 8 - | set0 (x) int64 9, 10 - - The structure has deliberately repeated names of tags, variables, and - dimensions in order to better check for bugs caused by name conflicts. - """ - - def _create_test_datatree(modify=lambda ds: ds): - set1_data = modify(xr.Dataset({"a": 0, "b": 1})) - set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) - root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) - - # Avoid using __init__ so we can independently test it - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) - - return root - - return _create_test_datatree - - -@pytest.fixture(scope="module") -def simple_datatree(create_test_datatree): - """ - Invoke create_test_datatree fixture (callback). - - Returns a DataTree. - """ - return create_test_datatree() diff --git a/xarray/tests/datatree/test_io.py b/xarray/tests/test_backends_datatree.py similarity index 70% rename from xarray/tests/datatree/test_io.py rename to xarray/tests/test_backends_datatree.py index 4f32e19de4a..7bdb2b532d9 100644 --- a/xarray/tests/datatree/test_io.py +++ b/xarray/tests/test_backends_datatree.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from xarray.backends.api import open_datatree @@ -8,69 +12,66 @@ requires_zarr, ) +if TYPE_CHECKING: + from xarray.backends.api import T_NetcdfEngine + + +class DatatreeIOBase: + engine: T_NetcdfEngine | None = None -class TestIO: - @requires_netCDF4 def test_to_netcdf(self, tmpdir, simple_datatree): - filepath = str( - tmpdir / "test.nc" - ) # casting to str avoids a pathlib bug in xarray + filepath = tmpdir / "test.nc" original_dt = simple_datatree - original_dt.to_netcdf(filepath, engine="netcdf4") + original_dt.to_netcdf(filepath, engine=self.engine) - roundtrip_dt = open_datatree(filepath) + roundtrip_dt = open_datatree(filepath, engine=self.engine) assert_equal(original_dt, roundtrip_dt) - @requires_netCDF4 def test_netcdf_encoding(self, tmpdir, simple_datatree): - filepath = str( - tmpdir / "test.nc" - ) # casting to str avoids a pathlib bug in xarray + filepath = tmpdir / "test.nc" original_dt = simple_datatree # add compression comp = dict(zlib=True, complevel=9) enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}} - original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4") - roundtrip_dt = open_datatree(filepath) + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) + roundtrip_dt = open_datatree(filepath, engine=self.engine) assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"] assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"] enc["/not/a/group"] = {"foo": "bar"} # type: ignore with pytest.raises(ValueError, match="unexpected encoding group.*"): - original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4") + original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine) - @requires_h5netcdf - def test_to_h5netcdf(self, tmpdir, simple_datatree): - filepath = str( - tmpdir / "test.nc" - ) # casting to str avoids a pathlib bug in xarray - original_dt = simple_datatree - original_dt.to_netcdf(filepath, engine="h5netcdf") - roundtrip_dt = open_datatree(filepath) - assert_equal(original_dt, roundtrip_dt) +@requires_netCDF4 +class TestNetCDF4DatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "netcdf4" + + +@requires_h5netcdf +class TestH5NetCDFDatatreeIO(DatatreeIOBase): + engine: T_NetcdfEngine | None = "h5netcdf" + + +@requires_zarr +class TestZarrDatatreeIO: + engine = "zarr" - @requires_zarr def test_to_zarr(self, tmpdir, simple_datatree): - filepath = str( - tmpdir / "test.zarr" - ) # casting to str avoids a pathlib bug in xarray + filepath = tmpdir / "test.zarr" original_dt = simple_datatree original_dt.to_zarr(filepath) roundtrip_dt = open_datatree(filepath, engine="zarr") assert_equal(original_dt, roundtrip_dt) - @requires_zarr def test_zarr_encoding(self, tmpdir, simple_datatree): import zarr - filepath = str( - tmpdir / "test.zarr" - ) # casting to str avoids a pathlib bug in xarray + filepath = tmpdir / "test.zarr" original_dt = simple_datatree comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)} @@ -85,13 +86,10 @@ def test_zarr_encoding(self, tmpdir, simple_datatree): with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_zarr(filepath, encoding=enc, engine="zarr") - @requires_zarr def test_to_zarr_zip_store(self, tmpdir, simple_datatree): from zarr.storage import ZipStore - filepath = str( - tmpdir / "test.zarr.zip" - ) # casting to str avoids a pathlib bug in xarray + filepath = tmpdir / "test.zarr.zip" original_dt = simple_datatree store = ZipStore(filepath) original_dt.to_zarr(store) @@ -99,7 +97,6 @@ def test_to_zarr_zip_store(self, tmpdir, simple_datatree): roundtrip_dt = open_datatree(store, engine="zarr") assert_equal(original_dt, roundtrip_dt) - @requires_zarr def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree): filepath = tmpdir / "test.zarr" zmetadata = filepath / ".zmetadata" @@ -114,7 +111,6 @@ def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree): roundtrip_dt = open_datatree(filepath, engine="zarr") assert_equal(original_dt, roundtrip_dt) - @requires_zarr def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): import zarr diff --git a/xarray/datatree_/datatree/tests/test_treenode.py b/xarray/tests/test_treenode.py similarity index 69% rename from xarray/datatree_/datatree/tests/test_treenode.py rename to xarray/tests/test_treenode.py index 3c75f3ac8a4..b0e737bd317 100644 --- a/xarray/datatree_/datatree/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -1,25 +1,30 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import cast + import pytest +from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode from xarray.datatree_.datatree.iterators import LevelOrderIter, PreOrderIter -from xarray.datatree_.datatree.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode class TestFamilyTree: def test_lonely(self): - root = TreeNode() + root: TreeNode = TreeNode() assert root.parent is None assert root.children == {} def test_parenting(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") assert mary.parent == john assert john.children["Mary"] is mary def test_no_time_traveller_loops(self): - john = TreeNode() + john: TreeNode = TreeNode() with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): john._set_parent(john, "John") @@ -27,8 +32,8 @@ def test_no_time_traveller_loops(self): with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): john.children = {"John": john} - mary = TreeNode() - rose = TreeNode() + mary: TreeNode = TreeNode() + rose: TreeNode = TreeNode() mary._set_parent(john, "Mary") rose._set_parent(mary, "Rose") @@ -39,11 +44,11 @@ def test_no_time_traveller_loops(self): rose.children = {"John": john} def test_parent_swap(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") - steve = TreeNode() + steve: TreeNode = TreeNode() mary._set_parent(steve, "Mary") assert mary.parent == steve @@ -51,24 +56,24 @@ def test_parent_swap(self): assert "Mary" not in john.children def test_multi_child_family(self): - mary = TreeNode() - kate = TreeNode() - john = TreeNode(children={"Mary": mary, "Kate": kate}) + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate}) assert john.children["Mary"] is mary assert john.children["Kate"] is kate assert mary.parent is john assert kate.parent is john def test_disown_child(self): - mary = TreeNode() - john = TreeNode(children={"Mary": mary}) + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) mary.orphan() assert mary.parent is None assert "Mary" not in john.children def test_doppelganger_child(self): - kate = TreeNode() - john = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode() with pytest.raises(TypeError): john.children = {"Kate": 666} @@ -77,22 +82,22 @@ def test_doppelganger_child(self): john.children = {"Kate": kate, "Evil_Kate": kate} john = TreeNode(children={"Kate": kate}) - evil_kate = TreeNode() + evil_kate: TreeNode = TreeNode() evil_kate._set_parent(john, "Kate") assert john.children["Kate"] is evil_kate def test_sibling_relationships(self): - mary = TreeNode() - kate = TreeNode() - ashley = TreeNode() + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + ashley: TreeNode = TreeNode() TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley}) assert kate.siblings["Mary"] is mary assert kate.siblings["Ashley"] is ashley assert "Kate" not in kate.siblings def test_ancestors(self): - tony = TreeNode() - michael = TreeNode(children={"Tony": tony}) + tony: TreeNode = TreeNode() + michael: TreeNode = TreeNode(children={"Tony": tony}) vito = TreeNode(children={"Michael": michael}) assert tony.root is vito assert tony.parents == (michael, vito) @@ -101,7 +106,7 @@ def test_ancestors(self): class TestGetNodes: def test_get_child(self): - steven = TreeNode() + steven: TreeNode = TreeNode() sue = TreeNode(children={"Steven": steven}) mary = TreeNode(children={"Sue": sue}) john = TreeNode(children={"Mary": mary}) @@ -124,8 +129,8 @@ def test_get_child(self): assert mary._get_item("Sue/Steven") is steven def test_get_upwards(self): - sue = TreeNode() - kate = TreeNode() + sue: TreeNode = TreeNode() + kate: TreeNode = TreeNode() mary = TreeNode(children={"Sue": sue, "Kate": kate}) john = TreeNode(children={"Mary": mary}) @@ -136,7 +141,7 @@ def test_get_upwards(self): assert sue._get_item("../Kate") is kate def test_get_from_root(self): - sue = TreeNode() + sue: TreeNode = TreeNode() mary = TreeNode(children={"Sue": sue}) john = TreeNode(children={"Mary": mary}) # noqa @@ -145,8 +150,8 @@ def test_get_from_root(self): class TestSetNodes: def test_set_child_node(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) assert john.children["Mary"] is mary @@ -155,16 +160,16 @@ def test_set_child_node(self): assert mary.parent is john def test_child_already_exists(self): - mary = TreeNode() - john = TreeNode(children={"Mary": mary}) - mary_2 = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary_2: TreeNode = TreeNode() with pytest.raises(KeyError): john._set_item("Mary", mary_2, allow_overwrite=False) def test_set_grandchild(self): - rose = TreeNode() - mary = TreeNode() - john = TreeNode() + rose: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode() john._set_item("Mary", mary) john._set_item("Mary/Rose", rose) @@ -175,8 +180,8 @@ def test_set_grandchild(self): assert rose.parent is mary def test_create_intermediate_child(self): - john = TreeNode() - rose = TreeNode() + john: TreeNode = TreeNode() + rose: TreeNode = TreeNode() # test intermediate children not allowed with pytest.raises(KeyError, match="Could not reach"): @@ -192,12 +197,12 @@ def test_create_intermediate_child(self): assert rose.parent == mary def test_overwrite_child(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) # test overwriting not allowed - marys_evil_twin = TreeNode() + marys_evil_twin: TreeNode = TreeNode() with pytest.raises(KeyError, match="Already a node object"): john._set_item("Mary", marys_evil_twin, allow_overwrite=False) assert john.children["Mary"] is mary @@ -212,8 +217,8 @@ def test_overwrite_child(self): class TestPruning: def test_del_child(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) del john["Mary"] @@ -224,16 +229,25 @@ def test_del_child(self): del john["Mary"] -def create_test_tree(): - a = NamedNode(name="a") - b = NamedNode() - c = NamedNode() - d = NamedNode() - e = NamedNode() - f = NamedNode() - g = NamedNode() - h = NamedNode() - i = NamedNode() +def create_test_tree() -> tuple[NamedNode, NamedNode]: + # a + # ├── b + # │ ├── d + # │ └── e + # │ ├── f + # │ └── g + # └── c + # └── h + # └── i + a: NamedNode = NamedNode(name="a") + b: NamedNode = NamedNode() + c: NamedNode = NamedNode() + d: NamedNode = NamedNode() + e: NamedNode = NamedNode() + f: NamedNode = NamedNode() + g: NamedNode = NamedNode() + h: NamedNode = NamedNode() + i: NamedNode = NamedNode() a.children = {"b": b, "c": c} b.children = {"d": d, "e": e} @@ -247,7 +261,9 @@ def create_test_tree(): class TestIterators: def test_preorderiter(self): root, _ = create_test_tree() - result = [node.name for node in PreOrderIter(root)] + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], PreOrderIter(root)) + ] expected = [ "a", "b", @@ -263,7 +279,9 @@ def test_preorderiter(self): def test_levelorderiter(self): root, _ = create_test_tree() - result = [node.name for node in LevelOrderIter(root)] + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) + ] expected = [ "a", # root Node is unnamed "b", @@ -279,19 +297,20 @@ def test_levelorderiter(self): class TestAncestry: + def test_parents(self): - _, leaf = create_test_tree() + _, leaf_f = create_test_tree() expected = ["e", "b", "a"] - assert [node.name for node in leaf.parents] == expected + assert [node.name for node in leaf_f.parents] == expected def test_lineage(self): - _, leaf = create_test_tree() + _, leaf_f = create_test_tree() expected = ["f", "e", "b", "a"] - assert [node.name for node in leaf.lineage] == expected + assert [node.name for node in leaf_f.lineage] == expected def test_ancestors(self): - _, leaf = create_test_tree() - ancestors = leaf.ancestors + _, leaf_f = create_test_tree() + ancestors = leaf_f.ancestors expected = ["a", "b", "e", "f"] for node, expected_name in zip(ancestors, expected): assert node.name == expected_name @@ -356,22 +375,28 @@ def test_levels(self): class TestRenderTree: def test_render_nodetree(self): - sam = NamedNode() - ben = NamedNode() - mary = NamedNode(children={"Sam": sam, "Ben": ben}) - kate = NamedNode() - john = NamedNode(children={"Mary": mary, "Kate": kate}) - - printout = john.__str__() + sam: NamedNode = NamedNode() + ben: NamedNode = NamedNode() + mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben}) + kate: NamedNode = NamedNode() + john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate}) expected_nodes = [ "NamedNode()", - "NamedNode('Mary')", - "NamedNode('Sam')", - "NamedNode('Ben')", - "NamedNode('Kate')", + "\tNamedNode('Mary')", + "\t\tNamedNode('Sam')", + "\t\tNamedNode('Ben')", + "\tNamedNode('Kate')", ] - for expected_node, printed_node in zip(expected_nodes, printout.splitlines()): - assert expected_node in printed_node + expected_str = "NamedNode('Mary')" + john_repr = john.__repr__() + mary_str = mary.__str__() + + assert mary_str == expected_str + + john_nodes = john_repr.splitlines() + assert len(john_nodes) == len(expected_nodes) + for expected_node, repr_node in zip(expected_nodes, john_nodes): + assert expected_node == repr_node def test_nodepath():