From 9cd94019b43709fd34474c9864fd27443559a6db Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 11 Jul 2024 11:55:31 +0300 Subject: [PATCH] [pallas] Added a CHANGELOG for Pallas The CHANGELOG is populated with the changes since June 10th, when JAX 0.4.29 was released. --- CHANGELOG.md | 2 ++ docs/pallas/CHANGELOG.md | 48 ++++++++++++++++++++++++++++++++++++++++ docs/pallas/index.rst | 9 ++++++++ 3 files changed, 59 insertions(+) create mode 100644 docs/pallas/CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b4714b7..052810777 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Change log Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the changes specific to the experimental Pallas APIs, +see {ref}`pallas-changelog`. + +This is the list of changes specific to {class}`jax.experimental.pallas`. +For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). + + + +## Released with JAX 0.4.31 + +* Changes + * {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to + be passed *before* `index_map`. The old argument order is deprecated and + will be removed in a future release. + * Fixed the interpreter mode to work with BlockSpec that involve padding + ({jax-issue}`#22275`). + Padding in interpreter mode will be with NaN, to help debug out-of-bounds + errors, but this behavior is not present when running in custom kernel mode, + and should not be depended on. + + +* Deprecations + + +* New Functionality + * Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`. + * Improved error messages for the {func}`jax.experimental.pallas.pallas_call` + API. + * Added lowering rules for TPU for `lax.shift_right_arithmetic` ({jax-issue}`#22279`) + and `lax.erf_inv` ({jax-issue}`#22310`). + * Added initial support for shape polymorphism for the Pallas TPU custom kernels\ + ({jax-issue}`#22084`). + +## Released with JAX 0.4.30 (June 18, 2024) + +* New Functionality + * Added checkify support for {func}`jax.experimental.pallas.pallas_call` in + interpret mode ({jax-issue}`#21862`). + * Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`). + + + + diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 9fbb560d1..403a8ce9c 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -6,6 +6,10 @@ Pallas is an extension to JAX that enables writing custom kernels for GPU and TP This section contains tutorials, guides and examples for using Pallas. See also the :class:`jax.experimental.pallas` module API documentation. +.. warning:: + Pallas is experimental and is changing frequently. + See the :ref:`pallas-changelog` for the recent changes. + .. toctree:: :caption: Guides :maxdepth: 2 @@ -14,9 +18,14 @@ See also the :class:`jax.experimental.pallas` module API documentation. design grid_blockspec + .. toctree:: :caption: Platform Features :maxdepth: 2 tpu/index +.. toctree:: + :maxdepth: 1 + + CHANGELOG