mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
1.6 KiB
1.6 KiB
(pallas-changelog)=
Pallas Changelog
This is the list of changes specific to {class}jax.experimental.pallas
.
For the overall JAX change log see here.
Released with JAX 0.4.31
-
Changes
- {class}
jax.experimental.pallas.BlockSpec
now expectsblock_shape
to be passed beforeindex_map
. The old argument order is deprecated and will be removed in a future release. - Fixed the interpreter mode to work with BlockSpec that involve padding
({jax-issue}
#22275
). Padding in interpreter mode will be with NaN, to help debug out-of-bounds errors, but this behavior is not present when running in custom kernel mode, and should not be depended on.
- {class}
-
Deprecations
-
New Functionality
- Added documentation for BlockSpec: {ref}
pallas_grids_and_blockspecs
. - Improved error messages for the {func}
jax.experimental.pallas.pallas_call
API. - Added lowering rules for TPU for
lax.shift_right_arithmetic
({jax-issue}#22279
) andlax.erf_inv
({jax-issue}#22310
). - Added initial support for shape polymorphism for the Pallas TPU custom kernels
({jax-issue}#22084
). - Added TPU support for checkify. ({jax-issue}
#22480
)
- Added documentation for BlockSpec: {ref}
Released with JAX 0.4.30 (June 18, 2024)
- New Functionality
- Added checkify support for {func}
jax.experimental.pallas.pallas_call
in interpret mode ({jax-issue}#21862
). - Improved support for PRNG keys for TPU kernels ({jax-issue}
#21773
).
- Added checkify support for {func}