Skip to content

Commit

Permalink
Merge branch 'main' into mhs/migrate_treenode
Browse files Browse the repository at this point in the history
  • Loading branch information
flamingbear authored Feb 26, 2024
2 parents 9cbaf3b + e47eb92 commit 9db7040
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ New Features
- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
42 changes: 39 additions & 3 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,17 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
def _oindex_get(self, key):
raise NotImplementedError("This method should be overridden")

def _vindex_get(self, key):
raise NotImplementedError("This method should be overridden")

@property
def oindex(self):
return IndexCallable(self._oindex_get)

@property
def vindex(self):
return IndexCallable(self._vindex_get)


class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
"""Wrap an array, converting tuples into the indicated explicit indexer."""
Expand Down Expand Up @@ -585,6 +592,10 @@ def transpose(self, order):
def _oindex_get(self, indexer):
return type(self)(self.array, self._updated_key(indexer))

def _vindex_get(self, indexer):
array = LazilyVectorizedIndexedArray(self.array, self.key)
return array[indexer]

def __getitem__(self, indexer):
if isinstance(indexer, VectorizedIndexer):
array = LazilyVectorizedIndexedArray(self.array, self.key)
Expand Down Expand Up @@ -644,6 +655,12 @@ def get_duck_array(self):
def _updated_key(self, new_key):
return _combine_indexers(self.key, self.shape, new_key)

def _oindex_get(self, indexer):
return type(self)(self.array, self._updated_key(indexer))

def _vindex_get(self, indexer):
return type(self)(self.array, self._updated_key(indexer))

def __getitem__(self, indexer):
# If the indexed array becomes a scalar, return LazilyIndexedArray
if all(isinstance(ind, integer_types) for ind in indexer.tuple):
Expand Down Expand Up @@ -691,6 +708,9 @@ def get_duck_array(self):
def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def _vindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

Expand Down Expand Up @@ -727,6 +747,9 @@ def get_duck_array(self):
def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def _vindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

Expand Down Expand Up @@ -1364,8 +1387,12 @@ def transpose(self, order):
return self.array.transpose(order)

def _oindex_get(self, key):
array, key = self._indexing_array_and_key(key)
return array[key]
key = _outer_to_numpy_indexer(key, self.array.shape)
return self.array[key]

def _vindex_get(self, key):
array = NumpyVIndexAdapter(self.array)
return array[key.tuple]

def __getitem__(self, key):
array, key = self._indexing_array_and_key(key)
Expand Down Expand Up @@ -1419,6 +1446,9 @@ def _oindex_get(self, key):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value

def _vindex_get(self, key):
raise TypeError("Vectorized indexing is not supported")

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
Expand Down Expand Up @@ -1465,11 +1495,14 @@ def _oindex_get(self, key):
value = value[(slice(None),) * axis + (subkey,)]
return value

def _vindex_get(self, key):
return self.array.vindex[key.tuple]

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, VectorizedIndexer):
return self.array.vindex[key.tuple]
return self.vindex[key]
else:
assert isinstance(key, OuterIndexer)
return self.oindex[key]
Expand Down Expand Up @@ -1551,6 +1584,9 @@ def _convert_scalar(self, item):
def _oindex_get(self, key):
return self.__getitem__(key)

def _vindex_get(self, key):
return self.__getitem__(key)

def __getitem__(
self, indexer
) -> (
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,10 @@ def __getitem__(self, key) -> Self:
dims, indexer, new_order = self._broadcast_indexes(key)
indexable = as_indexable(self._data)

if isinstance(indexer, BasicIndexer):
data = indexable[indexer]
elif isinstance(indexer, OuterIndexer):
if isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]
elif isinstance(indexer, VectorizedIndexer):
data = indexable.vindex[indexer]
else:
data = indexable[indexer]
if new_order:
Expand Down Expand Up @@ -801,6 +801,9 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):

if isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]

elif isinstance(indexer, VectorizedIndexer):
data = indexable.vindex[indexer]
else:
data = indexable[actual_indexer]
mask = indexing.create_mask(indexer, self.shape, data)
Expand Down

0 comments on commit 9db7040

Please sign in to comment.