I realized it is unnecessary and is no different than listing the parameters
in __init__ with relaxed types (to allow old argument order).
PiperOrigin-RevId: 648696510
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
The main change is to pass the `result_shapes` to the
hlo.CustomCallOp when the output shapes contain dimension
variables. Everything else is already handled by the
support for dynamic bounds sizes for TPU.
Note that this CL only adds limited support for shape
polymorphism: only on TPU, and only when the block
sizes are static.
PiperOrigin-RevId: 648409699
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.
~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.
The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```
It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.
The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.
PiperOrigin-RevId: 647993991
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.
The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.
I have also attempted to add a few docstrings.
For example, if you have 2 `lax.linalg.qr` calls (one on `TPU` and another on `device_host`), we should lower to the `device_host` qr decomposition to CPU.
PiperOrigin-RevId: 647705015
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`
These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.
PiperOrigin-RevId: 647671122
We previously took a logical_and of a mix of boolean and integer inputs, which isn't allowed
under some of the strict dtype modes. This has been causing some JAX tests to fail.
PiperOrigin-RevId: 647669850
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
It seems like nvgpu dialect bakes in a bunch of overly restrictive checks in its verifiers
and doesn't really buy us much in this case. nvvm works just fine.
PiperOrigin-RevId: 647653684
With this kernel we're able to significantly improve the performance
of large head_dim kernels, reaching ~62% utilization for 4k sequence
length and ~71% for 32k.
TODO: the two kernels are quite similar and it should be possible to
collapse them into one
PiperOrigin-RevId: 647597865
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.
memory space is exposed via JAX memories API so it doesn't have to be in the layout API.
Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.
Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.
PiperOrigin-RevId: 647487510