mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
PiperOrigin-RevId: 484062717
This commit is contained in:
parent
a08ced86f3
commit
bf21391248
@ -7,14 +7,18 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
-->
|
||||
|
||||
## jax 0.3.24
|
||||
* Changes
|
||||
* JAX should be faster to import. We now import scipy lazily, which accounted
|
||||
for a significant fraction of JAX's import time.
|
||||
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_INSTRUCTION_COUNT=$N` can be
|
||||
used to limit the number of cache entries written to the persistent
|
||||
cache. By default, computations with 6 or more instructions will be cached.
|
||||
* Changes
|
||||
* Added {func}`jax.scipy.stats.mode`.
|
||||
* Breaking Changes
|
||||
* The default device order used by `pmap` on TPU if no order is specified now
|
||||
matches `jax.devices()` for single-process jobs. Previously the
|
||||
two orderings differed, which could lead to unnecessary copies or
|
||||
out-of-memory errors. Requiring the orderings to agree simplifies matters.
|
||||
* Breaking Changes
|
||||
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
|
||||
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)
|
||||
|
||||
|
@ -1538,11 +1538,10 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
xb.device_count(pci.backend),
|
||||
replicas.num_global_replicas,
|
||||
parts.num_partitions))
|
||||
# On a single host, we use the platform's default device assignment to
|
||||
# potentially take advantage of device locality. On multiple hosts, the
|
||||
# default device assignment may interleave different hosts' replicas,
|
||||
# violating pmap's semantics where data is sharded across replicas in
|
||||
# row-major order. Instead, manually create a device assignment that ensures
|
||||
# On a single host, we simply grab the first N devices from jax.devices().
|
||||
# In the single host case, we want the default device order of pmap to
|
||||
# match jax.devices().
|
||||
# On multiple hosts, we create a default device assignment that ensures
|
||||
# each host is responsible for a continguous set of replicas.
|
||||
if shards.num_global_shards > shards.num_local_shards:
|
||||
# TODO(skye): use a locality-aware assignment that satisfies the above
|
||||
@ -1550,8 +1549,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
devices = [d for process_index in range(xb.process_count(pci.backend))
|
||||
for d in xb.local_devices(process_index, pci.backend)]
|
||||
else:
|
||||
devices = xb.get_backend(pci.backend).get_default_device_assignment(
|
||||
replicas.num_global_replicas, parts.num_partitions)
|
||||
devices = xb.local_devices(backend=pci.backend)[:shards.num_local_shards]
|
||||
else:
|
||||
if shards.num_local_shards != len(pci.local_devices):
|
||||
local_devices_str = ", ".join(map(str, pci.local_devices))
|
||||
@ -1577,7 +1575,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
# get_default_device_assignment() returns 2D assignment, caller may have
|
||||
# provided 1D list of devices).
|
||||
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
||||
device_assignment = np.array(devices).reshape(
|
||||
device_assignment: np.ndarray = np.array(devices).reshape(
|
||||
(replicas.num_global_replicas, parts.num_partitions))
|
||||
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
|
||||
# non-partitioned workloads, so disable unless needed.
|
||||
|
@ -181,6 +181,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDefaultDeviceOrdering(self):
|
||||
# Users rely on the fact that the default order of jax.devices() matches
|
||||
# the default order of pmap for single-host jobs.
|
||||
device_order = jax.devices()
|
||||
pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding
|
||||
self.assertListEqual(device_order, pmap_sharding.devices.tolist())
|
||||
|
||||
def testLowerCompile(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user