mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a flag jax_require_devices_during_lowering
to control if physical devices are passed during lowering to stablehlo. This is temporary to unblock nvidia.
PiperOrigin-RevId: 590318918
This commit is contained in:
parent
1851447b3c
commit
f210b0f95a
@ -191,7 +191,6 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
||||
for eqn in jaxpr.eqns:
|
||||
@ -207,7 +206,6 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||
# stablehlo is oblivious of physical devices.
|
||||
prim_requires_devices_during_lowering: set[core.Primitive] = set()
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr):
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in prim_requires_devices_during_lowering:
|
||||
|
@ -102,6 +102,11 @@ MeshAxisName = sharding_impls.MeshAxisName
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
# TODO(yashkatariya): Remove this flag when nvidia's use cases are fixed.
|
||||
_JAX_REQUIRE_DEVICES_DURING_LOWERING = config.DEFINE_bool(
|
||||
"jax_require_devices_during_lowering",
|
||||
True,
|
||||
help="Forces physical devices to be passed during lowering to stablehlo.")
|
||||
|
||||
### util
|
||||
|
||||
@ -1970,12 +1975,17 @@ def lower_sharding_computation(
|
||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
materialized_da = (
|
||||
tuple(da_object)
|
||||
if prim_requires_devices or _JAX_REQUIRE_DEVICES_DURING_LOWERING.value
|
||||
else None)
|
||||
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters)
|
||||
materialized_da, donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
|
@ -3859,9 +3859,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
|
||||
f(b) # lowering cache *hit*
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
prev_value = pxla._JAX_REQUIRE_DEVICES_DURING_LOWERING.value
|
||||
try:
|
||||
jax.config.update('jax_require_devices_during_lowering', False)
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
finally:
|
||||
jax.config.update('jax_require_devices_during_lowering', prev_value)
|
||||
|
||||
def test_lowering_cache_miss_different_devices_and_sharding(self):
|
||||
if jax.device_count() < 4:
|
||||
|
Loading…
x
Reference in New Issue
Block a user