diff --git a/spec/draft/index.rst b/spec/draft/index.rst index 9b1c1a0eb..5df9b31cb 100644 --- a/spec/draft/index.rst +++ b/spec/draft/index.rst @@ -30,6 +30,13 @@ Contents verification_test_suite benchmark_suite +.. toctree:: + :caption: Guides and Tutorials + :maxdepth: 1 + + migration_guide + tutorial_basic + .. toctree:: :caption: Other :maxdepth: 1 diff --git a/spec/draft/migration_guide.md b/spec/draft/migration_guide.md new file mode 100644 index 000000000..9babf2f35 --- /dev/null +++ b/spec/draft/migration_guide.md @@ -0,0 +1,236 @@ +(migration-guide)= + +# Migration Guide + +This page is meant to help migrate your codebase to an Array API compliant +implementation. The guide is divided into two parts and, depending on your +exact use-case, you should look thoroughly into at least one of them. + +The first part is dedicated for {ref}`array-producers`. If your library +mimics, for example, NumPy's or Dask's functionality, then you can find in +the first part additional instructions and guidance on how to ensure +downstream users can easily pick your solution as an array provider for +their system/algorithm. + +The second part delves into details for Array API compatibility for +{ref}`array-consumers`. This pertains to any software that performs +multidimensional array manipulation in Python, such as may be found in +scikit-learn, SciPy, or statsmodels. If your software relies on a certain +array producing library, such as NumPy or JAX, then you can use the second +part to learn how to make it library agnostic and interchange array +namespaces with significantly less friction. + +## Ecosystem + +Apart from the documented standard, the Array API ecosystem also provides +a set of tools and packages to help you with the migration process: + + +(array-api-compat)= + +### Array API Compat + +GitHub: [array-api-compat](https://github.com/data-apis/array-api-compat) + +User group: Array Consumers + +Although NumPy, Dask, CuPy, and PyTorch support the Array API Standard, there +are still some corner cases where their behavior diverges from the standard. +`array-api-compat` provides a compatibility layer to cover these cases. +This is also accompanied by a few utility functions for easier introspection +into array objects. As an array consumer, you can still rely on the original +API while having access to the standard compatible one. + + +(array-api-strict)= + +### Array API Strict + +GitHub: [array-api-strict](https://github.com/data-apis/array-api-strict) + +User group: Array Consumers, Array Producers (for testing) + +`array-api-strict` is a library that provides a strict and minimal +implementation of the Array API Standard. For array producers, it is designed +to be used as a reference implementation for testing and development purposes. +You can compare your API calls with `array-api-strict` counterparts and +ensure that your library is fully compliant with the standard and can +serve as a reliable reference for other developers in the ecosystem. +For consumers, you can use `array-api-strict` during the development as an +array provider to ensure your code uses APIs compliant with the standard. + + +(array-api-tests)= + +### Array API Test + +GitHub: [array-api-tests](https://github.com/data-apis/array-api-tests) + +User group: Array Producers + +`array-api-tests` is a collection of tests that can be used to verify the +compliance of your library with the Array API Standard. It includes tests +for array producers, covering a wide range of functionalities and use cases. +By running these tests, you can ensure that your library adheres to the +standard and can be used with compatible array consumer libraries. + + +(array-api-extra)= + +### Array API Extra + +GitHub: [array-api-extra](https://github.com/data-apis/array-api-extra) + +User group: Array Consumers + +`array-api-extra` is a collection of additional utilities and tools that are +missing from the Array API Standard but can be useful for compliant array +consumers. It includes additional array manipulation and statistical functions. +It is already used by SciPy and scikit-learn. + +The sections below mention when and how to use them. + + +(array-producers)= + +## Array Producers + +For array producers, the central task during the development/migration process +is ensuring that the user-facing API adheres to the Array API Standard. + +The complete API of the standard is documented in the +[API specification](https://data-apis.org/array-api/latest/API_specification/index.html). + +There, each function, constant, and object is described with details +on parameters, return values, and special cases. + +### Testing against Array API + +There are two main ways to test your API for compliance: either using +`array-api-tests` suite or testing your API manually against the +`array-api-strict` reference implementation. + +#### Array API Test suite (Recommended) + +{ref}`array-api-tests` is a test suite which verifies that your API +adheres to the standard. For each function or method, it confirms +it's importable, verifies the signature, generates multiple test +cases with the [hypothesis](https://hypothesis.readthedocs.io/en/latest/) +package, and runs assertions on the outputs. + +The setup details are enclosed in the GitHub repository, so here we +cover only the minimal workflow: + +1. Install your package (e.g., in editable mode). +2. Clone `array-api-tests`, and set the `ARRAY_API_TESTS_MODULE` environment + variable to your package import name. +3. Inside the `array-api-tests` directory run the command for running pytest: `pytest`. There are + multiple useful options delivered by the test suite. A few worth mentioning: + - `--max-examples=1000` - maximal number of test cases to generate when using + hypothesis. This allows you to balance between execution time of the test + suite and thoroughness of the testing. It's advised to use as many examples + as the time buget can fit. Each test case is a random combination of + possible inputs: the more cases, the higher chance of finding an + unsupported edge case. + - With the `--xfails-file` option, you can describe which tests are expected + to fail. It's impossible to get the whole API perfectly implemented on a + first try, so tracking what still fails gives you more control over the + state of your API. + - `-o xfail_strict=` is often used with the previous option. If a test + expected to fail actually passes (`XPASS`), then you can decide whether + to ignore that fact or raise it as an error. + - `--skips-file` for skipping tests. At times, some failing tests might stall + the execution time of the test suite. In that case, the most convenient + option is to skip these for the time being. + +We strongly advise you to embed this setup in your CI as well. This will allow +you to continuously monitor Array API coverage, and make sure new changes don't break existing +APIs. As a reference, see [NumPy's Array API Tests CI setup](https://github.com/numpy/numpy/blob/581d10f43b539a189a2d37856e5130464de9e5f6/.github/workflows/linux.yml#L296). + + +#### Array API Strict + +A simpler, and more manual, way of testing Array API coverage is to +run your API calls along with the {ref}`array-api-strict` Python implementation. + +This way, you can ensure that the outputs coming from your API match the minimal +reference implementation. Bear in mind, however, that you need to write +the tests cases yourself, so you need to also take into account any applicable edge +cases. + + +(array-consumers)= + +## Array Consumers + +For array consumers, the main premise is to keep in mind that your **array +manipulation operations should not lock in for a particular array producing +library**. For instance, if you use NumPy for arrays, then your code could +contain: + +```python +import numpy as np + +# ... +b = np.full(shape, val, dtype=dtype) @ a +c = np.mean(a, axis=0) +return np.dot(c, b) +``` + +The first step should be as simple as assigning the `np` namespace to a dedicated +namespace variable. The convention used in the ecosystem is to name it `xp`. Then, +it is vital to ensure that each method and function call is something that the Array API +supports. For example, `dot` is present in the NumPy's API, but the standard +doesn't support it. For the sake of simplicity, let's assume both `c` and `b` +are `ndim=2`; therefore, we select `tensordot` instead, as both NumPy and the +standard define it: + +```python +import numpy as np + +xp = np + +# ... +b = xp.full(shape, val, dtype=dtype) @ a +c = xp.mean(a, axis=0) +return xp.tensordot(c, b, axes=1) +``` + +At this point, replacing one backend with another one should only require providing a different +namespace, such as `xp = torch` (e.g., via an environment variable). This can be useful +if you're writing a script or in your custom software. The other alternatives are: + +- If you are building a library where the backend is determined by input arrays, + and your function accepts array arguments, then a recommended way is to ask + your input arrays for a namespace to use: `xp = arr.__array_namespace__()`. + If the given library doesn't have it, then [`array_api_compat.array_namespace()`](https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace) + should be used instead: + ```python + def func(array1, scalar1, scalar2): + xp = array1.__array_namespace__() # or array_namespace(array1) + return xp.arange(scalar1, scalar2) @ array1 + ``` +- For a function that accepts scalars and returns arrays, use namespace `xp` as + a parameter in the signature. Enforcing objects to have the same type as the + provided backend can then be achieved with `arg1 = xp.asarray(arg1)` for each input: + ```python + def func(s1, s2, xp): + return xp.arange(s1, s2) + ``` + +If you're relying on NumPy, CuPy, PyTorch, Dask, or JAX then +{ref}`array-api-compat` can come in handy for the transition. The compat layer +allows you to still rely on your preferred array producing library, while +making sure you're already using standard compatible API. Additionally, it +offers a set of useful utility functions, such as: + +- [array_namespace()](https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace) + for fetching the namespace based on input arrays. +- [is_array_api_obj()](https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.is_array_api_obj) + for inspecting whether a given object is Array API compatible. +- [device()](https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.device) + for retrieving the device on which an array resides. + +For now, the migration from a specific library (e.g., NumPy) to a standard +compatible setup requires a manual intervention for each failing API call, +but, in the future, we're hoping to provide tools for automating the migration process. diff --git a/spec/draft/tutorial_basic.md b/spec/draft/tutorial_basic.md new file mode 100644 index 000000000..ea8a3b453 --- /dev/null +++ b/spec/draft/tutorial_basic.md @@ -0,0 +1,158 @@ +(tutorial-basic)= + +# Array API Tutorial + +In this tutorial, we're going to demonstrate how to migrate to the Array API from the array consumer's +point of view for a simple graph algorithm. + +The example presented here comes from the [`graphblas-algorithms`](https://github.com/python-graphblas/graphblas-algorithms). +library. In particular, we'll be migrating [the HITS algorithm](https://github.com/python-graphblas/graphblas-algorithms/blob/35dbc90e808c6bf51b63d51d8a63f59238c02975/graphblas_algorithms/algorithms/link_analysis/hits_alg.py#L9), which is +used for the link analysis for estimating prominence in sparse networks, to be Array API compliant. + +The inlined and slightly simplified (without "authority" feature) +implementation looks similar to the following: + +```python +def hits(G, max_iter=100, tol=1.0e-8, normalized=True): + N = len(G) + h = Vector(float, N, name="h") + a = Vector(float, N, name="a") + h << 1.0 / N + # Power iteration: make up to max_iter iterations + A = G._A + hprev = Vector(float, N, name="h_prev") + for _i in range(max_iter): + hprev, h = h, hprev + a << hprev @ A + h << A @ a + h *= 1.0 / h.reduce(monoid.max).get(0) + if is_converged(hprev, h, tol): + break + else: + raise ConvergenceFailure(max_iter) + if normalized: + h *= 1.0 / h.reduce().get(0) + a *= 1.0 / a.reduce().get(0) + return h, a + +def is_converged(xprev, x, tol): + xprev << binary.minus(xprev | x) + xprev << unary.abs(xprev) + return xprev.reduce().get(0) < xprev.size * tol +``` + +We can see that the API is specific to the GraphBLAS array object. +There is a `Vector` constructor, overloaded `<<` for assigning new values, +and `reduce`/`get` for reductions. We need to replace them, and, by convention, +we will use `xp` namespace for calling respective functions. + +First, we want to make sure we construct arrays in an agnostic way: + +```python +h = xp.full(N, 1.0 / N) +A = xp.asarray(G.A) +``` + +Then, instead of `reduce` calls, we will use appropriate reduction +functions from the Array API: + +```python +h = h / xp.max(h) +# ... +h = h / xp.sum(xp.abs(h)) +a = a / xp.sum(xp.abs(a)) +# ... +err = xp.sum(xp.abs(...)) +``` + +We replace the custom binary operation with the Array API counterpart: + +```python +...(x - xprev) +``` + +And finally, let's ensure that the result of the convergence +condition is a scalar coming from our API: + +```python +err < xp.asarray(N * tol) +``` + +The rewrite is complete now, we can assemble all constituent parts into +a full implementation: + +```python +def hits(G, max_iter=100, tol=1.0e-8, normalized=True): + N = len(G) + h = xp.full(N, 1.0 / N) + A = xp.asarray(G.A) + # Power iteration: make up to max_iter iterations + for _i in range(max_iter): + hprev = h + a = hprev @ A + h = A @ a + h = h / xp.max(h) + if is_converged(hprev, h, N, tol): + break + else: + raise Exception("Didn't converge") + if normalized: + h = h / xp.sum(xp.abs(h)) + a = a / xp.sum(xp.abs(a)) + return h, a + +def is_converged(xprev, x, N, tol): + err = xp.sum(xp.abs(x - xprev)) + return err < xp.asarray(N * tol) +``` + +At this point, the actual execution depends only on `xp` namespace, +and replacing that one variable will allow us to switch from, e.g., NumPy arrays on the CPU +to JAX arrays for running on a GPU. This lets us be more flexible, and, for example, +use lazy evaluation and JIT compile a loop body with JAX's JIT compilation: + +```python +import jax +import jax.numpy as jnp + +xp = jnp + +def hits(G, max_iter=100, tol=1.0e-8, normalized=True): + N = len(G) + h = xp.full((N, 1), 1.0 / N) + A = xp.asarray(G.A) + # Power iteration: make up to max_iter iterations + for _i in range(max_iter): + h, a, conv = loop_body(h, A, N, tol) + if conv: + break + else: + raise Exception("Didn't converge") + if normalized: + h = h / xp.sum(xp.abs(h)) + a = a / xp.sum(xp.abs(a)) + return h, a + +@jax.jit +def loop_body(hprev, A, N, tol): + a = hprev.mT @ A + h = A @ a.mT + h = h / xp.max(h) + conv = is_converged(hprev, h, N, tol) + return h, a, conv + +def is_converged(xprev, x, N, tol): + err = xp.sum(xp.abs(x - xprev)) + return err < xp.asarray(N * tol) + +if __name__ == "__main__": + + class Graph(): + def __init__(self): + self.A = xp.ones((10, 10)) + def __len__(self): + return len(self.A) + + G = Graph() + h, a = hits(G) +```