Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add threaded executor #187

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions python/Untitled.ipynb

This file was deleted.

6 changes: 4 additions & 2 deletions python/deno.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
]
},
"fmt": {
"useTabs": true
"useTabs": true,
"exclude": [".venv", "notebooks"]
},
"lint": {
"rules": {
"exclude": [
"prefer-const"
]
}
},
"exclude": [".venv", "notebooks"]
}
}
7 changes: 2 additions & 5 deletions python/notebooks/mandelbrot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@
"# Initialize the store\n",
"store = MandlebrotStore(levels=50, tilesize=512, compressor=numcodecs.Blosc())\n",
"# Wrap in a cache so that tiles don't need to be computed as often\n",
"store = zarr.LRUStoreCache(store, max_size=1e9)\n",
"\n",
"# This store implements the 'multiscales' zarr specfiication which is recognized by vizarr\n",
"grp = zarr.open(store, mode=\"r\")"
"store = zarr.LRUStoreCache(store, max_size=1e9)"
]
},
{
Expand All @@ -182,7 +179,7 @@
"import vizarr\n",
"\n",
"viewer = vizarr.Viewer()\n",
"viewer.add_image(source=grp, name=\"mandelbrot\")\n",
"viewer.add_image(source=store, name=\"mandelbrot\")\n",
"viewer"
]
}
Expand Down
23 changes: 20 additions & 3 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ dependencies = ["anywidget", "zarr"]
[project.optional-dependencies]
dev = ["watchfiles", "jupyterlab"]

# automatically add the dev feature to the default env (e.g., hatch shell)
[tool.hatch.envs.default]
features = ["dev"]
[tool.ruff.lint]
pydocstyle = { convention = "numpy" }
select = [
"E", # style errors
"W", # style warnings
"F", # flakes
"D", # pydocstyle
"D417", # Missing argument descriptions in Docstrings
"I", # isort
"UP", # pyupgrade
"C4", # flake8-comprehensions
"B", # flake8-bugbear
"A001", # flake8-builtins
"RUF", # ruff-specific rules
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
]

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
84 changes: 43 additions & 41 deletions python/src/vizarr/_widget.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as vizarr from "https://hms-dbmi.github.io/vizarr/index.js";
import debounce from "https://esm.sh/just-debounce-it@3";
import debounce from "https://esm.sh/just-debounce-it@3.2.0";

/**
* @template T
Expand All @@ -9,24 +9,24 @@ import debounce from "https://esm.sh/just-debounce-it@3";
* @returns {Promise<{ data: T, buffers: DataView[] }>}
*/
function send(model, payload, { timeout = 3000 } = {}) {
let uuid = globalThis.crypto.randomUUID();
let id = Math.random().toString(36).substring(7);
return new Promise((resolve, reject) => {
let timer = setTimeout(() => {
reject(new Error(`Promise timed out after ${timeout} ms`));
model.off("msg:custom", handler);
}, timeout);
/**
* @param {{ uuid: string, payload: T }} msg
* @param {{ id: string, payload: T }} msg
* @param {DataView[]} buffers
*/
function handler(msg, buffers) {
if (!(msg.uuid === uuid)) return;
if (!(msg.id === id)) return;
clearTimeout(timer);
resolve({ data: msg.payload, buffers });
model.off("msg:custom", handler);
}
model.on("msg:custom", handler);
model.send({ payload, uuid });
model.send({ payload, id });
});
}

Expand Down Expand Up @@ -71,41 +71,43 @@ function get_source(model, source) {
* @property {[x: number, y: number]} target
*/

/** @type {import("npm:@anywidget/types").Render<Model>} */
export async function render({ model, el }) {
let div = document.createElement("div");
{
div.style.height = model.get("height");
div.style.backgroundColor = "black";
model.on("change:height", () => {
export default {
/** @type {import("npm:@anywidget/types").Render<Model>} */
async render({ model, el }) {
let div = document.createElement("div");
{
div.style.height = model.get("height");
});
}
let viewer = await vizarr.createViewer(div);
{
model.on("change:view_state", () => {
viewer.setViewState(model.get("view_state"));
});
viewer.on(
"viewStateChange",
debounce((/** @type {ViewState} */ update) => {
model.set("view_state", update);
model.save_changes();
}, 200),
);
}
{
// sources are append-only now
for (const config of model.get("_configs")) {
const source = get_source(model, config.source);
viewer.addImage({ ...config, source });
div.style.backgroundColor = "black";
model.on("change:height", () => {
div.style.height = model.get("height");
});
}
model.on("change:_configs", () => {
const last = model.get("_configs").at(-1);
if (!last) return;
const source = get_source(model, last.source);
viewer.addImage({ ...last, source });
});
}
el.appendChild(div);
}
let viewer = await vizarr.createViewer(div);
{
model.on("change:view_state", () => {
viewer.setViewState(model.get("view_state"));
});
viewer.on(
"viewStateChange",
debounce((/** @type {ViewState} */ update) => {
model.set("view_state", update);
model.save_changes();
}, 200),
);
}
{
// sources are append-only now
for (const config of model.get("_configs")) {
const source = get_source(model, config.source);
viewer.addImage({ ...config, source });
}
model.on("change:_configs", () => {
const last = model.get("_configs").at(-1);
if (!last) return;
const source = get_source(model, last.source);
viewer.addImage({ ...last, source });
});
}
el.appendChild(div);
},
};
98 changes: 66 additions & 32 deletions python/src/vizarr/_widget.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,81 @@
from __future__ import annotations

import concurrent.futures
import os
import pathlib
from typing import TYPE_CHECKING, TypeGuard

import anywidget
import traitlets
import pathlib

import zarr
import numpy as np
if TYPE_CHECKING:
import numpy as np
import zarr
import zarr.storage

__all__ = ["Viewer"]

THREAD_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count())


def is_zarr_node(obj: object) -> TypeGuard[zarr.Array | zarr.Group]:
return hasattr(obj, "store") and hasattr(obj, "_key_prefix")


def is_readable_store(obj: object) -> TypeGuard[zarr.storage.BaseStore]:
return hasattr(obj, "__getitem__") and hasattr(obj, "__contains__")


def has_array_protocol(obj: object) -> bool:
return hasattr(obj, "__array__") or hasattr(obj, "__array_interface__")


def handle_custom_message(widget: Viewer, msg: dict, _buffers: list[bytes]):
store, key_prefix = widget._store_paths[msg["payload"]["source_id"]]
key = key_prefix + msg["payload"]["key"].lstrip("/")

if msg["payload"]["type"] == "has":
widget.send({"id": msg["id"], "payload": key in store})
return

if msg["payload"]["type"] == "get":

def target():
try:
buffers = [store[key]]
except KeyError:
buffers = []
widget.send(
{"id": msg["id"], "payload": {"success": len(buffers) == 1}},
buffers,
)

THREAD_EXECUTOR.submit(target)
return

def _store_keyprefix(obj):
# Just grab the store and key_prefix from zarr.Array and zarr.Group objects
if isinstance(obj, (zarr.Array, zarr.Group)):
raise ValueError(f"Unknown message type: {msg['payload']['type']}")


def get_store_keyprefix(obj: zarr.Array | zarr.Group | np.ndarray | dict):
if is_zarr_node(obj):
# Just grab the store and key_prefix from zarr.Array and zarr.Group objects
return obj.store, obj._key_prefix

if isinstance(obj, np.ndarray):
if has_array_protocol(obj):
# Create an in-memory store, and write array as as single chunk
store = {}
import numpy as np
import zarr
import zarr.storage

store = zarr.storage.MemoryStore()
data = np.asarray(obj)
arr = zarr.create(
store=store, shape=obj.shape, chunks=obj.shape, dtype=obj.dtype
store=store, shape=data.shape, chunks=data.shape, dtype=data.dtype
)
arr[:] = obj
return store, ""

if hasattr(obj, "__getitem__") and hasattr(obj, "__contains__"):
if is_readable_store(obj):
return obj, ""

raise TypeError("Cannot normalize store path")
Expand All @@ -37,31 +90,12 @@ class Viewer(anywidget.AnyWidget):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._store_paths = []
self.on_msg(self._handle_custom_msg)

def _handle_custom_msg(self, msg, buffers):
store, key_prefix = self._store_paths[msg["payload"]["source_id"]]
key = key_prefix + msg["payload"]["key"].lstrip("/")

if msg["payload"]["type"] == "has":
self.send({"uuid": msg["uuid"], "payload": key in store})
return

if msg["payload"]["type"] == "get":
try:
buffers = [store[key]]
except KeyError:
buffers = []
self.send(
{"uuid": msg["uuid"], "payload": {"success": len(buffers) == 1}},
buffers,
)
return
self.on_msg(handle_custom_message)

def add_image(self, source, **config):
if not isinstance(source, str):
store, key_prefix = _store_keyprefix(source)
store, key_prefix = get_store_keyprefix(source)
source = {"id": len(self._store_paths)}
self._store_paths.append((store, key_prefix))
config["source"] = source
self._configs = self._configs + [config]
self._configs = [*self._configs, config]
Loading