mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Set jax_enable_memories
flag to True
by default
PiperOrigin-RevId: 660579462
This commit is contained in:
parent
7efca0490f
commit
be53ee10b1
@ -13,6 +13,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
## jax 0.4.32
|
||||
|
||||
* Changes
|
||||
* `jax_enable_memories` flag is set to `True` by default.
|
||||
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
|
||||
See {ref}`python-array-api` for more information.
|
||||
* Computations on the CPU backend may now be dispatched asynchronously in
|
||||
|
@ -1054,7 +1054,7 @@ def _update_jax_memories_thread_local(val):
|
||||
|
||||
enable_memories = bool_state(
|
||||
'jax_enable_memories',
|
||||
default=False,
|
||||
default=True,
|
||||
upgrade=True,
|
||||
update_global_hook=_update_jax_memories_global,
|
||||
update_thread_local_hook=_update_jax_memories_thread_local,
|
||||
|
@ -565,12 +565,6 @@ mlir.register_lowering(
|
||||
|
||||
|
||||
def _common_device_put_lowering(ctx, *xs, devices, srcs):
|
||||
for device in devices:
|
||||
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
f" platforms {ctx.module_context.platforms}")
|
||||
return xs
|
||||
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
||||
|
||||
|
@ -579,6 +579,8 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data, np_inp[s.index])
|
||||
|
||||
def test_deserialization_with_int4(self):
|
||||
if jtu.test_device_matches(['gpu']):
|
||||
self.skipTest("Fails on GPU. Enable after it's fixed")
|
||||
dtype = jnp.int4
|
||||
shape = (8, 2)
|
||||
arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
|
||||
|
@ -3824,19 +3824,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
' manager.*SingleDeviceSharding'):
|
||||
jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr)
|
||||
|
||||
@jtu.skip_on_devices("tpu", "gpu")
|
||||
def test_device_put_memory_kind_not_tpu_gpu(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Passing memory_kind to device_put via Shardings is not supported on'
|
||||
' platform.*'):
|
||||
f(jnp.arange(8))
|
||||
|
||||
def test_no_output_multiple_devices(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user