Merge pull request #11640 from pschuh:pmap-shaped-array

PiperOrigin-RevId: 464623040
This commit is contained in:
jax authors 2022-08-01 14:22:41 -07:00
commit 75d69725c3
2 changed files with 10 additions and 4 deletions

View File

@ -1964,9 +1964,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
local_axis_size = _mapped_axis_size(
in_tree, args, in_axes_flat, "pmap", kws=True)
for arg in args:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
if config.jax_array:
@ -2030,6 +2027,8 @@ def _get_f_mapped(
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, args, kwargs)
for arg in p.flat_args:
_check_arg(arg)
out = pxla.xla_pmap(
p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=axis_size,
@ -2230,7 +2229,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, args, kwargs)
abstract_args = list(map(xla.abstractify, p.flat_args))
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=axis_size,

View File

@ -308,6 +308,13 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
def testLowerShapedArray(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x_shape = jax.core.ShapedArray(x.shape, x.dtype)
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
def testMean(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')