From 8bf3a8686040819b9823ad941d527c6c1728f614 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 14 Dec 2023 09:13:43 -0800 Subject: [PATCH] [roll forward 2] Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md. Reverts b52bcc1639368069075284eefc763f824ca155f1 PiperOrigin-RevId: 590959383 --- CHANGELOG.md | 10 ++++++++++ jax/_src/interpreters/pxla.py | 14 ++------------ tests/pjit_test.py | 11 +++-------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 597366d00..fecafe3c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,16 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.24 +* Changes + * JAX lowering to StableHLO does not depend on physical devices anymore. + If your primitive wraps custom_paritioning or JAX callbacks in the lowering + rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your + primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set. + This is needed because custom_partitioning and JAX callbacks need physical + devices to create `Sharding`s during lowering. + This is a temporary state until we can create `Sharding`s without physical + devices. + ## jaxlib 0.4.24 ## jax 0.4.23 (Dec 13, 2023) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6a19d2844..fdbe4edb8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -102,12 +102,6 @@ MeshAxisName = sharding_impls.MeshAxisName MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = sharding_specs.ShardingSpec -# TODO(yashkatariya): Remove this flag when nvidia's use cases are fixed. -_JAX_REQUIRE_DEVICES_DURING_LOWERING = config.DEFINE_bool( - "jax_require_devices_during_lowering", - True, - help="Forces physical devices to be passed during lowering to stablehlo.") - ### util def identity(x): return x @@ -1977,17 +1971,13 @@ def lower_sharding_computation( semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) - materialized_da = ( - tuple(da_object) - if prim_requires_devices or _JAX_REQUIRE_DEVICES_DURING_LOWERING.value - else None) (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, len(da_object), - materialized_da, donated_invars, name_stack, all_default_mem_kind, - lowering_parameters=lowering_parameters) + tuple(da_object) if prim_requires_devices else None, donated_invars, + name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b6a5ab8d1..73642ee1d 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3802,14 +3802,9 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(out_a, NamedSharding(mesh2, P('y'))) f(b) # lowering cache *hit* - prev_value = pxla._JAX_REQUIRE_DEVICES_DURING_LOWERING.value - try: - jax.config.update('jax_require_devices_during_lowering', False) - with jtu.count_jit_and_pmap_compiles() as count: - g(np.arange(8)) - self.assertEqual(count[0], 1) - finally: - jax.config.update('jax_require_devices_during_lowering', prev_value) + with jtu.count_jit_and_pmap_compiles() as count: + g(np.arange(8)) + self.assertEqual(count[0], 1) def test_lowering_cache_miss_different_devices_and_sharding(self): if jax.device_count() < 4: