mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
This commit is contained in:
parent
b49d8b2615
commit
111f13e279
23
CHANGELOG.md
23
CHANGELOG.md
@ -16,6 +16,16 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
* This release includes wheels for Python 3.13. Free-threading mode is not yet
|
* This release includes wheels for Python 3.13. Free-threading mode is not yet
|
||||||
supported.
|
supported.
|
||||||
|
|
||||||
|
* Breaking changes
|
||||||
|
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
|
||||||
|
* array[0] on a pmap result now introduces a reshape (use array[0:1]
|
||||||
|
instead).
|
||||||
|
* The per-shard shape (accessable via jax_array.addressable_shards or
|
||||||
|
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
|
||||||
|
that directly accesses shards accordingly. The rank of the per-shard-shape
|
||||||
|
now matches that of the global shape which is the same behavior as jit.
|
||||||
|
This avoids costly reshapes when passing results from pmap into jit.
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
|
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike
|
||||||
arguments with `ndim != 1` are now deprecated, and in the future will result
|
arguments with `ndim != 1` are now deprecated, and in the future will result
|
||||||
@ -34,6 +44,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
|
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
|
||||||
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
|
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
|
||||||
The argument was only used by `xmap` which was removed in 0.4.31.
|
The argument was only used by `xmap` which was removed in 0.4.31.
|
||||||
|
* `jax.tree.map(f, None, non-None)`, which previously emitted a
|
||||||
|
`DeprecationWarning`, now raises an error in a future version of jax. `None`
|
||||||
|
is only a tree-prefix of itself. To preserve the current behavior, you can
|
||||||
|
ask `jax.tree.map` to treat `None` as a leaf value by writing:
|
||||||
|
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
|
||||||
|
|
||||||
* Bug fixes
|
* Bug fixes
|
||||||
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
|
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
|
||||||
@ -62,14 +77,6 @@ See the 0.4.33 release notes for more details.
|
|||||||
C++ and CUDA code from JAX.
|
C++ and CUDA code from JAX.
|
||||||
|
|
||||||
* Changes
|
* Changes
|
||||||
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
|
|
||||||
* array[0] on a pmap result now introduces a reshape (use array[0:1]
|
|
||||||
instead).
|
|
||||||
* The per-shard shape (accessable via jax_array.addressable_shards or
|
|
||||||
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
|
|
||||||
that directly accesses shards accordingly. The rank of the per-shard-shape
|
|
||||||
now matches that of the global shape which is the same behavior as jit.
|
|
||||||
This avoids costly reshapes when passing results from pmap into jit.
|
|
||||||
* `jax_enable_memories` flag is set to `True` by default.
|
* `jax_enable_memories` flag is set to `True` by default.
|
||||||
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
|
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
|
||||||
See {ref}`python-array-api` for more information.
|
See {ref}`python-array-api` for more information.
|
||||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
|||||||
import jax
|
import jax
|
||||||
from jax import flatten_util
|
from jax import flatten_util
|
||||||
from jax import tree_util
|
from jax import tree_util
|
||||||
|
from jax._src.lib import xla_extension_version
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.tree_util import flatten_one_level, prefix_errors
|
from jax._src.tree_util import flatten_one_level, prefix_errors
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -395,6 +396,7 @@ class TreeTest(jtu.JaxTestCase):
|
|||||||
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
|
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
|
||||||
({"a": 1}, {"a": (7,)}, [(7,)]),
|
({"a": 1}, {"a": (7,)}, [(7,)]),
|
||||||
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
|
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
|
||||||
|
(None, None, [])
|
||||||
)
|
)
|
||||||
def testFlattenUpTo(self, tree, xs, expected):
|
def testFlattenUpTo(self, tree, xs, expected):
|
||||||
_, tree_def = tree_util.tree_flatten(tree)
|
_, tree_def = tree_util.tree_flatten(tree)
|
||||||
@ -483,6 +485,11 @@ class TreeTest(jtu.JaxTestCase):
|
|||||||
[([1], (2,), {"a": [1]})],
|
[([1], (2,), {"a": [1]})],
|
||||||
re.escape("Custom node type mismatch"),
|
re.escape("Custom node type mismatch"),
|
||||||
),
|
),
|
||||||
|
*(
|
||||||
|
[]
|
||||||
|
if xla_extension_version < 288
|
||||||
|
else [(None, [2], re.escape("Expected None, got [2]."))]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def testFlattenUpToErrors(self, tree, xs, error):
|
def testFlattenUpToErrors(self, tree, xs, error):
|
||||||
_, tree_def = tree_util.tree_flatten(tree)
|
_, tree_def = tree_util.tree_flatten(tree)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user