Update JEP-12049 implementation discussion

This commit is contained in:
Jake VanderPlas 2022-09-20 09:44:29 -07:00
parent fc2902c6ac
commit fce1099997

View File

@ -33,7 +33,7 @@ def slice(operand: Array, start_indices: Sequence[int],
...
```
For the purposes of static type checking, this use of `Array = Any` for array annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.
For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.
For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)).
@ -195,7 +195,7 @@ One impact of this is that for the time being, when functions are decorated by J
While this is unfortunate, at the time of this writing mypy has a laundry-list of incompatibilities with the potential solution offered by `ParamSpec` (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)), and we therefore judge it as not ready for full adoption in JAX at this time.
We will revisit this question in the future once support for such features stabilizes.
Similarly, for the time being we will avoid adding the more complex & granular array annotations offered by the [jaxtyping](http://github.com/google/jaxtyping) project. This is a decision we could revisit at a future date.
Similarly, for the time being we will avoid adding the more complex & granular array type annotations offered by the [jaxtyping](http://github.com/google/jaxtyping) project. This is a decision we could revisit at a future date.
### `Array` Type Design Considerations
@ -277,31 +277,34 @@ Here, `jax.numpy.ndarray` would become a simple alias `jax.Array` for backward c
Finally, we could opt for full unification via restructuring of the class hierarchy and replacing duck-typing with OOP object hierarchies:
- `jax.Array` is the actual type of on-device arrays
- `jax.Array` is also the object used for array annotations, by ensuring that `Tracer` inherits from `jax.Array`
- `jax.Array` is also the object used for array type annotations, by ensuring that `Tracer` inherits from `jax.Array`
- `jax.Array` is also the object used for instance checks, via the same mechanism
Here `jnp.ndarray` could be an alias for `jax.Array`.
This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (`Tracer` *is an* `Array`?).
##### Option 4: Parial unification via class hierarchy
We could appease OOP pedants by instead making `Tracer` and `Array` derive from a common `ArrayBase` base class:
We could make the class hierarchy more sensible by making `Tracer` and the class for
on-device arrays inherit from a common base class. So, for example:
- `jax.Array` is the actual type of on-device arrays
- `ArrayBase` is the object used for array annotations
- `ArrayBase` is also the object used for instance checks
- `jax.Array` is a base class for `Tracer` as well as the actual type of on-device arrays,
which might be `jax._src.ArrayImpl` or similar.
- `jax.Array` is the object used for array type annotations
- `jax.Array` is also the object used for instance checks
Here `jnp.ndarray` would be an alias for `ArrayBase`.
This may be purer from an OOP perspective, but it reintroduces a bifurcation and the distinction between `Array` and `ArrayBase` for annotation and instance checks may become confusing.
Here `jnp.ndarray` would be an alias for `Array`.
This may be purer from an OOP perspective, but compared to Options 2 and 3 it drops the notion
that `type(x) is jax.Array` will evaluate to True.
##### Evaluation
Considering the overall strengths and weaknesses of each potential approach:
- From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations: `jax.Array` is all you need to know
- Between Option 2 and Option 3, the purer (in an OOP sense) apporach is arguably Option 3 (`Tracer` as a subclass of `Array`), but in other senses it breaks the inheritance model, because it would require `Tracer` objects to carry all the baggage of `Array` objects (data buffers, sharding, devices, etc.)
- Option 2 is less pure in an OOP sense, but it aligns more closely with the spirit of how JAX is designed, with `Tracer` objects duck-typing as `Arrays`, and using mechanisms that Python provides to support this kind of duck typing. There is one minor technical hurdle involved; that is that `jax.Array` will be defined in C++ via pybind11, and pybind11 currently [does not support](https://github.com/pybind/pybind11/issues/2696) custom metaclasses required for overriding `__instancecheck__`; but it is likely possible to work around this via Python's C API.
- From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations: `jax.Array` is all you need to know.
- However, both options 2 and 3 introduce some strange and/or confusing behavior. Option 2 depends on potentially confusing overrides of instance checks, which are [not well supported](https://github.com/pybind/pybind11/issues/2696) for classes defined in pybind11. Option 3 requires `Tracer` to be a subclass array. This breaks the inheritance model, because it would require `Tracer` objects to carry all the baggage of `Array` objects (data buffers, sharding, devices, etc.)
- Option 4 is purer in an OOP sense, and avoids the need for any overrides of typical instance check or type annotation behavior. The tradeoff is that the actual type of on-device arrays becomes something separate (here `jax._src.ArrayImpl`). But the vast majority of users would never have to touch this private implementation directly.
With this in mind, we conclude that Option 2 presents the best path forward.
There are different tradeoffs here, but after discussion we've landed on Option 4 as our way forward.
### Implementation Plan
@ -318,16 +321,15 @@ To move forward with type annotations, we will do the following:
The beginnings of this are done in {jax-issue}`#12300`.
3. Begin work on `jax._src.typing.Array` that follows Option 2 from the previous section.
Because `jax.Array` is not yet out of experimental, we'll start by define `jax._src.typing.Array` with the type annotation and instance-checking features that will eventually move to `jax.Array` when it has fully landed.
3. Begin work on a `jax.Array` base class that follows Option 4 from the previous section. Initially this will be defined in Python, and use the dynamic registration mechanism currently found in the `jnp.ndarray` implementation to ensure correct behavior of `isinstance` checks. A `pyi` override for each tracer and array-like class would ensure correct behavior for type annotations. `jnp.ndarray` could then be make into an alias of `jax.Array`
4. When this is implemented, remove the existing definition of `jnp.ndarray` and set is as an alias of `jax._src.typing.Array`.
5. As a test, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
4. As a test, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
6. Continue adding additional annotations one module at a time, focusing on public API functions.
7. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
5. In parallel, begin re-implementing a `jax.Array` base class in pybind11, so that `ArrayImpl` and `Tracer` can inherit from it. Use a `pyi` definition to ensure static type checkers recognize the appropriate attributes of the class.
7. Once `jax.Array` and `jax._src.ArrayImpl` have fully landed, remove these temporary Python implementations.
8. When all is finalized, create a public `jax.typing` module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.