mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reverts 255c30303d32e7473262b2e35348175c87e4348f
PiperOrigin-RevId: 674083626
This commit is contained in:
parent
620a686ac4
commit
dffac29e63
24
CHANGELOG.md
24
CHANGELOG.md
@ -12,15 +12,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.33
|
||||
|
||||
* Changes
|
||||
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
|
||||
* array[0] on a pmap result now introduces a reshape (use array[0:1]
|
||||
instead).
|
||||
* The per-shard shape (accessable via jax_array.addressable_shards or
|
||||
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
|
||||
that directly accesses shards accordingly. The rank of the per-shard-shape
|
||||
now matches that of the global shape which is the same behavior as jit.
|
||||
This avoids costly reshapes when passing results from pmap into jit.
|
||||
* Deletion:
|
||||
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||
in 0.4.30 JAX release.
|
||||
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
|
||||
@ -31,12 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* For cross-backend lowering, you can replace
|
||||
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
|
||||
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
|
||||
* `jax.tree.map(f, None, non-None)`, which previously emitted a
|
||||
`DeprecationWarning`, now raises an error in a future version of jax. `None`
|
||||
is only a tree-prefix of itself. To preserve the current behavior, you can
|
||||
ask `jax.tree.map` to treat `None` as a leaf value by writing:
|
||||
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
|
||||
|
||||
|
||||
## jaxlib 0.4.33
|
||||
|
||||
@ -49,6 +35,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
C++ and CUDA code from JAX.
|
||||
|
||||
* Changes
|
||||
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
|
||||
* array[0] on a pmap result now introduces a reshape (use array[0:1]
|
||||
instead).
|
||||
* The per-shard shape (accessable via jax_array.addressable_shards or
|
||||
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
|
||||
that directly accesses shards accordingly. The rank of the per-shard-shape
|
||||
now matches that of the global shape which is the same behavior as jit.
|
||||
This avoids costly reshapes when passing results from pmap into jit.
|
||||
* `jax_enable_memories` flag is set to `True` by default.
|
||||
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
|
||||
See {ref}`python-array-api` for more information.
|
||||
|
@ -24,7 +24,6 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import flatten_util
|
||||
from jax import tree_util
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.tree_util import flatten_one_level, prefix_errors
|
||||
import jax.numpy as jnp
|
||||
@ -396,7 +395,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]),
|
||||
({"a": 1}, {"a": (7,)}, [(7,)]),
|
||||
({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]),
|
||||
(None, None, [])
|
||||
)
|
||||
def testFlattenUpTo(self, tree, xs, expected):
|
||||
_, tree_def = tree_util.tree_flatten(tree)
|
||||
@ -485,11 +483,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
[([1], (2,), {"a": [1]})],
|
||||
re.escape("Custom node type mismatch"),
|
||||
),
|
||||
*(
|
||||
[]
|
||||
if xla_extension_version < 284
|
||||
else [(None, [2], re.escape("Expected None, got [2]."))]
|
||||
),
|
||||
)
|
||||
def testFlattenUpToErrors(self, tree, xs, error):
|
||||
_, tree_def = tree_util.tree_flatten(tree)
|
||||
|
Loading…
x
Reference in New Issue
Block a user