1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.

PiperOrigin-RevId: 484062717
This commit is contained in:
Peter Hawkins 2022-10-26 13:55:32 -07:00 committed by jax authors
parent a08ced86f3
commit bf21391248
3 changed files with 19 additions and 10 deletions
CHANGELOG.md
jax/interpreters
tests

@ -7,14 +7,18 @@ Remember to align the itemized text with the first line of an item within a list
-->
## jax 0.3.24
* Changes
* JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
used to limit the number of cache entries written to the persistent
cache. By default, computations with 6 or more instructions will be cached.
* Changes
* Added {func}`jax.scipy.stats.mode`.
* Breaking Changes
* The default device order used by `pmap` on TPU if no order is specified now
matches `jax.devices()` for single-process jobs. Previously the
two orderings differed, which could lead to unnecessary copies or
out-of-memory errors. Requiring the orderings to agree simplifies matters.
* Breaking Changes
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)

@ -1538,11 +1538,10 @@ class PmapExecutable(stages.XlaExecutable):
xb.device_count(pci.backend),
replicas.num_global_replicas,
parts.num_partitions))
# On a single host, we use the platform's default device assignment to
# potentially take advantage of device locality. On multiple hosts, the
# default device assignment may interleave different hosts' replicas,
# violating pmap's semantics where data is sharded across replicas in
# row-major order. Instead, manually create a device assignment that ensures
# On a single host, we simply grab the first N devices from jax.devices().
# In the single host case, we want the default device order of pmap to
# match jax.devices().
# On multiple hosts, we create a default device assignment that ensures
# each host is responsible for a continguous set of replicas.
if shards.num_global_shards > shards.num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above
@ -1550,8 +1549,7 @@ class PmapExecutable(stages.XlaExecutable):
devices = [d for process_index in range(xb.process_count(pci.backend))
for d in xb.local_devices(process_index, pci.backend)]
else:
devices = xb.get_backend(pci.backend).get_default_device_assignment(
replicas.num_global_replicas, parts.num_partitions)
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
else:
if shards.num_local_shards != len(pci.local_devices):
local_devices_str = ", ".join(map(str, pci.local_devices))
@ -1577,7 +1575,7 @@ class PmapExecutable(stages.XlaExecutable):
# get_default_device_assignment() returns 2D assignment, caller may have
# provided 1D list of devices).
# Convert to 2D in case it's 1D and we have > 1 partitions.
device_assignment = np.array(devices).reshape(
device_assignment: np.ndarray = np.array(devices).reshape(
(replicas.num_global_replicas, parts.num_partitions))
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
# non-partitioned workloads, so disable unless needed.

@ -181,6 +181,13 @@ class PythonPmapTest(jtu.JaxTestCase):
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDefaultDeviceOrdering(self):
# Users rely on the fact that the default order of jax.devices() matches
# the default order of pmap for single-host jobs.
device_order = jax.devices()
pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding
self.assertListEqual(device_order, pmap_sharding.devices.tolist())
def testLowerCompile(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)