The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.
PiperOrigin-RevId: 404405186
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.
This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.
PiperOrigin-RevId: 398008511
By wrapping common operators in `jit`, we get a number of benefits:
* `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive.
* `jit` allows us to cache and reuse logic such as broadcasting and type promotion.
One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars.
On my laptop, this benchmark improves from 95us to 4us:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jax.device_put(7)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [3]: %timeit jnp.add(x, x).block_until_ready()
4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
PiperOrigin-RevId: 389871450
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.
This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,
```
def average(x):
return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```
This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.
Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:
```
def dim_as_value(d):
jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```
We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.
One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.
With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.
PiperOrigin-RevId: 385493070
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.
In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.
Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794