mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Update JEP-12049 implementation discussion
This commit is contained in:
parent
fc2902c6ac
commit
fce1099997
@ -33,7 +33,7 @@ def slice(operand: Array, start_indices: Sequence[int],
|
||||
...
|
||||
```
|
||||
|
||||
For the purposes of static type checking, this use of `Array = Any` for array annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.
|
||||
For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer.
|
||||
|
||||
For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)).
|
||||
|
||||
@ -195,7 +195,7 @@ One impact of this is that for the time being, when functions are decorated by J
|
||||
While this is unfortunate, at the time of this writing mypy has a laundry-list of incompatibilities with the potential solution offered by `ParamSpec` (see [`ParamSpec` mypy bug tracker](https://github.com/python/mypy/issues?q=is%3Aissue+is%3Aopen++label%3Atopic-paramspec+)), and we therefore judge it as not ready for full adoption in JAX at this time.
|
||||
We will revisit this question in the future once support for such features stabilizes.
|
||||
|
||||
Similarly, for the time being we will avoid adding the more complex & granular array annotations offered by the [jaxtyping](http://github.com/google/jaxtyping) project. This is a decision we could revisit at a future date.
|
||||
Similarly, for the time being we will avoid adding the more complex & granular array type annotations offered by the [jaxtyping](http://github.com/google/jaxtyping) project. This is a decision we could revisit at a future date.
|
||||
|
||||
### `Array` Type Design Considerations
|
||||
|
||||
@ -277,31 +277,34 @@ Here, `jax.numpy.ndarray` would become a simple alias `jax.Array` for backward c
|
||||
Finally, we could opt for full unification via restructuring of the class hierarchy and replacing duck-typing with OOP object hierarchies:
|
||||
|
||||
- `jax.Array` is the actual type of on-device arrays
|
||||
- `jax.Array` is also the object used for array annotations, by ensuring that `Tracer` inherits from `jax.Array`
|
||||
- `jax.Array` is also the object used for array type annotations, by ensuring that `Tracer` inherits from `jax.Array`
|
||||
- `jax.Array` is also the object used for instance checks, via the same mechanism
|
||||
|
||||
Here `jnp.ndarray` could be an alias for `jax.Array`.
|
||||
This final approach is in some senses the most pure, but it is somewhat forced from an OOP design standpoint (`Tracer` *is an* `Array`?).
|
||||
|
||||
##### Option 4: Parial unification via class hierarchy
|
||||
We could appease OOP pedants by instead making `Tracer` and `Array` derive from a common `ArrayBase` base class:
|
||||
We could make the class hierarchy more sensible by making `Tracer` and the class for
|
||||
on-device arrays inherit from a common base class. So, for example:
|
||||
|
||||
- `jax.Array` is the actual type of on-device arrays
|
||||
- `ArrayBase` is the object used for array annotations
|
||||
- `ArrayBase` is also the object used for instance checks
|
||||
- `jax.Array` is a base class for `Tracer` as well as the actual type of on-device arrays,
|
||||
which might be `jax._src.ArrayImpl` or similar.
|
||||
- `jax.Array` is the object used for array type annotations
|
||||
- `jax.Array` is also the object used for instance checks
|
||||
|
||||
Here `jnp.ndarray` would be an alias for `ArrayBase`.
|
||||
This may be purer from an OOP perspective, but it reintroduces a bifurcation and the distinction between `Array` and `ArrayBase` for annotation and instance checks may become confusing.
|
||||
Here `jnp.ndarray` would be an alias for `Array`.
|
||||
This may be purer from an OOP perspective, but compared to Options 2 and 3 it drops the notion
|
||||
that `type(x) is jax.Array` will evaluate to True.
|
||||
|
||||
##### Evaluation
|
||||
|
||||
Considering the overall strengths and weaknesses of each potential approach:
|
||||
|
||||
- From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations: `jax.Array` is all you need to know
|
||||
- Between Option 2 and Option 3, the purer (in an OOP sense) apporach is arguably Option 3 (`Tracer` as a subclass of `Array`), but in other senses it breaks the inheritance model, because it would require `Tracer` objects to carry all the baggage of `Array` objects (data buffers, sharding, devices, etc.)
|
||||
- Option 2 is less pure in an OOP sense, but it aligns more closely with the spirit of how JAX is designed, with `Tracer` objects duck-typing as `Arrays`, and using mechanisms that Python provides to support this kind of duck typing. There is one minor technical hurdle involved; that is that `jax.Array` will be defined in C++ via pybind11, and pybind11 currently [does not support](https://github.com/pybind/pybind11/issues/2696) custom metaclasses required for overriding `__instancecheck__`; but it is likely possible to work around this via Python's C API.
|
||||
- From a user perspective, the unified approaches (options 2 and 3) are arguably best, because they remove the cognitive overhead involved in remembering which objects to use for instance checks or annotations: `jax.Array` is all you need to know.
|
||||
- However, both options 2 and 3 introduce some strange and/or confusing behavior. Option 2 depends on potentially confusing overrides of instance checks, which are [not well supported](https://github.com/pybind/pybind11/issues/2696) for classes defined in pybind11. Option 3 requires `Tracer` to be a subclass array. This breaks the inheritance model, because it would require `Tracer` objects to carry all the baggage of `Array` objects (data buffers, sharding, devices, etc.)
|
||||
- Option 4 is purer in an OOP sense, and avoids the need for any overrides of typical instance check or type annotation behavior. The tradeoff is that the actual type of on-device arrays becomes something separate (here `jax._src.ArrayImpl`). But the vast majority of users would never have to touch this private implementation directly.
|
||||
|
||||
With this in mind, we conclude that Option 2 presents the best path forward.
|
||||
There are different tradeoffs here, but after discussion we've landed on Option 4 as our way forward.
|
||||
|
||||
### Implementation Plan
|
||||
|
||||
@ -318,16 +321,15 @@ To move forward with type annotations, we will do the following:
|
||||
|
||||
The beginnings of this are done in {jax-issue}`#12300`.
|
||||
|
||||
3. Begin work on `jax._src.typing.Array` that follows Option 2 from the previous section.
|
||||
Because `jax.Array` is not yet out of experimental, we'll start by define `jax._src.typing.Array` with the type annotation and instance-checking features that will eventually move to `jax.Array` when it has fully landed.
|
||||
3. Begin work on a `jax.Array` base class that follows Option 4 from the previous section. Initially this will be defined in Python, and use the dynamic registration mechanism currently found in the `jnp.ndarray` implementation to ensure correct behavior of `isinstance` checks. A `pyi` override for each tracer and array-like class would ensure correct behavior for type annotations. `jnp.ndarray` could then be make into an alias of `jax.Array`
|
||||
|
||||
4. When this is implemented, remove the existing definition of `jnp.ndarray` and set is as an alias of `jax._src.typing.Array`.
|
||||
|
||||
5. As a test, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
|
||||
4. As a test, use these new typing definitions to comprehensively annotate functions within `jax.lax` according to the guidelines above.
|
||||
|
||||
6. Continue adding additional annotations one module at a time, focusing on public API functions.
|
||||
|
||||
7. Once `jax.Array` has fully landed, migrate the features of `jax._src.typing.Array` to this class, and let `jax.numpy.ndarray = Array`.
|
||||
5. In parallel, begin re-implementing a `jax.Array` base class in pybind11, so that `ArrayImpl` and `Tracer` can inherit from it. Use a `pyi` definition to ensure static type checkers recognize the appropriate attributes of the class.
|
||||
|
||||
7. Once `jax.Array` and `jax._src.ArrayImpl` have fully landed, remove these temporary Python implementations.
|
||||
|
||||
8. When all is finalized, create a public `jax.typing` module that makes the above types available to users, along with documentation of annotation best practices for code using JAX.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user