From 79ff8e6232da195c3ca39784d852871e6b1a9c72 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 3 Oct 2024 20:47:21 -0700 Subject: [PATCH] Cache the iteration over jaxpr equation when extracting shardings because majority of the time, it's the same jaxpr so we don't need to evaluate it again and again. PiperOrigin-RevId: 682148975 --- jax/_src/dispatch.py | 25 ++++++++++++++----------- jax/_src/interpreters/pxla.py | 21 +++++++++++++-------- tests/pjit_test.py | 26 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e8de7a350..7874a79cf 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -16,7 +16,7 @@ from __future__ import annotations import atexit -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Sequence import contextlib import dataclasses import enum @@ -207,6 +207,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: # stablehlo is oblivious of physical devices. prim_requires_devices_during_lowering: set[core.Primitive] = set() +@util.weakref_lru_cache def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: for eqn in jaxpr.eqns: if eqn.primitive in prim_requires_devices_during_lowering: @@ -222,23 +223,24 @@ class SourceInfo(NamedTuple): eqn_name: str +@util.weakref_lru_cache def get_intermediate_shardings( - jaxpr: core.Jaxpr, -) -> Iterator[tuple[Sharding, SourceInfo]]: + jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: from jax._src import pjit from jax.experimental import shard_map + out = [] for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: s = eqn.params['sharding'] if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (s, source_info) + out.append((s, source_info)) elif eqn.primitive is pjit.pjit_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield from ((i, source_info) for i in eqn.params['in_shardings']) - yield from ((o, source_info) for o in eqn.params['out_shardings']) + out.extend((i, source_info) for i in eqn.params['in_shardings']) + out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: if not eqn.params['mesh']._is_jax_device_mesh: continue @@ -246,14 +248,15 @@ def get_intermediate_shardings( def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) - yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) - for names in [*eqn.params['in_names'], *eqn.params['out_names']]) + out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) + for names in [*eqn.params['in_names'], *eqn.params['out_names']]) elif eqn.primitive is device_put_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield from ((s, source_info) for s in eqn.params['devices'] - if isinstance(s, Sharding) and s.memory_kind is not None) + out.extend((s, source_info) for s in eqn.params['devices'] + if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): - yield from get_intermediate_shardings(subjaxpr) + out.extend(get_intermediate_shardings(subjaxpr)) + return out def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 650f3fb6c..d531727f4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -19,7 +19,7 @@ import enum from contextlib import contextmanager import collections from collections import namedtuple -from collections.abc import Callable, Sequence, Iterable, Iterator +from collections.abc import Callable, Sequence, Iterable import dataclasses from functools import partial, lru_cache, cached_property import functools @@ -1985,14 +1985,17 @@ def _create_da_object( # pytype: disable=invalid-annotation return xc.DeviceList(device_assignment) +@weakref_lru_cache def jaxpr_transfer_mem_kinds( - jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: + jaxpr: core.Jaxpr) -> Sequence[sharding_impls.TransferToMemoryKind]: + out = [] # type: ignore for eqn in jaxpr.eqns: if eqn.primitive is dispatch.device_put_p: - yield from (d for d in eqn.params['devices'] - if isinstance(d, sharding_impls.TransferToMemoryKind)) + out.extend(d for d in eqn.params['devices'] + if isinstance(d, sharding_impls.TransferToMemoryKind)) for subjaxpr in core.subjaxprs(jaxpr): - yield from jaxpr_transfer_mem_kinds(subjaxpr) + out.extend(jaxpr_transfer_mem_kinds(subjaxpr)) + return out def are_all_shardings_default_mem_kind(da_object, shardings): @@ -2001,7 +2004,9 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i) or i.memory_kind is None: + if is_unspecified_or_auto(i): + continue + if i.memory_kind is None: # pytype: disable=attribute-error continue if i.memory_kind != default_mem_kind: return False @@ -2174,7 +2179,7 @@ def lower_sharding_computation( # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. unique_intermediate_shardings = util.stable_unique( - list(dispatch.get_intermediate_shardings(jaxpr))) + dispatch.get_intermediate_shardings(jaxpr)) unique_in_shardings = util.stable_unique(in_shardings) unique_out_shardings = util.stable_unique(out_shardings) backend, device_assignment = _get_and_check_device_assignment( @@ -2196,7 +2201,7 @@ def lower_sharding_computation( da_object = _create_da_object(tuple(device_assignment)) - transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) + transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, it.chain(unique_in_shardings, unique_out_shardings, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bda72f651..0dc5284d5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -32,6 +32,7 @@ import jax import jax.numpy as jnp from jax._src import core from jax._src import config +from jax._src import dispatch from jax._src import test_util as jtu from jax import dtypes from jax import stages @@ -41,6 +42,7 @@ from jax.lax import with_sharding_constraint from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils +from jax.experimental.shard_map import shard_map from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map @@ -5294,6 +5296,30 @@ class UtilTest(jtu.JaxTestCase): self.assertArraysEqual(w, w_gt) + def test_get_intermediate_shardings(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np.arange(8), s) + + @jax.jit + def g(x): + x = with_sharding_constraint(x, s) + return with_sharding_constraint(x, s) + + @jax.jit + def f(x, y): + x, y = with_sharding_constraint((x, y), s) + x, y = shard_map(lambda x, y: (x, y), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(x, y) + x, y = jax.device_put((x, y), s) + x, y = jax.jit(lambda x, y: (x, y), in_shardings=s, out_shardings=s)(x, y) + return g(x), y + + jaxpr = f.trace(arr, arr).jaxpr + out = dispatch.get_intermediate_shardings(jaxpr) + self.assertLen(out, 16) + + @jtu.with_config(jax_use_shardy_partitioner=True) class SdyIntegrationTest(jtu.JaxTestCase):