Skip to content

Data type extension API: friction and recommendations for third-party data type authors #4

Description

@d-v-b-agent

Context

As an exercise to evaluate the data type extension API (the ZDType interface a third party uses to add a new data type), I implemented bfloat16 — a real, standardized zarr-extensions data type that zarr-python doesn't ship with — as an out-of-tree extension, and separately had it reviewed by someone familiar with zarrs (Rust) and zarrita (Python) but not zarr-python. This issue records what the API does well and a prioritized list of improvements.

For reference, zarrs and zarrita bracket the design space: zarrs is powerful and fully decomposed (separate traits for data type/configuration, element byte (de)serialization, and per-codec compatibility; dual compile-time + runtime registries; a central fill-value parser that decodes "NaN"/"Infinity"/hex floats). zarrita is a closed ~11-member enum with no extension mechanism at all. ZDType sits close to zarrs in capability but much lighter, because it delegates byte-level concerns to NumPy.

(Related: #3 simplifies the registry and shares the dtype-JSON handling. This issue is about the author-facing extension API, which that PR does not change.)

What the API does well

  • Clear core abstraction: ZDType[native_dtype_cls, scalar_cls] makes the two bridged types explicit.
  • Instance-carries-parameters model generalizes to parametrized types and beats zarrita's closed enum.
  • The data-type-JSON vs scalar-JSON split maps 1:1 to zarrs' configuration() vs fill_value().
  • Parameter-free types need zero data-type-JSON methods (after Simplify the data type registry and share data-type JSON handling #3) — the common case is tiny.
  • Leaning on NumPy collapses huge surface (no byte layout / element encoding / codec compat / size to implement, unlike zarrs).

Recommendations

Ordered by value. [both] = independently raised by both the author experience and the external reviewer (strongest signals).

  1. [both] Provide a reusable, width/dtype-generic float-scalar JSON helper. (Additive, non-breaking — highest value.) Hand-writing the spec-mandated "NaN"/"Infinity"/"-Infinity"/hex-float encoding is the single biggest friction point; every float-extension author re-implements it and most get an edge case wrong. zarr-python already has float_to_json_v3/float_from_json_v3, but they're private and hardcoded to IEEE float16/32/64 layouts (the hex branch picks >e/>f/>d by hex-string length), so they're not even usable for bfloat16. A public, width-aware helper would turn a float type's scalar methods from ~20 lines into ~3.

  2. [both] Make the override surface consistently public. (Breaking — changes the public extension interface.) _check_scalar, _check_native_dtype, _to_json_v3/_from_json_v3, and the _zarr_v3_name classvar are underscore-prefixed yet are exactly what an author must implement/use — sitting next to public to_json_scalar/from_native_dtype/dtype_cls. Drop the underscores from the intended extension points (or clearly document them as such).

  3. Use one version-dispatch mechanism. Scalar JSON threads a runtime zarr_format param while data-type JSON uses separate _to_json_v2/_to_json_v3 methods. Pick one (zarrs uses a uniform version param). For version-independent scalar encodings the threaded zarr_format is pure noise.

  4. Remove or document the duplicated dtype-class declaration — the native dtype class is written twice (generic arg ZDType[X, ...] and dtype_cls = X).

  5. Document the V2-name derivation and the scalar round-trip contract. The base reuses the V3 name as the V2 name (works only because the example's numpy dtype has an alias; np.dtype(bfloat16).str is "<V2", non-round-trippable). And from_json_scalar reads hex while to_json_scalar writes canonical only — an intentional write-canonical/read-lenient contract that should be stated (a reviewer reasonably flagged it as a bug).

  6. Nudge authors toward range/precision validation. cast_scalar doing scalar_cls(data) silently turns a huge int into inf or drops precision; zarrs forces range validation (its uint10 example rejects out-of-range values).

  7. Add a parametrized example and a variable-length/object example. Every hard question (config JSON, endianness, variable size, validation, registration semantics) lives in the cases the parameter-free example skips.

Cross-validation

The author experience and the external reviewer — with no knowledge of each other or of zarr-python internals — independently converged on items 1–5. That convergence is the clearest signal of where the API actually hurts: the placement of special-float encoding on authors, and the naming/versioning/duplication inconsistencies in the override surface.

The bfloat16 extension this is based on (condensed)
import ml_dtypes, numpy as np, math
from typing import ClassVar, Literal, Self
from zarr.dtype import ZDType, register_data_type
from zarr.errors import DataTypeValidationError
from zarr.types import JSON, ZarrFormat

bfloat16_dtype_cls = type(np.dtype(ml_dtypes.bfloat16))
bfloat16_scalar_cls = ml_dtypes.bfloat16

class BFloat16(ZDType[bfloat16_dtype_cls, bfloat16_scalar_cls]):
    # parameter-free: no to_json/from_json needed; the base handles V2/V3.
    _zarr_v3_name: ClassVar[Literal["bfloat16"]] = "bfloat16"
    dtype_cls = bfloat16_dtype_cls

    @classmethod
    def from_native_dtype(cls, dtype) -> Self:
        if cls._check_native_dtype(dtype):
            return cls()
        raise DataTypeValidationError(f"Invalid data type: {dtype}")

    def to_native_dtype(self): return self.dtype_cls()
    def _check_scalar(self, data): return isinstance(data, (int, float, np.integer, np.floating, bfloat16_scalar_cls))
    def cast_scalar(self, data):
        if self._check_scalar(data): return bfloat16_scalar_cls(data)
        raise TypeError(...)
    def default_scalar(self): return bfloat16_scalar_cls(0)

    # ~20 lines of hand-written special-float JSON (the friction in rec. #1):
    def to_json_scalar(self, data, *, zarr_format) -> JSON:
        f = float(self.cast_scalar(data))
        if math.isnan(f): return "NaN"
        if f == math.inf: return "Infinity"
        if f == -math.inf: return "-Infinity"
        return f
    def from_json_scalar(self, data, *, zarr_format):
        if isinstance(data, str):
            specials = {"NaN": math.nan, "Infinity": math.inf, "-Infinity": -math.inf}
            if data in specials: return bfloat16_scalar_cls(specials[data])
            if data.startswith("0x"):
                return np.array(int(data, 16), dtype=np.uint16).view(bfloat16_scalar_cls)[()]
            raise TypeError(...)
        if isinstance(data, (int, float)): return bfloat16_scalar_cls(data)
        raise TypeError(...)

register_data_type(BFloat16)

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions