Skip to content

Commit

Permalink
mypy typing. (terrible?)
Browse files Browse the repository at this point in the history
There must be a better way, but I don't know it.
particularly the list comprehension casts.
  • Loading branch information
flamingbear committed Feb 15, 2024
1 parent 32053b6 commit 32e7453
Showing 1 changed file with 65 additions and 59 deletions.
124 changes: 65 additions & 59 deletions xarray/tests/datatree/test_treenode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from collections.abc import Iterator
from typing import cast

import pytest

from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode
Expand All @@ -8,29 +11,29 @@

class TestFamilyTree:
def test_lonely(self):
root = TreeNode()
root: TreeNode = TreeNode()
assert root.parent is None
assert root.children == {}

def test_parenting(self):
john = TreeNode()
mary = TreeNode()
john: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
mary._set_parent(john, "Mary")

assert mary.parent == john
assert john.children["Mary"] is mary

def test_no_time_traveller_loops(self):
john = TreeNode()
john: TreeNode = TreeNode()

with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"):
john._set_parent(john, "John")

with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"):
john.children = {"John": john}

mary = TreeNode()
rose = TreeNode()
mary: TreeNode = TreeNode()
rose: TreeNode = TreeNode()
mary._set_parent(john, "Mary")
rose._set_parent(mary, "Rose")

Expand All @@ -41,36 +44,36 @@ def test_no_time_traveller_loops(self):
rose.children = {"John": john}

def test_parent_swap(self):
john = TreeNode()
mary = TreeNode()
john: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
mary._set_parent(john, "Mary")

steve = TreeNode()
steve: TreeNode = TreeNode()
mary._set_parent(steve, "Mary")

assert mary.parent == steve
assert steve.children["Mary"] is mary
assert "Mary" not in john.children

def test_multi_child_family(self):
mary = TreeNode()
kate = TreeNode()
john = TreeNode(children={"Mary": mary, "Kate": kate})
mary: TreeNode = TreeNode()
kate: TreeNode = TreeNode()
john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate})
assert john.children["Mary"] is mary
assert john.children["Kate"] is kate
assert mary.parent is john
assert kate.parent is john

def test_disown_child(self):
mary = TreeNode()
john = TreeNode(children={"Mary": mary})
mary: TreeNode = TreeNode()
john: TreeNode = TreeNode(children={"Mary": mary})
mary.orphan()
assert mary.parent is None
assert "Mary" not in john.children

def test_doppelganger_child(self):
kate = TreeNode()
john = TreeNode()
kate: TreeNode = TreeNode()
john: TreeNode = TreeNode()

with pytest.raises(TypeError):
john.children = {"Kate": 666}
Expand All @@ -79,22 +82,22 @@ def test_doppelganger_child(self):
john.children = {"Kate": kate, "Evil_Kate": kate}

john = TreeNode(children={"Kate": kate})
evil_kate = TreeNode()
evil_kate: TreeNode = TreeNode()
evil_kate._set_parent(john, "Kate")
assert john.children["Kate"] is evil_kate

def test_sibling_relationships(self):
mary = TreeNode()
kate = TreeNode()
ashley = TreeNode()
mary: TreeNode = TreeNode()
kate: TreeNode = TreeNode()
ashley: TreeNode = TreeNode()
TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley})
assert kate.siblings["Mary"] is mary
assert kate.siblings["Ashley"] is ashley
assert "Kate" not in kate.siblings

def test_ancestors(self):
tony = TreeNode()
michael = TreeNode(children={"Tony": tony})
tony: TreeNode = TreeNode()
michael: TreeNode = TreeNode(children={"Tony": tony})
vito = TreeNode(children={"Michael": michael})
assert tony.root is vito
assert tony.parents == (michael, vito)
Expand All @@ -103,7 +106,7 @@ def test_ancestors(self):

class TestGetNodes:
def test_get_child(self):
steven = TreeNode()
steven: TreeNode = TreeNode()
sue = TreeNode(children={"Steven": steven})
mary = TreeNode(children={"Sue": sue})
john = TreeNode(children={"Mary": mary})
Expand All @@ -126,8 +129,8 @@ def test_get_child(self):
assert mary._get_item("Sue/Steven") is steven

def test_get_upwards(self):
sue = TreeNode()
kate = TreeNode()
sue: TreeNode = TreeNode()
kate: TreeNode = TreeNode()
mary = TreeNode(children={"Sue": sue, "Kate": kate})
john = TreeNode(children={"Mary": mary})

Expand All @@ -138,7 +141,7 @@ def test_get_upwards(self):
assert sue._get_item("../Kate") is kate

def test_get_from_root(self):
sue = TreeNode()
sue: TreeNode = TreeNode()
mary = TreeNode(children={"Sue": sue})
john = TreeNode(children={"Mary": mary}) # noqa

Expand All @@ -147,8 +150,8 @@ def test_get_from_root(self):

class TestSetNodes:
def test_set_child_node(self):
john = TreeNode()
mary = TreeNode()
john: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
john._set_item("Mary", mary)

assert john.children["Mary"] is mary
Expand All @@ -157,16 +160,16 @@ def test_set_child_node(self):
assert mary.parent is john

def test_child_already_exists(self):
mary = TreeNode()
john = TreeNode(children={"Mary": mary})
mary_2 = TreeNode()
mary: TreeNode = TreeNode()
john: TreeNode = TreeNode(children={"Mary": mary})
mary_2: TreeNode = TreeNode()
with pytest.raises(KeyError):
john._set_item("Mary", mary_2, allow_overwrite=False)

def test_set_grandchild(self):
rose = TreeNode()
mary = TreeNode()
john = TreeNode()
rose: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
john: TreeNode = TreeNode()

john._set_item("Mary", mary)
john._set_item("Mary/Rose", rose)
Expand All @@ -177,8 +180,8 @@ def test_set_grandchild(self):
assert rose.parent is mary

def test_create_intermediate_child(self):
john = TreeNode()
rose = TreeNode()
john: TreeNode = TreeNode()
rose: TreeNode = TreeNode()

# test intermediate children not allowed
with pytest.raises(KeyError, match="Could not reach"):
Expand All @@ -194,12 +197,12 @@ def test_create_intermediate_child(self):
assert rose.parent == mary

def test_overwrite_child(self):
john = TreeNode()
mary = TreeNode()
john: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
john._set_item("Mary", mary)

# test overwriting not allowed
marys_evil_twin = TreeNode()
marys_evil_twin: TreeNode = TreeNode()
with pytest.raises(KeyError, match="Already a node object"):
john._set_item("Mary", marys_evil_twin, allow_overwrite=False)
assert john.children["Mary"] is mary
Expand All @@ -214,8 +217,8 @@ def test_overwrite_child(self):

class TestPruning:
def test_del_child(self):
john = TreeNode()
mary = TreeNode()
john: TreeNode = TreeNode()
mary: TreeNode = TreeNode()
john._set_item("Mary", mary)

del john["Mary"]
Expand All @@ -226,7 +229,7 @@ def test_del_child(self):
del john["Mary"]


def create_test_tree():
def create_test_tree() -> tuple[NamedNode, NamedNode]:
# a
# ├── b
# │ ├── d
Expand All @@ -236,16 +239,15 @@ def create_test_tree():
# └── c
# └── h
# └── i

a = NamedNode(name="a")
b = NamedNode()
c = NamedNode()
d = NamedNode()
e = NamedNode()
f = NamedNode()
g = NamedNode()
h = NamedNode()
i = NamedNode()
a: NamedNode = NamedNode(name="a")
b: NamedNode = NamedNode()
c: NamedNode = NamedNode()
d: NamedNode = NamedNode()
e: NamedNode = NamedNode()
f: NamedNode = NamedNode()
g: NamedNode = NamedNode()
h: NamedNode = NamedNode()
i: NamedNode = NamedNode()

a.children = {"b": b, "c": c}
b.children = {"d": d, "e": e}
Expand All @@ -259,7 +261,9 @@ def create_test_tree():
class TestIterators:
def test_preorderiter(self):
root, _ = create_test_tree()
result = [node.name for node in PreOrderIter(root)]
result: list[str | None] = [
node.name for node in cast(Iterator[NamedNode], PreOrderIter(root))
]
expected = [
"a",
"b",
Expand All @@ -275,7 +279,9 @@ def test_preorderiter(self):

def test_levelorderiter(self):
root, _ = create_test_tree()
result = [node.name for node in LevelOrderIter(root)]
result: list[str | None] = [
node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root))
]
expected = [
"a", # root Node is unnamed
"b",
Expand Down Expand Up @@ -369,11 +375,11 @@ def test_levels(self):

class TestRenderTree:
def test_render_nodetree(self):
sam = NamedNode()
ben = NamedNode()
mary = NamedNode(children={"Sam": sam, "Ben": ben})
kate = NamedNode()
john = NamedNode(children={"Mary": mary, "Kate": kate})
sam: NamedNode = NamedNode()
ben: NamedNode = NamedNode()
mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben})
kate: NamedNode = NamedNode()
john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate})

printout = john.__str__()
expected_nodes = [
Expand Down

0 comments on commit 32e7453

Please sign in to comment.