mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Cache the iteration over jaxpr equation when extracting shardings because majority of the time, it's the same jaxpr so we don't need to evaluate it again and again.
PiperOrigin-RevId: 682148975
This commit is contained in:
parent
be76fb6abf
commit
79ff8e6232
@ -16,7 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
@ -207,6 +207,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||
# stablehlo is oblivious of physical devices.
|
||||
prim_requires_devices_during_lowering: set[core.Primitive] = set()
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in prim_requires_devices_during_lowering:
|
||||
@ -222,23 +223,24 @@ class SourceInfo(NamedTuple):
|
||||
eqn_name: str
|
||||
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def get_intermediate_shardings(
|
||||
jaxpr: core.Jaxpr,
|
||||
) -> Iterator[tuple[Sharding, SourceInfo]]:
|
||||
jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]:
|
||||
from jax._src import pjit
|
||||
from jax.experimental import shard_map
|
||||
|
||||
out = []
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
s = eqn.params['sharding']
|
||||
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
||||
continue
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield (s, source_info)
|
||||
out.append((s, source_info))
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield from ((i, source_info) for i in eqn.params['in_shardings'])
|
||||
yield from ((o, source_info) for o in eqn.params['out_shardings'])
|
||||
out.extend((i, source_info) for i in eqn.params['in_shardings'])
|
||||
out.extend((o, source_info) for o in eqn.params['out_shardings'])
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
if not eqn.params['mesh']._is_jax_device_mesh:
|
||||
continue
|
||||
@ -246,14 +248,15 @@ def get_intermediate_shardings(
|
||||
def _names_to_pspec(names):
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
||||
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
||||
out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
||||
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
|
||||
elif eqn.primitive is device_put_p:
|
||||
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
||||
yield from ((s, source_info) for s in eqn.params['devices']
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None)
|
||||
out.extend((s, source_info) for s in eqn.params['devices']
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None)
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from get_intermediate_shardings(subjaxpr)
|
||||
out.extend(get_intermediate_shardings(subjaxpr))
|
||||
return out
|
||||
|
||||
|
||||
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
||||
|
@ -19,7 +19,7 @@ import enum
|
||||
from contextlib import contextmanager
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Sequence, Iterable, Iterator
|
||||
from collections.abc import Callable, Sequence, Iterable
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import functools
|
||||
@ -1985,14 +1985,17 @@ def _create_da_object( # pytype: disable=invalid-annotation
|
||||
return xc.DeviceList(device_assignment)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def jaxpr_transfer_mem_kinds(
|
||||
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
|
||||
jaxpr: core.Jaxpr) -> Sequence[sharding_impls.TransferToMemoryKind]:
|
||||
out = [] # type: ignore
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is dispatch.device_put_p:
|
||||
yield from (d for d in eqn.params['devices']
|
||||
if isinstance(d, sharding_impls.TransferToMemoryKind))
|
||||
out.extend(d for d in eqn.params['devices']
|
||||
if isinstance(d, sharding_impls.TransferToMemoryKind))
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_transfer_mem_kinds(subjaxpr)
|
||||
out.extend(jaxpr_transfer_mem_kinds(subjaxpr))
|
||||
return out
|
||||
|
||||
|
||||
def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
@ -2001,7 +2004,9 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
except:
|
||||
return True
|
||||
for i in shardings:
|
||||
if is_unspecified_or_auto(i) or i.memory_kind is None:
|
||||
if is_unspecified_or_auto(i):
|
||||
continue
|
||||
if i.memory_kind is None: # pytype: disable=attribute-error
|
||||
continue
|
||||
if i.memory_kind != default_mem_kind:
|
||||
return False
|
||||
@ -2174,7 +2179,7 @@ def lower_sharding_computation(
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
unique_intermediate_shardings = util.stable_unique(
|
||||
list(dispatch.get_intermediate_shardings(jaxpr)))
|
||||
dispatch.get_intermediate_shardings(jaxpr))
|
||||
unique_in_shardings = util.stable_unique(in_shardings)
|
||||
unique_out_shardings = util.stable_unique(out_shardings)
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
@ -2196,7 +2201,7 @@ def lower_sharding_computation(
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(unique_in_shardings, unique_out_shardings,
|
||||
|
@ -32,6 +32,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
@ -41,6 +42,7 @@ from jax.lax import with_sharding_constraint
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P, Mesh
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding, common_devices_indices_map
|
||||
@ -5294,6 +5296,30 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertArraysEqual(w, w_gt)
|
||||
|
||||
def test_get_intermediate_shardings(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
arr = jax.device_put(np.arange(8), s)
|
||||
|
||||
@jax.jit
|
||||
def g(x):
|
||||
x = with_sharding_constraint(x, s)
|
||||
return with_sharding_constraint(x, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
x, y = with_sharding_constraint((x, y), s)
|
||||
x, y = shard_map(lambda x, y: (x, y), mesh=mesh, in_specs=P('x'),
|
||||
out_specs=P('x'))(x, y)
|
||||
x, y = jax.device_put((x, y), s)
|
||||
x, y = jax.jit(lambda x, y: (x, y), in_shardings=s, out_shardings=s)(x, y)
|
||||
return g(x), y
|
||||
|
||||
jaxpr = f.trace(arr, arr).jaxpr
|
||||
out = dispatch.get_intermediate_shardings(jaxpr)
|
||||
self.assertLen(out, 16)
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user