This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.
Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.
PiperOrigin-RevId: 654381378
There are 2 reasons for doing this:
* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.
* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.
Also fixes: https://github.com/google/jax/issues/17422
PiperOrigin-RevId: 650621659
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.
To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
- Make it clear that this only impacts FP32 computations.
- Remove incorrect aliases, eg. 'bfloat16' for default. This does not do as advertised for GPU.
- explicitly specify GPU and TPU device-dependent behaviour.
PiperOrigin-RevId: 647342888
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.