From 111f13e2795ea88ec42fb63df04498a1f8461fd0 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 25 Sep 2024 10:13:53 -0700 Subject: [PATCH] Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b PiperOrigin-RevId: 678748794 --- CHANGELOG.md | 23 +++++++++++++++-------- tests/tree_util_test.py | 7 +++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bdcd1c20..9c3c63b6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * This release includes wheels for Python 3.13. Free-threading mode is not yet supported. +* Breaking changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. + * Deprecations * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments with `ndim != 1` are now deprecated, and in the future will result @@ -34,6 +44,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. The argument was only used by `xmap` which was removed in 0.4.31. + * `jax.tree.map(f, None, non-None)`, which previously emitted a + `DeprecationWarning`, now raises an error in a future version of jax. `None` + is only a tree-prefix of itself. To preserve the current behavior, you can + ask `jax.tree.map` to treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. * Bug fixes * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs @@ -62,14 +77,6 @@ See the 0.4.33 release notes for more details. C++ and CUDA code from JAX. * Changes - * `jax_pmap_no_rank_reduction` flag is set to `True` by default. - * array[0] on a pmap result now introduces a reshape (use array[0:1] - instead). - * The per-shard shape (accessable via jax_array.addressable_shards or - jax_array.addressable_data(0)) now has a leading (1, ...). Update code - that directly accesses shards accordingly. The rank of the per-shard-shape - now matches that of the global shape which is the same behavior as jit. - This avoids costly reshapes when passing results from pmap into jit. * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index f8792a263..c5342a993 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import jax from jax import flatten_util from jax import tree_util +from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -395,6 +396,7 @@ class TreeTest(jtu.JaxTestCase): ({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]), ({"a": 1}, {"a": (7,)}, [(7,)]), ({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]), + (None, None, []) ) def testFlattenUpTo(self, tree, xs, expected): _, tree_def = tree_util.tree_flatten(tree) @@ -483,6 +485,11 @@ class TreeTest(jtu.JaxTestCase): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), + *( + [] + if xla_extension_version < 288 + else [(None, [2], re.escape("Expected None, got [2]."))] + ), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree)