Cache the iteration over jaxpr equation when extracting shardings because majority of the time, it's the same jaxpr so we don't need to evaluate it again and again.

PiperOrigin-RevId: 682148975
This commit is contained in:
Yash Katariya 2024-10-03 20:47:21 -07:00 committed by jax authors
parent be76fb6abf
commit 79ff8e6232
3 changed files with 53 additions and 19 deletions

View File

@ -16,7 +16,7 @@
from __future__ import annotations
import atexit
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Sequence
import contextlib
import dataclasses
import enum
@ -207,6 +207,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
# stablehlo is oblivious of physical devices.
prim_requires_devices_during_lowering: set[core.Primitive] = set()
@util.weakref_lru_cache
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
for eqn in jaxpr.eqns:
if eqn.primitive in prim_requires_devices_during_lowering:
@ -222,23 +223,24 @@ class SourceInfo(NamedTuple):
eqn_name: str
@util.weakref_lru_cache
def get_intermediate_shardings(
jaxpr: core.Jaxpr,
) -> Iterator[tuple[Sharding, SourceInfo]]:
jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map
out = []
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
s = eqn.params['sharding']
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
continue
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield (s, source_info)
out.append((s, source_info))
elif eqn.primitive is pjit.pjit_p:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield from ((i, source_info) for i in eqn.params['in_shardings'])
yield from ((o, source_info) for o in eqn.params['out_shardings'])
out.extend((i, source_info) for i in eqn.params['in_shardings'])
out.extend((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
if not eqn.params['mesh']._is_jax_device_mesh:
continue
@ -246,14 +248,15 @@ def get_intermediate_shardings(
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
yield from ((s, source_info) for s in eqn.params['devices']
out.extend((s, source_info) for s in eqn.params['devices']
if isinstance(s, Sharding) and s.memory_kind is not None)
for subjaxpr in core.subjaxprs(jaxpr):
yield from get_intermediate_shardings(subjaxpr)
out.extend(get_intermediate_shardings(subjaxpr))
return out
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:

View File

@ -19,7 +19,7 @@ import enum
from contextlib import contextmanager
import collections
from collections import namedtuple
from collections.abc import Callable, Sequence, Iterable, Iterator
from collections.abc import Callable, Sequence, Iterable
import dataclasses
from functools import partial, lru_cache, cached_property
import functools
@ -1985,14 +1985,17 @@ def _create_da_object( # pytype: disable=invalid-annotation
return xc.DeviceList(device_assignment)
@weakref_lru_cache
def jaxpr_transfer_mem_kinds(
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
jaxpr: core.Jaxpr) -> Sequence[sharding_impls.TransferToMemoryKind]:
out = [] # type: ignore
for eqn in jaxpr.eqns:
if eqn.primitive is dispatch.device_put_p:
yield from (d for d in eqn.params['devices']
out.extend(d for d in eqn.params['devices']
if isinstance(d, sharding_impls.TransferToMemoryKind))
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_transfer_mem_kinds(subjaxpr)
out.extend(jaxpr_transfer_mem_kinds(subjaxpr))
return out
def are_all_shardings_default_mem_kind(da_object, shardings):
@ -2001,7 +2004,9 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
except:
return True
for i in shardings:
if is_unspecified_or_auto(i) or i.memory_kind is None:
if is_unspecified_or_auto(i):
continue
if i.memory_kind is None: # pytype: disable=attribute-error
continue
if i.memory_kind != default_mem_kind:
return False
@ -2174,7 +2179,7 @@ def lower_sharding_computation(
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
unique_intermediate_shardings = util.stable_unique(
list(dispatch.get_intermediate_shardings(jaxpr)))
dispatch.get_intermediate_shardings(jaxpr))
unique_in_shardings = util.stable_unique(in_shardings)
unique_out_shardings = util.stable_unique(out_shardings)
backend, device_assignment = _get_and_check_device_assignment(
@ -2196,7 +2201,7 @@ def lower_sharding_computation(
da_object = _create_da_object(tuple(device_assignment))
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(unique_in_shardings, unique_out_shardings,

View File

@ -32,6 +32,7 @@ import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import config
from jax._src import dispatch
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
@ -41,6 +42,7 @@ from jax.lax import with_sharding_constraint
from jax._src import prng
from jax.sharding import PartitionSpec as P, Mesh
from jax.experimental import multihost_utils
from jax.experimental.shard_map import shard_map
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding, common_devices_indices_map
@ -5294,6 +5296,30 @@ class UtilTest(jtu.JaxTestCase):
self.assertArraysEqual(w, w_gt)
def test_get_intermediate_shardings(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np.arange(8), s)
@jax.jit
def g(x):
x = with_sharding_constraint(x, s)
return with_sharding_constraint(x, s)
@jax.jit
def f(x, y):
x, y = with_sharding_constraint((x, y), s)
x, y = shard_map(lambda x, y: (x, y), mesh=mesh, in_specs=P('x'),
out_specs=P('x'))(x, y)
x, y = jax.device_put((x, y), s)
x, y = jax.jit(lambda x, y: (x, y), in_shardings=s, out_shardings=s)(x, y)
return g(x), y
jaxpr = f.trace(arr, arr).jaxpr
out = dispatch.get_intermediate_shardings(jaxpr)
self.assertLen(out, 16)
@jtu.with_config(jax_use_shardy_partitioner=True)
class SdyIntegrationTest(jtu.JaxTestCase):