Skip to content

Commit

Permalink
ENH: Implemented __getitem__ logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Sep 24, 2024
1 parent c81d2e2 commit a8ded71
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 8 deletions.
6 changes: 6 additions & 0 deletions sparse/mlir_backend/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ def __del__(self):
for field in self._obj.get__fields_():
free_memref(field)

def __getitem__(self, key) -> "Tensor":
# imported lazily to avoid cyclic dependency
from ._ops import getitem

return getitem(self, key)

@_hold_self_ref_in_ret
def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
return self._obj.to_sps(self.shape)
Expand Down
117 changes: 111 additions & 6 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ctypes
from types import EllipsisType

import mlir.execution_engine
import mlir.passmanager
Expand Down Expand Up @@ -85,12 +86,39 @@ def get_reshape_module(
def reshape(a, shape):
return tensor.reshape(out_tensor_type, a, shape)

reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


@fn_cache
def get_slice_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
offsets: tuple[int, ...],
sizes: tuple[int, ...],
strides: tuple[int, ...],
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def getitem(a):
return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)

getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "getitem_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "getitem_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])

Expand Down Expand Up @@ -135,3 +163,80 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
)

return Tensor(ret_obj, shape=out_tensor_type.shape)


def _add_missing_dims(key: tuple, ndim: int) -> tuple:
if len(key) < ndim and Ellipsis not in key:
return key + (...,)
return key


def _expand_ellipsis(key: tuple, ndim: int) -> tuple:
if Ellipsis in key:
if len([e for e in key if e is Ellipsis]) > 1:
raise Exception(f"Ellipsis should be used once: {key}")
to_expand = ndim - len(key) + 1
if to_expand <= 0:
raise Exception(f"Invalid use of Ellipsis in {key}")
idx = key.index(Ellipsis)
return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]
return key


def _decompose_slices(
key: tuple,
shape: tuple[int, ...],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
offsets = []
sizes = []
strides = []

for key_elem, size in zip(key, shape, strict=False):
if isinstance(key_elem, slice):
offset = key_elem.start if key_elem.start is not None else 0
size = key_elem.stop - offset if key_elem.stop is not None else size - offset
stride = key_elem.step if key_elem.step is not None else 1
elif isinstance(key_elem, int):
offset = key_elem
size = key_elem + 1
stride = 1
offsets.append(offset)
sizes.append(size)
strides.append(stride)

return tuple(offsets), tuple(sizes), tuple(strides)


def _get_new_shape(sizes, strides) -> tuple[int, ...]:
return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))


def getitem(
x: Tensor,
key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],
) -> Tensor:
if not isinstance(key, tuple):
key = (key,)
if None in key:
raise Exception(f"Lazy indexing isn't supported: {key}")

ret_obj = x._format_class()

key = _add_missing_dims(key, x.ndim)
key = _expand_ellipsis(key, x.ndim)
offsets, sizes, strides = _decompose_slices(key, x.shape)

new_shape = _get_new_shape(sizes, strides)
out_tensor_type = x._obj.get_tensor_definition(new_shape)

slice_module = get_slice_module(
x._obj.get_tensor_definition(x.shape),
out_tensor_type,
offsets,
sizes,
strides,
)

slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())

return Tensor(ret_obj, shape=out_tensor_type.shape)
39 changes: 37 additions & 2 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ def test_reshape(rng, dtype):
arr = sps.random_array(
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
)
if format == "coo":
arr.sum_duplicates()
arr.sum_duplicates()

tensor = sparse.asarray(arr)

Expand Down Expand Up @@ -264,3 +263,39 @@ def test_reshape(rng, dtype):
# DENSE
# NOTE: dense reshape is probably broken in MLIR
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)


@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404")
@parametrize_dtypes
@pytest.mark.parametrize(
"index",
[
0,
(2,),
(2, 3),
(..., slice(0, 4, 2)),
(1, slice(1, None, 1)),
# TODO: For below cases we need an update to ownership mechanism.
# `tensor[:, :]` returns the same memref that was passed.
# The mechanism sees the result as MLIR-allocated and frees
# it, while it still can be owned by SciPy/NumPy causing a
# segfault when it frees SciPy/NumPy managed memory.
# ...,
# slice(None),
# (slice(None), slice(None)),
],
)
def test_indexing_2d(rng, dtype, index):
SHAPE = (20, 30)
DENSITY = 0.5

for format in ["csr", "csc", "coo"]:
arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng)
arr.sum_duplicates()

tensor = sparse.asarray(arr)

actual = tensor[index].to_scipy_sparse()
expected = arr.todense()[index]

np.testing.assert_array_equal(actual.todense(), expected)

0 comments on commit a8ded71

Please sign in to comment.