Reverts 255c30303d32e7473262b2e35348175c87e4348f

PiperOrigin-RevId: 674083626
This commit is contained in:
Peter Hawkins 2024-09-12 18:13:51 -07:00 committed by jax authors
parent 620a686ac4
commit dffac29e63
2 changed files with 9 additions and 22 deletions

View File

@ -12,15 +12,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.33
* Changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
* array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
* The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
@ -31,12 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
* `jax.tree.map(f, None, non-None)`, which previously emitted a
`DeprecationWarning`, now raises an error in a future version of jax. `None`
is only a tree-prefix of itself. To preserve the current behavior, you can
ask `jax.tree.map` to treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
## jaxlib 0.4.33
@ -49,6 +35,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
C++ and CUDA code from JAX.
* Changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
* array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
* The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
* `jax_enable_memories` flag is set to `True` by default.
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.

View File

@ -24,7 +24,6 @@ from absl.testing import parameterized
import jax
from jax import flatten_util
from jax import tree_util
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax._src.tree_util import flatten_one_level, prefix_errors
import jax.numpy as jnp
@ -396,7 +395,6 @@ class TreeTest(jtu.JaxTestCase):
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
({"a": 1}, {"a": (7,)}, [(7,)]),
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
(None, None, [])
)
def testFlattenUpTo(self, tree, xs, expected):
_, tree_def = tree_util.tree_flatten(tree)
@ -485,11 +483,6 @@ class TreeTest(jtu.JaxTestCase):
[([1], (2,), {"a": [1]})],
re.escape("Custom node type mismatch"),
),
*(
[]
if xla_extension_version < 284
else [(None, [2], re.escape("Expected None, got [2]."))]
),
)
def testFlattenUpToErrors(self, tree, xs, error):
_, tree_def = tree_util.tree_flatten(tree)