Skip to content

Adopt the array API standard for the multi-backend numeric core #132

@frazane

Description

@frazane

Background

scoringrules runs every score across numpy, numba, jax, and torch through a hand-written ArrayBackend abstraction — an ABC with one concrete subclass per framework (~1400 lines). It works, but it's a maintenance burden: every new numeric primitive has to be added to the ABC and re-implemented in each backend, and users must explicitly select or register a backend.

The Python array API standard (via array-api-compat) now provides a portable, framework-agnostic surface that covers most of what we hand-rolled. Adopting it lets us infer the framework directly from the input arrays, drop most of the custom backend code, and return results in the caller's own framework with no configuration.

Goals

  • Infer the array framework from input arrays instead of a global/registry selection.
  • Replace the ArrayBackend ABC with array-api-compat plus a thin extension layer for special functions the standard doesn't cover (erf, gamma, betainc, …).
  • Keep numba's compiled gufuncs as an explicit, opt-in fast path (backend="numba").
  • Deprecate registry-based backend selection; long-term, remove the backend= argument entirely.

Approach

This touches all ten score families, so we're doing it in phases — family by family — with the shared infrastructure landing alongside the first pilot family.

  • Phase 1 — Foundation + CRPS pilot. Add the array-api-compat dependency; build the inference and augmented-namespace layer plus the native-first special-function extension layer; migrate the entire CRPS family end-to-end as the pilot. Deprecate backend="numpy"/"jax"/"torch" and register_backend / set_active for array-API frameworks; mixed-framework inputs now raise. Tests exercise native arrays per framework and assert the inferred backend so a misrouted score can't silently fall back to numpy.
  • Phase 2 — Migrate the remaining families. Port the other nine families (energy, logs, brier, kernels, variogram, dss, interval, quantile, error_spread) and their weighted variants onto the same infrastructure, folding the parallel _xp helper modules back into the shared core as each family moves over.
  • Phase 3 — Remove the old abstraction. Once every family is migrated, delete the ArrayBackend ABC, the per-framework subclasses, and the backend registry.
  • Phase 4 (long-term) — Remove backend= entirely. Explore transparent auto-numba so the compiled fast path is selected automatically for numpy inputs, removing the need for the backend argument at all.

Behavior changes from Phase 1

  • The default no longer routes through numba. interval_score and crps_quantile previously defaulted to gufuncs and now require backend="numba" to opt in.
  • torch has no native betainc / hyp2f1 / expi; scores needing those raise NotImplementedError for torch inputs (pass numpy or jax instead).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions