platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.
The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.
PiperOrigin-RevId: 644051624
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.
We also add a backwards compatibility test.
PiperOrigin-RevId: 640557581
When `aggregate_to_topk=True` (the default) the output reduction
dimension size is `k`, and we do not need to invoke `ApproxtopKReductionOutputSize`.
Add a set of test cases for shape polymorphism for approx_top_k.
The case when `aggregate_to_topk=True` and `k` is symbolic will
be fixed separately.
The case when `aggregate_to_topk=False` raises a clearer NotImplementedError.
As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an
uninformative `ZeroDivisionError`. This PR adds a check which raises a
`ValueError` for `unroll<=0`.
A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules
PiperOrigin-RevId: 633297839
This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals.
It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later.
1. `jit(f, out_shardings=s_host)`
2. ```
@jax.jit
def f(x):
return jax.device_put(x, s_host)
```
PiperOrigin-RevId: 632621025
Rely on XLA decomposition.
# JAX GPU microbenchmarks
285us for cumsum over 1e8 elements
449us for cumsum over 1e8 elements.
# JAX CPU microbenchmarks:
1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements
PiperOrigin-RevId: 632547166