add ShardedDeviceTuple constant handler, fixes #1062

This commit is contained in:
Matthew Johnson 2019-07-24 21:45:56 +03:00
parent 27b46e6615
commit 3f9c001c33
2 changed files with 25 additions and 0 deletions

View File

@ -399,6 +399,8 @@ batching.pytype_aval_mappings[ShardedDeviceTuple] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ShardedDeviceTuple] = \
xla.canonicalize_dtype_handlers[xla.DeviceTuple]
xb.register_constant_handler(ShardedDeviceTuple, xla._device_tuple_constant_handler)
class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
"""A ShardedDeviceArray is an ndarray sharded across devices.

View File

@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import collections
from functools import partial
from absl.testing import absltest
import numpy as onp
@ -29,6 +30,7 @@ from jax.core import Primitive, pack, JaxTuple
from jax.interpreters import ad
from jax.interpreters.xla import DeviceArray, DeviceTuple
from jax.abstract_arrays import concretization_err_msg
from jax.lib import xla_bridge as xb
from jax import test_util as jtu
from jax.config import config
@ -899,5 +901,26 @@ class APITest(jtu.JaxTestCase):
jtu.check_raises_regexp(lambda: api.jit(3), TypeError,
"Expected a callable value.*")
def test_issue_1062(self):
# code from https://github.com/google/jax/issues/1062 @shoyer
# this tests, among other things, whether ShardedDeviceTuple constants work
device_count = xb.device_count()
@jit
def multi_step(state, count):
return lax.fori_loop(0, count, lambda i, s: s, state)
@jit
def multi_step_pmap(state, count=2):
@partial(api.pmap, axis_name='x')
def pmapped_multi_step(state):
return multi_step(state, count)
return pmapped_multi_step(state)
u = np.ones((device_count, 100))
u_final = multi_step_pmap(u) # doesn't crash
if __name__ == '__main__':
absltest.main()