mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11640 from pschuh:pmap-shaped-array
PiperOrigin-RevId: 464623040
This commit is contained in:
commit
75d69725c3
@ -1964,9 +1964,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
local_axis_size = _mapped_axis_size(
|
||||
in_tree, args, in_axes_flat, "pmap", kws=True)
|
||||
|
||||
for arg in args:
|
||||
_check_arg(arg)
|
||||
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
if config.jax_array:
|
||||
@ -2030,6 +2027,8 @@ def _get_f_mapped(
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
global_arg_shapes, devices, args, kwargs)
|
||||
for arg in p.flat_args:
|
||||
_check_arg(arg)
|
||||
out = pxla.xla_pmap(
|
||||
p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=axis_size,
|
||||
@ -2230,7 +2229,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
global_arg_shapes, devices, args, kwargs)
|
||||
abstract_args = list(map(xla.abstractify, p.flat_args))
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
computation = pxla.lower_parallel_callable(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=axis_size,
|
||||
|
@ -308,6 +308,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
|
||||
def testLowerShapedArray(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x_shape = jax.core.ShapedArray(x.shape, x.dtype)
|
||||
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
|
||||
|
||||
def testMean(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user