mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add ShardedDeviceTuple constant handler, fixes #1062
This commit is contained in:
parent
27b46e6615
commit
3f9c001c33
@ -399,6 +399,8 @@ batching.pytype_aval_mappings[ShardedDeviceTuple] = op.attrgetter('aval')
|
||||
xla.canonicalize_dtype_handlers[ShardedDeviceTuple] = \
|
||||
xla.canonicalize_dtype_handlers[xla.DeviceTuple]
|
||||
|
||||
xb.register_constant_handler(ShardedDeviceTuple, xla._device_tuple_constant_handler)
|
||||
|
||||
|
||||
class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
"""A ShardedDeviceArray is an ndarray sharded across devices.
|
||||
|
@ -17,6 +17,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from functools import partial
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as onp
|
||||
@ -29,6 +30,7 @@ from jax.core import Primitive, pack, JaxTuple
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters.xla import DeviceArray, DeviceTuple
|
||||
from jax.abstract_arrays import concretization_err_msg
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
@ -899,5 +901,26 @@ class APITest(jtu.JaxTestCase):
|
||||
jtu.check_raises_regexp(lambda: api.jit(3), TypeError,
|
||||
"Expected a callable value.*")
|
||||
|
||||
def test_issue_1062(self):
|
||||
# code from https://github.com/google/jax/issues/1062 @shoyer
|
||||
# this tests, among other things, whether ShardedDeviceTuple constants work
|
||||
device_count = xb.device_count()
|
||||
|
||||
@jit
|
||||
def multi_step(state, count):
|
||||
return lax.fori_loop(0, count, lambda i, s: s, state)
|
||||
|
||||
@jit
|
||||
def multi_step_pmap(state, count=2):
|
||||
@partial(api.pmap, axis_name='x')
|
||||
def pmapped_multi_step(state):
|
||||
return multi_step(state, count)
|
||||
|
||||
return pmapped_multi_step(state)
|
||||
|
||||
u = np.ones((device_count, 100))
|
||||
u_final = multi_step_pmap(u) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user