diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fbe947fa..038c0131a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.32 * Changes + * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. * Computations on the CPU backend may now be dispatched asynchronously in diff --git a/jax/_src/config.py b/jax/_src/config.py index 5b4226f8f..46b327327 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1054,7 +1054,7 @@ def _update_jax_memories_thread_local(val): enable_memories = bool_state( 'jax_enable_memories', - default=False, + default=True, upgrade=True, update_global_hook=_update_jax_memories_global, update_thread_local_hook=_update_jax_memories_thread_local, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e7fd8657c..6c0b46077 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -565,12 +565,6 @@ mlir.register_lowering( def _common_device_put_lowering(ctx, *xs, devices, srcs): - for device in devices: - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 04a64fe55..2712e2b4a 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -579,6 +579,8 @@ class CheckpointTest(jtu.JaxTestCase): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): + if jtu.test_device_matches(['gpu']): + self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 shape = (8, 2) arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 516d1fec7..df87fed4b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3824,19 +3824,6 @@ class ArrayPjitTest(jtu.JaxTestCase): ' manager.*SingleDeviceSharding'): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) - @jtu.skip_on_devices("tpu", "gpu") - def test_device_put_memory_kind_not_tpu_gpu(self): - @jax.jit - def f(x): - y = x * 2 - return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host')) - - with self.assertRaisesRegex( - NotImplementedError, - 'Passing memory_kind to device_put via Shardings is not supported on' - ' platform.*'): - f(jnp.arange(8)) - def test_no_output_multiple_devices(self): mesh = jtu.create_global_mesh((2,), ('x',))