- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
narrative content in the docstring but it should go somewhere else.
This is useful for building higher level array libraries around JAX, because it
makes it possible to override operations like `jax_array + other`.
I think I covered all the array types that JAX should be able to handle:
- Python builtin numbers int, float and complex
- NumPy scalars
- NumPy arrays
- JAX array types and tracers
Did I miss anything? Maybe bfloat16 scalars?
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes#1431
See https://github.com/google/jax/pull/1668 for more.
* Implement jax.numpy.nonzero.
* Implement the one-argument form of np.where.
* Fix output type and error message.
* Add lax_description strings to where and nonzero.
* Change jax.numpy scalar types to return 0D JAX arrays rather than NumPy scalars when instantiated.
jax.numpy and numpy have slightly different promotion behaviors. For consistency with JAX arrays, we would like the result of, say, `jax.numpy.int32(7)` to have the same promotion behavior as `jax.numpy.array(7, dtype=jax.numpy.int32)`. The easiest way to do this is to have the jax.numpy scalars return 0D arrays when instantiated; the difference between NumPy scalars and arrays is not a fundamental one and we do not need to distinguish between them in JAX.
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.
Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.
Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').
Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
This change prepares for switching the default types in JAX's NumPy to be 32-bit types. In particular, it makes the JAX tests pass in the event that jax.numpy.int_, jax.numpy.float_, and jax.numpy.complex_ are defined to be 32-bit types instead of 64-bit types, but does not yet change the defaults.
* Fixes to type handling.
* Specify exactly which types to test in lax_test.py, rather than relying on non-x64 mode to squash unsupported types.
* Fix some excessive promotions in jax.numpy.
* Fix some buggy RNGs that returned the wrong type for complex inputs.
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
* Remove np._promote_args_like, and replace its users with a newer _promote_args_inexact.
We no longer want to promote arguments exactly like NumPy; NumPy has a bad habit of promoting integer types to float64, whereas we want to promote to jax.numpy.float_, which may not be the same.
For example
```
import numpy as onp
onp.sin(3).dtype
```
returns `onp.dtype(float64)`.
However, it turns out that all of the users of `_promote_args_like` are using it for exactly one behavior: promoting integers or bools to inexact types like float. Implement that behavior explicitly rather than mimicing the behavior of NumPy.
* Relax test tolerances.