# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Callable, Generator, Iterable, Iterator, Sequence from functools import partial import itertools as it import math import operator as op from types import SimpleNamespace from typing import Any, NamedTuple, TypeVar import unittest from absl.testing import absltest from absl.testing import parameterized import numpy as np import jax import jax.ad_checkpoint from jax import api_util from jax import lax from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import core from jax._src import prng from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals from jax._src.mesh import AxisType from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() jtu.request_cpu_devices(8) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip # Helper for some tests. def create_inputs(a_sharding, b_sharding): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( jnp.arange(b * e).reshape((b, e)), jax.sharding.NamedSharding(mesh, a_sharding)) m2 = jax.device_put( jnp.arange(e * f).reshape((e, f)), jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 class ShardMapTest(jtu.JaxTestCase): def test_identity(self): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) def identity(x): return x @jax.jit def fwd(a): c = shard_map( identity, mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (4, 2)) def test_all_gather(self): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) # NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use # all_gather_invariant primitive, which differs in its output replication # type compared to all_gather. @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): return ( lax.all_gather(a, 'z', axis=0, tiled=True), lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True), ) c, d = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) for i, a_shard in enumerate(np.split(a, 4, axis=1)): self.assertAllClose(c.addressable_data(2 * i), a_shard) self.assertEqual(d.addressable_data(0).shape, (4, 8)) for i, a_shard in enumerate(np.split(a, 2, axis=0)): self.assertAllClose(d.addressable_data(i), a_shard) def test_all_gather_with_axis_index_groups(self): mesh, a, _ = create_inputs(P('x', ('y', 'z')), P(None, None)) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', ('y', 'z')),), out_specs=P('x', ('y', 'z')), ) def fwd(a): return lax.all_gather( a, ('y', 'z'), axis_index_groups=((0, 1), (2, 3)), axis=-1, tiled=True ) c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (4, 4)) for i, row_block in enumerate(np.split(a, 2, axis=0)): for j, block in enumerate(np.split(row_block, 2, axis=-1)): self.assertAllClose(c.addressable_data(4 * i + 2 * j), block) self.assertAllClose(c.addressable_data(4 * i + 2 * j + 1), block) def test_matmul_partial(self): raise unittest.SkipTest("invalid replication asserted by out_spec?") mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) assert a.addressable_data(0).shape == (4, 4) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None)) def fwd(a): c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} return c c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (4, 8)) def test_matmul_reduce_scatter(self): mesh, a, b = create_inputs(P('z', 'y'), P('y', None)) assert a.addressable_data(0).shape == (4, 4) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', 'y'), P('y', None)), out_specs=P(('z', 'y'), None)) def fwd(a, b): c = jnp.matmul(a, b) # [B.z, F] {y.unreduced} return ( lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True), lax.psum_scatter(c, ('z', 'y'), scatter_dimension=0, tiled=True), ) expected = jnp.matmul(a, b) c, d = fwd(a, b) self.assertEqual(c.addressable_data(0).shape, (2, 8)) self.assertAllClose(expected, c) self.assertEqual(d.addressable_data(0).shape, (1, 8)) self.assertAllClose(expected[:4] + expected[4:], d) def test_reduce_scatter_with_axis_index_groups(self): axis_index_groups = ((0, 2, 4, 6), (1, 3, 5, 7)) mesh, a, _ = create_inputs(P(None, ('x', 'y', 'z')), P(None, None)) assert a.addressable_data(0).shape == (8, 1) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P(None, ('x', 'y', 'z')),), out_specs=P(None, ('x', 'y', 'z')), ) def fwd(a): return lax.psum_scatter( a, ('x', 'y', 'z'), scatter_dimension=0, axis_index_groups=axis_index_groups, tiled=True, ) c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (2, 1)) sum_of_even_columns = np.sum(a[..., axis_index_groups[0]], -1) for i, sums in enumerate(np.split(sum_of_even_columns, 4, 0)): self.assertAllClose(np.squeeze(c.addressable_data(2 * i), -1), sums) sum_of_odd_columns = np.sum(a[..., axis_index_groups[1]], -1) for i, sums in enumerate(np.split(sum_of_odd_columns, 4, 0)): self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums) def test_collective_permute(self): mesh = jtu.create_mesh((8,), 'x') a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) c = fwd(a) self.assertAllClose(c[1, :], a[0, :]) def test_collective_permute_with_multiple_axis_names(self): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( jnp.arange(8 * 8).reshape((4, 16)), jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))), ) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', ('y', 'z')),), out_specs=P('x', ('y', 'z')), ) def fwd(a): xy_axis_size = lax.psum(1, ('x', 'y')) yz_axis_size = lax.psum(1, ('y', 'z')) xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)] yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)] return ( lax.ppermute(a, ('x', 'y'), perm=xy_perm), lax.ppermute(a, ('y', 'z'), perm=yz_perm), ) c, d = fwd(a) for i in range(8): self.assertAllClose( a.addressable_data(i), c.addressable_data((i + 2) % 8) ) self.assertAllClose( a.addressable_data(i), d.addressable_data(4 * (i // 4) + (i + 1) % 4) ) @parameterized.named_parameters( dict( testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=8) ), dict( testcase_name='_multiple_axis_names', axis_name=('x', 'y'), mesh_axes=dict(x=4, y=2), ), ) def test_all_to_all(self, axis_name, mesh_axes): mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P(axis_name, None)), ) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P(axis_name, None),), out_specs=P(None, axis_name), ) def fwd(a): return lax.all_to_all( a, axis_name, split_axis=1, concat_axis=1, tiled=True ) c = fwd(a) assert (c == jnp.reshape(a.T, (1, 64))).all() def test_all_to_all_with_axis_index_groups(self): mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( jnp.arange(4 * 4).reshape((4, 4)), jax.sharding.NamedSharding(mesh, P('x', None)), ) self.assertEqual(a.addressable_data(0).shape, (1, 4)) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x'), ) def fwd(a): return lax.all_to_all( a, 'x', split_axis=1, concat_axis=0, axis_index_groups=((0, 1), (2, 3)), tiled=True, ) c = fwd(a) # Each shard corresponds to a quadrant rather than a row. self.assertEqual(c.addressable_data(0).shape, (2, 2)) for i, row_block in enumerate(np.split(a, 2, axis=0)): for j, block in enumerate(np.split(row_block, 2, axis=-1)): self.assertAllClose(block, c.addressable_data(2 * i + j)) def test_all_to_all_grad(self): mesh = jtu.create_mesh((4,), 'x') a = jax.device_put( jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None)), ) self.assertEqual(a.addressable_data(0).shape, (2, 8)) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x') ) def fwd(x): return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True) c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) self.assertAllClose(a, c) @jax.jit @partial(jax.grad, has_aux=True) def loss_and_grad(x): loss = fwd(x).sum() * 2 return loss, loss grad, loss = loss_and_grad(a) self.assertEqual(loss, 2 * sum(range(64))) self.assertAllClose(grad, 2 * np.ones_like(a)) def test_eager_repr(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = None @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) def f(x): nonlocal s s = str(x) return x _ = f(np.arange(8 * 8.).reshape(8, 8)) self.assertIsInstance(s, str) self.assertIn('at mesh coordinates', s) def test_jvp_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), jtu.check_grads(g, args, 2, ['fwd']) jtu.check_grads(jax.jit(g), args, 2, ['fwd']) def test_linearize_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) y, y_dot = jax.jvp(g, [x], [x]) y_, g_lin = jax.linearize(g, x) y_dot_ = g_lin(x) self.assertAllClose(y, y_, check_dtypes=False) self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) y, y_dot = jax.jvp(g, [x], [x]) y_, g_lin = jax.linearize(g, x) y_dot_ = g_lin(x) self.assertAllClose(y, y_, check_dtypes=False) self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) y, y_dot = jax.jvp(g, [x], [x]) y_, g_lin = jax.linearize(g, x) y_dot_ = g_lin(x) self.assertAllClose(y, y_, check_dtypes=False) self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_replication_checker_eager(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): return 2 * x def g(x): return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): g(x) def f2(x): return jax.lax.psum(x, 'x') def g2(x): return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = g2(x) # doesn't crash def test_replication_checker_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): return 2 * x def g(x): return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): jax.jit(g)(x) def f2(x): return jax.lax.psum(x, 'x') def g2(x): return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = jax.jit(g2)(x) # doesn't crash def test_process_env_traces(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) x = np.arange(8.) def g(x): y = (3. * x).sum() z = shard_map(lambda x: 2 * x * y, mesh, in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.)) return z jtu.check_grads(g, (x,), modes=['fwd'], order=2) def test_eager_control_flow(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(2 * 2.).reshape(2, 2) def f(x): y = jax.lax.psum(x, ('x', 'y')) if y < 0: return x else: return -x def g(x): return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) y = g(x) self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): return shard_map(lambda x: 2. * x, mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g)(x) self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): return shard_map(lambda x: 2. * x, mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='i')(x) self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name_reuse_mesh_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): return shard_map(lambda x: 2. * x, mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh self.assertAllClose(y, 2 * x, check_dtypes=False) def test_tree_prefix_error(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y')) def f(x): return x x = jnp.arange(8 * 8.).reshape(8, 8) with self.assertRaisesRegex(ValueError, r'shard_map in_specs\[0\]'): f([x, x]) def test_rank_errors(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) def foo(): return {'hi': [3.]} with self.assertRaisesRegex(ValueError, 'which has length 1'): shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})() with self.assertRaisesRegex(ValueError, 'which has length 1'): jax.jit(lambda: shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})())() with self.assertRaisesRegex(ValueError, 'which has rank 0'): shard_map(foo, mesh=mesh, in_specs=({'hi': P('x')},), out_specs=())( {'hi': [jnp.array(3.)]}) with self.assertRaisesRegex(ValueError, r'consider using an in_specs entry of `P\(\)`'): shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.) def test_reverse_mode_ad(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('x',), P(None)), out_specs=P('x',)) def f(x, y): return jnp.sin(x) + 3 + jnp.tan(2.) * jnp.cos(x) + y x = jnp.arange(8.) / 10. y = jnp.arange(4.) / 10. jtu.check_grads(f, (x, y), modes=['fwd', 'rev'], order=2) def test_post_process(self): # JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def g(y): return jnp.sin(y) * jnp.sin(x).sum() return g(jnp.arange(8.)) x = jnp.arange(8.) _, f_lin = jax.linearize(f, x) y_dot = f_lin(x) y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum() self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False) @jtu.run_on_devices('gpu', 'tpu') def test_axis_index(self): mesh = jtu.create_mesh((4,), 'x') @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x')) def f(): return jax.lax.axis_index('x')[None] x = f() self.assertAllClose(x, jnp.arange(4), check_dtypes=False) def test_remat_basic(self): # this tests remat-of-shmap mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) # check param updating is handled @jax.remat @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): return jnp.sin(jnp.sin(x)) x = jnp.arange(4.) g = jax.grad(lambda x: f(x).sum())(x) # doesn't crash self.assertAllClose(g, jnp.cos(jnp.sin(x)) * jnp.cos(x), check_dtypes=False, atol=1e-3, rtol=1e-3) saved_res = saved_residuals(f, x) self.assertLen(saved_res, 1) # also check residuals are handled correctly @partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f2(x): return jnp.sin(jnp.sin(x)) g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash self.assertAllClose(g2, jnp.cos(jnp.sin(x)) * jnp.cos(x), check_dtypes=False, atol=1e-3, rtol=1e-3) saved_res = saved_residuals(f2, x) self.assertLen(saved_res, 2) def test_shmap_of_remat_basic(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) x = jnp.arange(4.) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) @partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable) def f2(x): return jnp.sin(x) g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash self.assertAllClose(g2, jnp.cos(x), check_dtypes=False) def test_remat_scalar_residuals(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): return jnp.sin(jnp.sin(jnp.sin(x.sum()))[None]) x = jnp.arange(8.) _ = jax.grad(lambda x: f(x).sum())(x) # doesn't crash jtu.check_grads(f, (x,), modes=['rev'], order=2, atol=1e-2, rtol=1e-2) def test_collectives_not_saved(self): # regression test for bug in cl/612416803 mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @jax.remat @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): return jax.lax.all_gather(x, 'x') * jax.lax.all_gather(x, 'x') saved_res = saved_residuals(f, jnp.ones(4)) self.assertLen(saved_res, 1) def test_check_rep_false_doesnt_hit_rep_rules(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) prim = core.Primitive('prim') # no rep rule here! prim.multiple_results = True prim.def_impl(lambda: []) prim.def_abstract_eval(lambda: []) @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True) def f(): prim.bind() with self.assertRaises(NotImplementedError): f() with self.assertRaises(NotImplementedError): jax.jit(f)() @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) def f2(): prim.bind() f2() jax.jit(f2)() @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) def f3(): jax.jit(prim.bind)() f3() jax.jit(f3)() def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): return x x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name='y'))(x).jaxpr e, = jaxpr.eqns self.assertIn('in_names', e.params) self.assertEqual(e.params['in_names'], ({0: ('y',), 1: ('x',)},)) self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial( shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False ) def f(x): return jnp.sin(jnp.sum(x)) x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4) put_x = jax.device_put( x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')), ) vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x) vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x) self.assertArraysEqual( vmap_spmd_axisname_result, vmap_no_spmd_axisname_result ) def test_vmap_spmd_axis_name_pair(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): return x x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr e, = jaxpr.eqns self.assertIn('in_names', e.params) self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},)) self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) def test_nested_vmap_with_capture_spmd_axis_name(self): self.skipTest('https://github.com/jax-ml/jax/issues/23476') mesh = jtu.create_mesh((2, 2), ('x', 'y')) def to_map_with_capture(x, y): # We capture x from `to_map_with_capture`'s parameters. def with_capture(y_slice): # Inside of all the maps, we have 'mapped everything away'--we are just # adding two scalars, but one by fully mapping across each of the two # dimensions, the other by mapping across one and capturing the # resulting scalar. self.assertEqual(x.shape, ()) self.assertEqual(y_slice.shape, ()) return x + y_slice # This vmap i will refer to as 'inner vmap'. vmap_with_capture = jax.vmap(with_capture) shmap_vmap_capture = shard_map( vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y') ) return shmap_vmap_capture(y) # And this one is the outer vmap. mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x') x = jnp.arange(2).reshape(2) y = jnp.arange(2 * 2).reshape(2, 2) # Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap # is over an axis of size 2. This is a problem at the moment. jax.make_jaxpr(mapped)(x, y).jaxpr def test_shard_map_abstract_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) def f(x): return shard_map(lambda x: x, mesh=mesh.abstract_mesh, in_specs=P('x'), out_specs=P('x'))(x) out1 = jax.jit(f)(arr) self.assertArraysEqual(out1, np_inp) self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) out_eager = f(arr) self.assertArraysEqual(out_eager, np_inp) self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) out1, out2 = shard_map(lambda x, y: (x, y), mesh=mesh.abstract_mesh, in_specs=P('x'), out_specs=P('x'))(np_inp, arr) self.assertArraysEqual(out1, np_inp) self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) def test_different_devices_shmap_abstract_mesh_cache_hit(self): if jax.device_count() < 4: self.skipTest('Requires >=4 devices') mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i') mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i') abstract_mesh = mesh1.abstract_mesh @jax.jit def f(x): x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'), out_specs=P('i'))(x) return jax.lax.sin(x) with ( jtu.count_jit_tracing_cache_miss() as tracing_count, jtu.count_jit_and_pmap_lowerings() as lowering_count, jtu.count_jit_compilation_cache_miss() as compilation_count, ): a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P())) out_a = f(a) # tracing and lowering cached # same num_devices but different devices. b = jax.device_put(out_a, NamedSharding(mesh2, P())) f(b) # tracing and lowering cache *hit* self.assertEqual(tracing_count(), 1) self.assertEqual(lowering_count(), 1) self.assertEqual(compilation_count(), 2) # 2 misses since devices differ. def test_shmap_abstract_mesh_errors(self): mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) abstract_mesh = mesh.abstract_mesh with self.assertRaisesRegex( ValueError, "Please pass `jax.Array`s with a `NamedSharding` as input to" " `shard_map` when passing `AbstractMesh` to the mesh argument"): shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))(jnp.arange(8)) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) mesh2 = jtu.create_mesh((2,), 'y') abs_mesh2 = mesh2.abstract_mesh with self.assertRaisesRegex( ValueError, 'Mesh shape of the input.*does not match the mesh shape passed to' ' shard_map'): shard_map(lambda x: x, mesh=abs_mesh2, in_specs=P('y'), out_specs=P('y'))(arr) with self.assertRaisesRegex( ValueError, 'Please pass `jax.Array`s with a `NamedSharding` as input to' ' `shard_map` when passing `AbstractMesh` to the mesh argument.'): shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))(np_inp) arr_mesh2 = jax.device_put(np_inp, NamedSharding(mesh2, P('y'))) with self.assertRaisesRegex( ValueError, 'Mesh shape of the input.*does not match the mesh shape passed to' ' shard_map'): shard_map(lambda x, y: (x, y), mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))(arr, arr_mesh2) @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') @jtu.thread_unsafe_test() def test_debug_print_jit(self, jit): if config.use_shardy_partitioner.value: self.skipTest('TODO(b/384938613): Failing under shardy') mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): idx = jax.lax.axis_index('i') jax.debug.print("instance {i} has value x={x}", i=idx, x=x) y = jnp.cos(x) jax.debug.print("instance {i} has value y={y}", i=idx, y=y) return y if jit: f = jax.jit(f) x = jnp.arange(2 * len(jax.devices())) with jtu.capture_stdout() as output: f(x) jax.effects_barrier() for i in range(len(jax.devices())): self.assertIn(f'instance {i} has value', output()) def test_debug_print_eager(self): mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): jax.debug.print("x={x}", x=x) y = jnp.cos(x) jax.debug.print("y={y}", y=y) return y x = jnp.arange(2 * len(jax.devices())) with jtu.capture_stdout() as output: f(x) jax.effects_barrier() for i in range(len(jax.devices())): self.assertIn(f'x=[{2*i} {2*i+1}]', output()) def test_partial_eval_custom_axis_env(self): mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(_): _, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')), None, None, length=1) return idx xs = jnp.arange(16.) jax.eval_shape(jax.grad(lambda x: jax.remat(f)(x).sum().astype('float32')), xs) @jax.legacy_prng_key('allow') def test_prngkeyarray_eager(self): # https://github.com/jax-ml/jax/issues/15398 mesh = jtu.create_mesh((4,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) rng = jax.random.PRNGKey(0) sharded_rng = jax.random.split(rng, num=4) sharded_rng = jax.device_put(sharded_rng, sharding) def f(key): return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16, dtype=jnp.int32) pspec = P('x') if config.enable_custom_prng.value else P('x', None) g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! def test_functools_partial_rank_error(self): mesh = jtu.create_mesh((4,), ('x',)) @partial def f(x): return x g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',)) x = jnp.arange(4) with self.assertRaises(ValueError): g(x) def test_in_specs_none_error(self): mesh = jtu.create_mesh((4,), ('x',)) def f(x): return x with self.assertRaisesRegex(TypeError, "but it was None"): shard_map(f, mesh, in_specs=None, out_specs=P())(3.) # TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug # with self.assertRaises(TypeError): # shard_map(f, mesh, in_specs=(None,), out_specs=P())(3.) shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) def f(x, y, z): x, y, z = x.sum(), y.sum(), z.sum() def body(c, _): c, *cs = c return (*cs, c), None out, _ = jax.lax.scan(body, (x, y, z), None, length=3) return [jnp.expand_dims(a, 0) for a in out] x = jnp.arange(4) # doesn't crash, because out_spec assumes no replication (and there is none) shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) # does crash, because output incorrectly promises replication with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('x'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('y'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(None))(x, x, x) def g(x, y, z): x, y, z = x.sum(), y.sum(), z.sum() def body(c, _): return c, None out, _ = jax.lax.scan(body, (x, y, z), None, length=1) return [jnp.expand_dims(a, 0) for a in out] # doesn't crash, because everything matches shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) # does crash, because the second guy is wrong with self.assertRaisesRegex(ValueError, "require replication"): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) def test_cond_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) x = jnp.arange(4) def f(x, y): def true_fn(x, y): return x def false_fun(x, y): return x + 1 return jax.lax.cond(True, true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): return x def false_fun(x, y): return y return jax.lax.cond(True, true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) def f(x, y): def true_fn(x, y): return x def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): return x def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) with self.assertRaisesRegex(ValueError, "require replication"): shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) # https://github.com/jax-ml/jax/issues/24418 def f(a): c = jax.lax.cond(jnp.any(a), lambda: 1, lambda: 0) return jnp.reshape(c, a.shape) mesh = jtu.create_mesh((2,), ('x',)) a = jnp.array([True, False]) shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) def test_switch_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) x = jnp.arange(4) def f(n, x, y): return jax.lax.switch( n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): return 2. * x @foo.defjvp def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 3. * x_dot mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) def test_eager_custom_vjp_basic(self): @jax.custom_vjp def foo(x): return 2. * x def foo_fwd(x): return foo(x), None def foo_bwd(_, y_bar): return 3. * y_bar, foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @parameterized.parameters([True, False]) def test_axis_index_basic(self, jit): def foo(): return jax.lax.axis_index('x')[None] if jit: foo = jax.jit(foo) mesh = jtu.create_mesh((4,), ('x',)) ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.parameters([True, False]) def test_axis_index_twoaxes(self, jit): def foo(): out1 = jax.lax.axis_index('i')[None, None] out2 = jax.lax.axis_index('j')[None, None] out3 = jax.lax.axis_index(('i', 'j'))[None, None] return out1, out2, out3 if jit: foo = jax.jit(foo) mesh = jtu.create_mesh((4, 2), ('i', 'j')) ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2)) expected3 = jnp.arange(8.).reshape(4, 2) self.assertAllClose(ans1, expected1, check_dtypes=False) self.assertAllClose(ans2, expected2, check_dtypes=False) self.assertAllClose(ans3, expected3, check_dtypes=False) def test_axis_index_eager(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P()) def foo(): val = jax.lax.psum(jax.lax.axis_index('x'), 'x') return 1. if val > 0 else -1. out = foo() # doesn't crash self.assertEqual(out, 1.) def test_jaxpr_shardings_with_no_outputs(self): # https://github.com/jax-ml/jax/issues/15385 mesh = jtu.create_mesh((4,), ('i',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('i')) def f(): return jax.lax.iota(jnp.dtype('int32'), 4) f() # don't crash @partial(shard_map, mesh=mesh, in_specs=(P('i'),), out_specs=P('i')) def g(a_block): i = jnp.arange(a_block.shape[0]) return i + a_block g(np.arange(32)) # don't crash def test_device_put(self): mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return x + jax.device_put(1) x = jnp.arange(32.) f(x) # doesn't crash jax.jit(f)(x) # doesn't crash @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def g(x): return x + jax.device_put(1, jax.devices()[0]) with self.assertRaisesRegex(ValueError, "got device"): g(x) # jit means device_puts are ignored, even those within shmap bodies, so no # error! jax.jit(g)(x) # doesn't crash @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_key_array_with_replicated_last_tile_dim(self): # See https://github.com/jax-ml/jax/issues/16137 mesh = jtu.create_mesh((2, 4), ('i', 'j')) def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), check_rep=False) def g(rng): return jnp.array([jax.random.normal(rng[0])]) return g(jax.random.split(rng, 4)) jax.jit(f)(jax.random.key(0)) # doesn't crash # same method appears in api_test.py:DCETest # TODO(mattjj): consider moving this method to be a helper in jtu def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool], expected_used_inputs: list[bool], expected_num_eqns: int | None = None, check_diff: bool = True): jaxpr_dce, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs) core.check_jaxpr(jaxpr_dce) self.assertEqual(used_inputs, expected_used_inputs) if expected_num_eqns is not None: all_jaxprs = it.chain([jaxpr_dce], core.subjaxprs(jaxpr_dce)) num_eqns = sum(len(subjaxpr.eqns) for subjaxpr in all_jaxprs) self.assertEqual(num_eqns, expected_num_eqns, msg=str(jaxpr_dce)) rand_ = jtu.rand_small(np.random.RandomState(0)) rand = lambda v: rand_(v.aval.shape, v.aval.dtype) consts = [rand(v) for v in jaxpr.constvars] inputs = [rand(v) for v in jaxpr.invars ] inputs_dce = [x for x, used in zip(inputs, used_inputs) if used] full_outs = core.eval_jaxpr(jaxpr , consts, *inputs) expected_outs_dce = [y for y, used in zip(full_outs, used_outputs) if used] outs = core.eval_jaxpr(jaxpr_dce, consts, *inputs_dce) self.assertAllClose(outs, expected_outs_dce) if check_diff and expected_num_eqns != 0: f = lambda *args: core.eval_jaxpr(jaxpr_dce, consts, *args) jtu.check_grads(f, inputs_dce, order=2, modes=['rev']) def test_returned_out_sharding(self): mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, inp) def test_dce(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) def f(x, y, z): @partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')), out_specs=(P(None, None), P(None, 'i'), P('i', 'j'))) def g(y, z): return jnp.sin(x), jnp.cos(z), jnp.tan(y) return g(y, z) x = jnp.zeros((4, 4)) y = jnp.zeros((8, 8)) z = jnp.zeros((16, 16)) jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr self.assertLen(jaxpr.eqns, 1) self.assertLen(jaxpr.eqns[0].params['jaxpr'].eqns, 3) # If we use all outputs, nothing should be deleted. self.assert_dce_result( jaxpr, used_outputs=[True, True, True], expected_used_inputs=[True, True, True], expected_num_eqns=1 + 3, # one outer eqn, three remain in body check_diff=False) # If we drop the last output, the second input should be dropped. self.assert_dce_result( jaxpr, used_outputs=[True, True, False], expected_used_inputs=[True, False, True], expected_num_eqns=1 + 2, # one outer eqn, two remain in body check_diff=False) # If we drop the second output, the last input should be dropped. self.assert_dce_result( jaxpr, used_outputs=[True, False, True], expected_used_inputs=[True, True, False], expected_num_eqns=1 + 2, # one outer eqn, two remain in body check_diff=False) # If we drop the latter two outputs, the latter two inputs should be dropped self.assert_dce_result( jaxpr, used_outputs=[True, False, False], expected_used_inputs=[True, False, False], expected_num_eqns=1 + 1, # one outer eqn, two remain in body check_diff=False) # Finally, try dropping the closed-over value. self.assert_dce_result( jaxpr, used_outputs=[False, True, False], expected_used_inputs=[False, False, True], expected_num_eqns=1 + 1, # one outer eqn, two remain in body check_diff=False) def test_post_process_partial_eval_with_scalar_res(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh, in_specs=P(), out_specs=P())())(2.0) self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False) def test_sharding_metadata_in_hlo_attrs(self): mesh = Mesh(jax.devices(), ('i',)) x = jnp.arange(len(jax.devices()), dtype='float32') y = jnp.array([3.], dtype='float32') def foo(x): x = jnp.sin(x) x = shard_map(lambda x: jnp.cos(x * y), mesh, in_specs=P('i'), out_specs=P('i'))(x) x = shard_map(lambda x: jnp.cos(x * y), mesh, in_specs=P('i'), out_specs=P('i'))(x) return x hlo_str = jax.jit(foo).lower(x).as_text("stablehlo", debug_info=True) if config.use_shardy_partitioner.value: if len(jax.devices()) > 1: self.assertEqual(2, hlo_str.count('sdy.manual_computation')) else: # When devices == 1, the `sdy.manual_computation` is inlined. self.assertEqual(0, hlo_str.count('sdy.manual_computation')) else: self.assertIn('call @shmap_body', hlo_str) self.assertIn('call @shmap_body_0', hlo_str) self.assertIn('%arg0: tensor<1xf32>', hlo_str) self.assertIn('"[None]"', hlo_str) self.assertIn('%arg1: tensor<1xf32>', hlo_str) self.assertIn('"[(\'i\',)]"', hlo_str) self.assertIn( '-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str ) def test_rewrite_process_call(self): def f(x): return core.call_p.bind( lu.wrap_init(lambda x: [2. * x], debug_info=api_util.debug_info("test", lambda x: [2. * x], (x,), {})), x)[0] * x mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call self.assertAllClose(y, 2 * x * x, check_dtypes=True) def test_rewrite_post_process_call(self): # We shouldn't hit post_process_call here because of RewriteTrace's dynamic # behavior (i.e. no data dependence). mesh = jtu.create_mesh((4,), ('x',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) def f(x): return core.call_p.bind( lu.wrap_init(lambda: [2. * x], debug_info=api_util.debug_info("test", lambda: [2. * x], (), {})))[0] * x x = jnp.arange(4.) y = f(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) @parameterized.parameters([True, False]) def test_rewrite_process_custom_jvp_call(self, jit): @jax.custom_jvp def foo(x): return 2. * x @foo.defjvp def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 2. * x_dot mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) x = jnp.arange(4.) y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) y2, y_dot = jax.jvp(g, (x,), (3 * x,)) self.assertAllClose(y2, 2 * x * x, check_dtypes=True) self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True) @parameterized.parameters([True, False]) def test_rewrite_process_custom_vjp_call(self, jit): @jax.custom_vjp def foo(x): return 2. * x def foo_fwd(x): return foo(x), None def foo_bwd(_, y_bar): return 2. * y_bar, foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) x = jnp.arange(4.) y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) @parameterized.parameters([True, False]) def test_rewrite_process_custom_vjp_call_match_more_replicated(self, jit): @jax.custom_vjp def foo(x): return 2. * x def foo_fwd(x): return foo(x), None def foo_bwd(_, y_bar): return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) x = jnp.arange(4.) y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True) def test_same_pspec_eager_shard_map(self): # This behavior is not guaranteed by JAX and this test can be changed if # the behavior changes. mesh = jtu.create_mesh((1, 4, 1), ('data', 'seq', 'model')) def f(x): return x * x + 2 x = jnp.ones([2, 16, 4]) x_spec = jax.sharding.PartitionSpec("data", "seq", "model") x = jax.device_put(x, jax.sharding.NamedSharding(mesh, x_spec)) shard_f = shard_map(f, mesh=mesh, in_specs=x_spec, out_specs=x_spec) y = shard_f(x) self.assertEqual(x_spec, y.sharding.spec) @parameterized.parameters([True, False]) def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit): @jax.custom_vjp def foo(x, y): del y return 2. * x def foo_fwd(x, y): return foo(x, y), y def foo_bwd(y, _): return y, None # diff! x_bar less replicated than primal/tangent foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x, y: foo(x, y) * y, mesh, in_specs=(P(), P('x')), out_specs=P('x')) if jit: g = jax.jit(g) x = jnp.arange(4.) y = jnp.arange(4 * 4.) z = g(x, y) self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y) self.assertAllClose(z.sum(), z_, check_dtypes=False) self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), check_dtypes=False) @parameterized.parameters([True, False]) def test_rewrite_custom_vjp_call_jaxpr(self, jit): @jax.custom_vjp def foo(x): return 2. * x def foo_fwd(x): return foo(x), None def foo_bwd(_, y_bar): return 2. * y_bar, foo.defvjp(foo_fwd, foo_bwd) def foo_scan(x): y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1) return y mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo_scan(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) x = jnp.arange(4.) y = g(x) self.assertAllClose(y, 2 * x * x, check_dtypes=True) y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x) self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) def test_transpose_identity(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): return x jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.) e, = jaxpr.jaxpr.eqns self.assertEmpty(e.params['jaxpr'].eqns) jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,)) e, = jaxpr.jaxpr.eqns self.assertEmpty(e.params['jaxpr'].eqns) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def g(x): return jax.jit(lambda x: 1. * x)(x) jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns self.assertEmpty(e1.outvars) self.assertLen(e2.params['jaxpr'].eqns, 1) def test_fanout_specs_transpose_to_psum(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) def f(x): return x jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e2, = e.params['jaxpr'].eqns self.assertEqual(str(e2.primitive), 'psum2') self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P()) def f(x): return jax.lax.psum(x, 'x') jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.])) e, = jaxpr.jaxpr.eqns e1, = e.params['jaxpr'].eqns self.assertEqual(str(e1.primitive), 'pbroadcast') def test_psum_with_implicit_fanout_self_transposes(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): return jax.lax.psum(x, 'x') jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns self.assertEqual(str(e1.primitive), 'psum2') self.assertEqual(str(e2.primitive), 'pbroadcast') def test_transpose_float0(self): mesh = jtu.create_mesh((4,), ('x',)) s = jax.sharding.NamedSharding(mesh, P(None, 'x')) # vjp that triggers float0 @jax.custom_vjp def f(x, _): return x def f_fwd(x, y): return x, jnp.zeros(shape=y.shape, dtype=np.int32) def f_rev(tmp, g): return (g, tmp) f.defvjp(f_fwd, f_rev) # trivial vjp that consumes float0 @jax.custom_vjp def g(x, y): return x, y def g_fwd(x, y): return jax.vjp(lambda x, y: (x, y), x, y) def g_bwd(vjp_fn, result): return vjp_fn(result) g.defvjp(g_fwd, g_bwd) @partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P()) def f_shmapped(x, y): return jax.lax.psum(f(x, y).sum(), axis_name=('x')) @partial(shard_map, mesh=mesh, check_rep=False, in_specs=P('x'), out_specs=(P('x'), P())) def f_shmapped2(x, y): return g(x, y) def f_wrapper(x, y): x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y)) return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum() @partial(jax.jit, in_shardings=s, out_shardings=jax.sharding.NamedSharding(mesh, P())) def example(x, y): return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y) x = np.zeros(shape=(8,16), dtype=np.float32) y = np.zeros(shape=(8,16), dtype=np.int32) # Doesn't crash. dx, dy = example(x, y) self.assertEqual(dy.dtype, jax.dtypes.float0) def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x')) def f(x, y): return x * y jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e = e.params['jaxpr'].eqns[0] self.assertEqual(e.primitive.name, 'pbroadcast') self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None, length=2) return x jaxpr = jax.make_jaxpr(f)(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e, = e.params['jaxpr'].eqns e1, e2 = e.params['jaxpr'].eqns self.assertEqual(e1.primitive.name, 'psum2') self.assertEqual(e2.primitive.name, 'pbroadcast') def test_check_rep_false_grads(self): if jtu.is_device_tpu(5, 'e'): self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e') # This test is redundant with the systematic tests below, but it serves as a # direct regression test for a bug. mesh = jtu.create_mesh((4,), ('heads',)) def f(q, k, v): def body(q, k, v): return q * k[None, :] + v[None, :] out = shard_map(body, mesh, check_rep=False, in_specs=(q_spec, kv_spec, kv_spec,), out_specs=q_spec)(q, k, v) return out.sum() q_spec = P('heads', None) kv_spec = P(None) q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec)) k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) if jtu.device_under_test() == 'tpu': rtol = 2e-2 else: rtol = 1e-2 jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol) def test_axis_env_extension_regression(self): def foo(x): i = jax.lax.axis_index('x') return jnp.exp(x) + i.astype(x.dtype) @partial(jax.remat, policy=lambda *args, **kwargs: True) def bar(x): return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),), out_specs=P('x'), check_rep=False)(x) jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization(self, jit, remat): mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x) if jit: f = jax.jit(f) if remat: policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable f = jax.remat(f, policy=policy) g = lambda x: f(x).sum() x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x.sum()) + x, jax.lax.exp(x) if jit: f = jax.jit(f) if remat: policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable f = jax.remat(f, policy=policy) g = lambda x: sum(f(x)).sum() x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) e1, _, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) @parameterized.parameters([True, False]) def test_check_rep_failure_inside_rule(self, jit): mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): return jax.lax.psum(((w * x) ** 2).sum(), 'i') return f(x) if jit: loss = jax.jit(loss) jax.grad(loss)(3.0, jnp.arange(8.)) # don't crash def test_conv_general_dilated(self): mesh = jtu.create_mesh((4,), ('i',)) dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) @partial(shard_map, mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), out_specs=P(None, None)) def f(x, y): return lax.psum(dot(x, y), 'i') a = jnp.ones((16, 32)) b = jnp.ones((32, 8)) y = f(a, b) # don't crash self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2) def test_cumsum(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8.) shard_map(jnp.cumsum, mesh=mesh, in_specs=P('i'), out_specs=P('i') )(x) # don't crash def test_custom_jvp_inside_jit(self): mesh = jtu.create_mesh((4,), ('batch',)) x = shard_map(jax.jit(jax.nn.relu), mesh=mesh, in_specs=P('batch'), out_specs=P('batch'))(jnp.arange(16.)) # don't crash def test_random_normal_rules(self): mesh = jtu.create_mesh((4,), ('i',)) keys = jax.random.split(jax.random.key(0), 4) shard_map(lambda k: jax.random.normal(k[0], (1,)), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(keys) # don't crash def test_erf_rules(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(16.) shard_map(jax.lax.erf, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash def test_error_for_variable_num_args(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(*args): return args[0] @ args[1] shard_f = shard_map( f, mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) def test_custom_vjp_replication_error_message_hint(self): mesh = jtu.create_mesh((4,), 'i') @jax.custom_vjp def f(x): return jax.lax.psum(x, 'i') def f_fwd(x): return f(x), None def f_bwd(_, g): return jax.lax.psum(g, 'i'), f.defvjp(f_fwd, f_bwd) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def g(x): return f(f(x)) y, grad = jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4)) # first psum sums, second psum multiplies by 4 self.assertAllClose(y, (jnp.ones(4) * 4).sum(), check_dtypes=False) # two psums on the backward pass, each one multiplies by 4 self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) def test_repeated_psum_allowed(self): # https://github.com/jax-ml/jax/issues/19175 mesh = jtu.create_mesh((4,), 'i') @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def g(x): return jax.lax.psum(jax.lax.psum(x, 'i'), 'i') y = g(jnp.arange(4.)) self.assertAllClose(y, jnp.arange(4.).sum(keepdims=True) * 4, check_dtypes=False) def test_approx_top_k(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) x = jnp.array([3.0, 1.0, 4.0, 2.0]) _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh, P('i'), P('i'))(x) def test_disable_jit(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return x x = jnp.arange(8.) with jax.disable_jit(): f(x) # don't crash @parameterized.parameters(it.product(range(4), repeat=3)) @jtu.run_on_devices("cpu") def test_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd): num_args = 3 rng = np.random.RandomState(seed) mesh = Mesh(np.array(jax.devices()[:1]), ('i',)) in_perm = rng.permutation(num_args) out_perm = rng.permutation(num_args) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(inputs): inputs = [inputs[i] for i in in_perm] outputs = inputs[:num_input_fwd] + [ jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)] return [outputs[i] for i in out_perm] jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1, modes=['rev'], atol=1e-3, rtol=1e-3) def test_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, {AxisType.Manual: ('i',), AxisType.Auto: ('j',)}) x = jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P(None, 'j'))) return x * x @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', None), out_specs=P('i', None), auto=frozenset({'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( 'in_shardings=[<@mesh, [{"i"}, {?}]>]' ' out_shardings=[<@mesh, [{"i"}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: self.assertIn( 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,' ' replicated}}', f.lower(v).as_text('hlo'), ) self.assertAllClose(v * v, f(v), check_dtypes=False) def test_partial_auto_explicit_no_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('i', 'j'), axis_types=(AxisType.Explicit,) * 2) def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) return out @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', None), out_specs=P('i', None), auto=frozenset({'j'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) out = f(v) self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v * v, out, check_dtypes=False) @jtu.with_user_mesh((2, 2), ('i', 'j')) def test_partial_auto_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, {AxisType.Manual: ('i',), AxisType.Explicit: ('j',)}) self.assertEqual(x.aval.sharding.spec, P(None, 'j')) out = x * x self.assertEqual(out.aval.sharding.spec, P(None, 'j')) return out @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', None), out_specs=P('i', None), auto=frozenset({'j'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) out = f(v) self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v * v, out, check_dtypes=False) if config.use_shardy_partitioner.value: self.assertIn( 'sdy.sharding_constraint %1 <@mesh, [{}, {"j"}]>', f.lower(v).as_text(), ) else: self.assertIn( 'mhlo.sharding = "{devices=[1,2,2]<=[2,2]T(1,0) last_tile_dims={manual}}"}', f.lower(v).as_text(), ) @jax.jit def h(x): return jnp.sum(f(x)) jax.grad(h)(v) # doesn't crash jax.jit(jax.grad(h))(v) # doesn't crash @jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) def test_partial_auto_explicit_multi_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, {AxisType.Manual: ('i', 'j'), AxisType.Explicit: ('k', 'l')}) self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l')) out = x.T self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None)) return out @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', 'j', None, None), out_specs=P('i', 'j', None, None), auto=frozenset({'k', 'l'}))(x) self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None)) return x v = jnp.arange(64.).reshape(4, 2, 2, 4) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j', 'k', 'l'))) out = f(v) self.assertEqual( out.sharding, NamedSharding(mesh, P(('i', 'l'), ('j', 'k'), None, None))) def test_partial_auto_propagate_through(self): mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k')) sharding = jax.sharding.NamedSharding(mesh, P('i')) def g(x): return jax.lax.with_sharding_constraint(x * x, sharding) @jax.jit def f(x): return shard_map( g, mesh, in_specs=P(), out_specs=P(), check_rep=False, auto=frozenset({'i'}), )(x) v = jnp.arange(32.0).reshape(4, 8) v = jax.device_put(v, sharding) if config.use_shardy_partitioner.value: self.assertIn( 'in_shardings=[<@mesh, [{?}, {?}]>]' ' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}', f.lower(v).as_text(), ) else: self.assertIn( 'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,' ' replicated}}', f.lower(v).as_text('hlo'), ) actual = f(v) self.assertAllClose(v * v, actual, check_dtypes=False) self.assertEqual(actual.sharding, sharding) def test_shmap_close_over_unused_params(self): mesh = jtu.create_mesh((2,), ("data",)) def loss_fn(_, batch): return jnp.sum(batch) @jax.jit def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(grad_fn, mesh=mesh, in_specs=P("data"), out_specs=P(), check_rep=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash def test_shmap_close_over_unused_params_vmap(self): mesh = jtu.create_mesh((2,), ("data",)) def loss_fn(params, batch): return jnp.sum(params) + jnp.sum(batch) @jax.jit def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"), out_specs=P("data"), check_rep=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash def test_sharded_prng_with_abstract_mesh(self): shape = (8, 2, 2) mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape) key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl) key = jax.device_put(key, NamedSharding(mesh, P())) @jax.jit def shard_key(key): return shard_map( lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key) out = shard_key(key) self.assertTrue(out.sharding.is_equivalent_to(NamedSharding(mesh, P()), out.ndim)) def test_partial_auto_error_wsc_manual(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) return x * x @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) with self.assertRaisesRegex(ValueError, "manual"): f(v) def test_partial_auto_error_invalid_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) return x * x @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'k'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"): f(v) def test_partial_auto_error_wrong_in_specs(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) return x * x @jax.jit def f(x): x = shard_map(g, mesh, in_specs=P('i', 'j'), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): f(v) def test_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): return x * x def h(x): return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_grad_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): # manual: 'i', 'j' return x * x def h(x): # auto: 'j', manual: 'i' return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): # auto: 'i', 'j' return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) def test_grad_nested_partial_auto_with_residuals(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): return x * x * x def h(x): return shard_map(g, mesh, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False) def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) def h(x): return x * x @jax.jit def f(x): return shard_map(h, mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j', 'k'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_of_pjit(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) def h(): def _make_zeros(): return jnp.zeros(()) s = jax.sharding.NamedSharding(mesh, P()) y = jax.jit(_make_zeros, out_shardings=s)() return y.reshape((1,)) def f(): return shard_map( h, mesh, in_specs=(), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_of_pjit_different_mesh(self): if config.use_shardy_partitioner.value: self.skipTest( 'Shardy requires the mesh axis names to be the same across ' 'the entire computation.' ) mesh = jtu.create_mesh((2, 2), ('i', 'j')) mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l')) def h(): def _make_zeros(): return jnp.zeros(()) s = jax.sharding.NamedSharding(mesh2, P()) y = jax.jit(_make_zeros, out_shardings=s)() return y.reshape((1,)) def f(): return shard_map( h, mesh, in_specs=(), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_axis_index(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) out_sharding = NamedSharding(mesh, P('i', None)) @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))() self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_axis_index_degenerated_axis(self): mesh = jtu.create_mesh((1, 2), ('i', 'j')) out_sharding = NamedSharding(mesh, P('i', None)) @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1), mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))() self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_ppermute(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(8.) def g(x): x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j'))) return jax.lax.ppermute(x, 'i', [(0, 1), (1, 2), (2, 3), (3, 0)]) @jax.jit def f(x): return shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))(x) y = f(x) # don't crash self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]), check_dtypes=False) # TODO(parkers,mattjj): get XLA to support this too # def test_partial_auto_all_to_all(self): # # mesh = jtu.create_mesh((4, 2), ('i', 'j')) # x = jnp.arange(128.).reshape(16, 8) # # def g(x): # x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j'))) # return jax.lax.all_to_all(x, 'i', 0, 1, tiled=True) # # @jax.jit # def f(x): # return shard_map(g, # mesh, in_specs=P('i', None), out_specs=P(None, 'i'), # check_rep=False, auto=frozenset({'j'}))(x) # # f(x) # don't crash def test_partial_auto_debug_print(self): if config.use_shardy_partitioner.value: raise unittest.SkipTest("shardy error") mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(8.) def g(x): jax.debug.print('{}', x) @jax.jit def f(x): return shard_map(g, mesh, in_specs=P('i'), out_specs=None, check_rep=False, auto=frozenset({'j'}))(x) y = f(x) # don't crash def test_partial_auto_of_random_keys(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) keys = jax.random.split(jax.random.key(0), 8) @jax.jit def f(x): return shard_map(lambda k: k, mesh, in_specs=P('i'), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))(keys) y = f(keys) # doesn't crash self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys), check_dtypes=False) def test_partial_auto_of_random_keys_slice(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) keys = jax.random.split(jax.random.key(0), 8).reshape(4, 2) @jax.jit def f(x): return shard_map(lambda k: k[0], mesh, in_specs=P('i'), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))(x) f(keys) # doesn't crash def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j'), ) def f(x): return jnp.sin(x) xs = jnp.arange(4 * 16.).reshape(4, 16) jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21056 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) @partial( shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j'), ) def f(x): return jnp.sin(x) xs = jnp.arange(4 * 16.).reshape(4, 16) jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_grad_shmap_residuals_axis_names_in_mesh_order(self): # https://github.com/jax-ml/jax/issues/21236 mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) @partial( shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j'), ) def f(x): return jnp.sin(x) xs = jnp.arange(16.) ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) if config.use_shardy_partitioner.value: self.assertIn( 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() ) else: self.assertIn( "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() ) def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), ) def f(x): return jnp.sin(x) xs = jnp.arange(4 * 16.).reshape(4, 16) with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): jax.vmap(f, spmd_axis_name='i')(xs) @partial( shard_map, mesh=mesh, in_specs=P('j'), out_specs=P(('i', 'j')), check_rep=False, ) def g(x): return jnp.sin(x) xs = jnp.arange(4 * 16.).reshape(4, 16) with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): jax.vmap(g, spmd_axis_name='i')(xs) def test_in_spec_none(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(8).reshape(4, 2) def f(o, x): self.assertIs(o, obj) return jnp.sin(x) obj = object() y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f2(o, x): self.assertIsInstance(o, dict) self.assertIs(o['a'], obj['a']) return jnp.sin(x) obj = {'a': object()} y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f3(x, o): self.assertIs(o, obj) return jnp.sin(x) obj = object() y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f4(o1, o2, x, o3): self.assertIs(o1, obj1) self.assertIs(o2[0], obj2[0]) self.assertIs(o2[1], obj2[1]) self.assertIs(o3, obj3) return jnp.sin(x) obj1 = object() obj2 = (object(), object()) obj3 = object() y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) with self.assertRaisesRegex(ValueError, 'divisible'): shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) with self.assertRaisesRegex(ValueError, 'divisible'): shard_map(lambda *_: None, mesh, (P('i'), None), None )(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, )(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, )((object(), object()), x) def test_in_spec_none_rank_errors(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) with self.assertRaisesRegex(ValueError, 'rank'): shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) with self.assertRaisesRegex(ValueError, 'rank'): shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None )(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, )(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, )((object(), object()), x) def test_custom_linear_solve_rep_rules(self): # https://github.com/jax-ml/jax/issues/20162 mesh = jtu.create_mesh((1,), ('i',)) a = jnp.array(1).reshape(1, 1) b = jnp.array(1).reshape(1) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(a, b): c = jnp.linalg.solve(a, b) return c _ = f(a, b) # don't crash def test_temporary_error_suppression_flag(self): mesh = jtu.create_mesh((2,), ('i',)) def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), mesh=mesh, in_specs=(P(None), P('i')), out_specs=P(None), check_rep=False, )(x, y) return z y = jnp.arange(8) xs = jnp.arange(32).reshape(4, 8) with self.assertRaisesRegex(ValueError, 'vmap spmd_axis_name cannot appear in'): _ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y) with config.disable_vmap_shmap_error(): _ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y) def test_in_spec_none_hashability(self): mesh = jtu.create_mesh((2,), ('i',)) class A: def __hash__(self): raise Exception @partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=()) def f(a): return () f(A()) # don't crash def test_get_check_rep(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x, reduce_along, use_jit): out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along)) @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) def g(x): result = lax.psum(x, axis_name=reduce_along) def check_rep(result): self.assertEqual( jax.experimental.shard_map.get_replication(result), set(reduce_along)) return result result = check_rep(result) result = jax.vmap(check_rep)(result) return result if use_jit: return jax.jit(g)(x) else: return g(x) for use_jit in [True, False]: x = np.zeros((8, 8), dtype=np.float32) f(x, reduce_along=('y',), use_jit=use_jit) f(x, reduce_along=('x',), use_jit=use_jit) f(x, reduce_along=('x', 'y'), use_jit=use_jit) def test_pmin(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmin(x, 'i'), mesh=mesh, in_specs=P('i'), out_specs=P() )(x) # don't crash self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) def test_pmax(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, in_specs=P('i'), out_specs=P() )(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) class FunSpec(NamedTuple): name: str num_inputs: int fun: Callable out_rep: Callable valid_types: Callable | None = None fun_specs = [ FunSpec('id', 1, lambda x: x, lambda r: r), FunSpec('flip', 2, lambda x, y: (y, x), lambda r_x, r_y: (r_y, r_x)), FunSpec('transpose', 1, lambda x: x.T, lambda r: r), FunSpec('ravel', 1, lambda x: x.ravel(), lambda r: r), FunSpec( 'dot', 2, jnp.dot, lambda r1, r2: r1 & r2, lambda x1, x2: (x1.shape and x2.shape and x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]), ), FunSpec( 'sin_dot_sin', 2, lambda x1, x2: jnp.sin(jnp.dot(jnp.sin(x1), x2)), lambda r1, r2: r1 & r2, lambda x1, x2: (x1.shape and x2.shape and x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0])), FunSpec('relu', 1, lambda x: jax.nn.relu(x + 1) - 1, lambda r: r), ] input_shapes = [ jax.ShapeDtypeStruct(shape, jnp.dtype('float32')) # TODO(mattjj): 0 axis sizes lead to XLA sigfpe, file bug! for k in range(1, 4) for shape in it.permutations(range(1, 4), k) if not shape or len(set(shape)) > 1 # skip all-equal shapes, boring! ] mesh_shapes = [ (1,), (1, 1), (1, 2), (2, 2), (2, 4), (4, 2), ] # Reference implementation of shard_map. ShapeDtypeDuck = Any # has shape and dtype attributes Specs = Any # pytree of PartitionSpec def shmap_reference( body_in_types: Sequence[ShapeDtypeDuck], body_out_types: Sequence[ShapeDtypeDuck], out_types: Sequence[ShapeDtypeDuck], f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs ) -> Callable: def f_shmapped(*args): outs = jax.tree.map(lambda y: jnp.zeros(y.shape, y.dtype), out_types) getters = [make_indexer(mesh, s, x) for s, x in zip(in_specs, args)] putters = jax.tree.map(partial(make_indexer, mesh), out_specs, outs) for idx in it.product(*map(range, mesh.shape.values())): args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)] assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types)) out_shards = f(*args_shards) assert jax.tree.all(jax.tree.map(lambda y, r: y.shape == r.shape, out_shards, body_out_types)) outs = jax.tree.map(lambda y, out, indexer: out.at[indexer(idx)].set(y), out_shards, outs, putters) return outs return f_shmapped def make_indexer(mesh: Mesh, spec: P, x: Any ) -> Callable[[tuple[int, ...]], tuple[slice, ...]]: block_shape = [d // math.prod(mesh.shape[ax] for ax in (elt or ())) for d, elt in zip(x.shape, spec)] def indexer(idx): starts = [0 if el is None else idx[list(mesh.shape).index(el)] if type(el) is not tuple else sum(idx[list(mesh.shape).index(el[i])] * math.prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el))) for el in spec] return tuple(slice(start * size, (start + 1) * size) for start, size in zip(starts, block_shape)) return indexer # The code below is similar to named_cases_from_sampler in test_util.py, but it # uses generators instead of passing a "select" function around. # To sample test cases efficiently, we construct a generator which yields to the # caller to choose one of an iterable's options. That is, we can read 'yield' in # this code as 'choose one'. To call functions which themselves need to make # choices, we use 'yield from'. That is, we can read 'yield from' in this code # as 'call this choice-making function'. Option = Any CaseSpec = tuple # first element is a string test name Chooser = Generator[Iterable[Option], Option, CaseSpec] def sample_shmap() -> Chooser: spec = yield fun_specs mesh_shape = yield mesh_shapes axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)] mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)), axis_names=axis_names) in_types = (tys for tys in it.product(input_shapes, repeat=spec.num_inputs) if not spec.valid_types or spec.valid_types(*tys)) body_in_types = yield in_types body_out_types = jax.eval_shape(spec.fun, *body_in_types) in_types, in_specs = yield from make_in_specs(mesh, body_in_types) args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size for ty in in_types] out_reps = spec.out_rep(*map(partial(unmentioned, mesh), in_specs)) out_specs = yield from make_out_specs(mesh, body_out_types, out_reps) out_types = jax.tree.map(partial(dilate, mesh), out_specs, body_out_types) ref = partial(shmap_reference, body_in_types, body_out_types, out_types) in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short() for t in in_types) + ')' jit = yield [True, False] name = f'{spec.name}_{mesh.shape}_jit={jit}_{in_specs}_{out_specs}_{in_str}' return name, spec.fun, mesh.shape, jit, in_specs, out_specs, args, ref def unmentioned(mesh: Mesh, pspec: P) -> set[core.AxisName]: return set(mesh.axis_names) - {n for ns in pspec if ns is not None for n in (ns if type(ns) is tuple else [ns])} # To drive the sampler, we have `sample` function which just runs a loop. def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]: rng = np.random.RandomState(0) seen: set[str] = set() while len(seen) < num: name, *case = sample_one(rng, make_gen()) if name not in seen: seen.add(name) yield case # To sample one test spec, we run the generator, getting back sequences of # options from it and sending in our choices from those options until finally a # test case spec is produced. def sample_one(rng: np.random.RandomState, gen: Chooser) -> CaseSpec: lst = list(next(gen)) try: while True: choice = lst[rng.randint(len(lst))] lst = list(gen.send(choice)) except StopIteration as e: return e.value # Next are some choice-making functions for shard_map test specifications. MeshDuck = Any # same attributes as a Mesh def make_in_specs(mesh: MeshDuck, in_types: Sequence[ShapeDtypeDuck] ) -> Chooser: pairs = [] for ty in in_types: pair = yield from make_in_spec(mesh, ty) pairs.append(pair) return tuple(zip(*pairs)) def make_in_spec(mesh: Mesh, in_type_base: ShapeDtypeDuck) -> Chooser: assert len(list(powerset(mesh.shape))) subset = yield powerset(mesh.shape) elts = yield partitions(subset, len(in_type_base.shape)) partition_spec = P(*(tuple(e) if e else None for e in elts)) new_type = dilate(mesh, partition_spec, in_type_base) return new_type, partition_spec def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck: new_shape = tuple(d * math.prod(mesh.shape[ax] for ax in (elt or ())) for d, elt in zip(shape.shape, spec)) return jax.ShapeDtypeStruct(new_shape, shape.dtype) def make_out_specs( mesh: MeshDuck, out_types: ShapeDtypeDuck | Sequence[ShapeDtypeDuck], out_reps: set[core.AxisName] | Sequence[set[core.AxisName]] ) -> Chooser: if type(out_types) is not tuple: out_spec = yield from make_out_spec(mesh, out_types, out_reps) # type: ignore return out_spec else: out_specs = [] for ty, rep in zip(out_types, out_reps): out_spec = yield from make_out_spec(mesh, ty, rep) # type: ignore out_specs.append(out_spec) return tuple(out_specs) def make_out_spec( mesh: Mesh, out_type: ShapeDtypeDuck, out_rep: set[core.AxisName] ) -> Chooser: subset = yield (s for s in powerset(mesh.shape) if out_rep | set(s) == set(mesh.shape)) elts = yield partitions(subset, len(out_type.shape)) return P(*(tuple(e) if e else None for e in elts)) # Combinatorial helper functions T = TypeVar('T') def partitions(s: Sequence[T], k: int) -> Iterator[list[list[T]]]: for indices in it.product(range(k), repeat=len(s)): outs: list[list[T]] = [[] for _ in range(k)] for i, elt in zip(indices, s): outs[i].append(elt) yield outs def powerset(s: Iterable[T]) -> Iterator[Sequence[T]]: s = list(s) return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1)) # Vmap test helpers Arr = Any def sample_shmap_batched(bdim_size: int) -> Chooser: name, *shmap_specs, args, ref = yield from sample_shmap() bdims = yield all_bdims(*map(op.attrgetter('shape'), args)) batch_args = map(partial(batchify_arg, bdim_size), bdims, args) return name + f'_vmap_{bdims}', bdims, *shmap_specs, batch_args, ref def all_bdims(*shapes: tuple[int, ...] ) -> Iterator[Sequence[int | None]]: bdims = ((None, *range(len(shape) + 1)) for shape in shapes) return (t for t in it.product(*bdims) if not all(e is None for e in t)) def batchify_arg(size: int, bdim: int | None, x: Arr) -> Arr: if bdim is None: return x else: iota = np.arange(1, size + 1, dtype=x.dtype).reshape( [1 if i != bdim else -1 for i in range(len(x.shape) + 1)]) return np.expand_dims(x, bdim) * iota def args_slicer(args: Sequence[Arr], bdims: Sequence[int | None] ) -> Callable[[int], Sequence[Arr]]: def slicer(x, bdim): if bdim is None: return lambda _: x else: return lambda i: x.take(indices=i, axis=bdim) slicers = map(slicer, args, bdims) return lambda i: [sl(i) for sl in slicers] class ShardMapSystematicTest(jtu.JaxTestCase): @staticmethod def make_mesh(mesh_shape): return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) out = shard_map(fun, mesh, in_specs, out_specs)(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @parameterized.parameters( (*params, check_rep) for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap) for check_rep in [True, False] ) @jax.default_matmul_precision("float32") def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): mesh = self.make_mesh(mesh) args = map(jnp.array, args) f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep) if jit: f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) @jax.default_matmul_precision("float32") def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _): mesh = self.make_mesh(mesh) no_sharding = [all(elt is None for elt in spec) for spec in in_specs] args, closed_over_args = partition_list(no_sharding, args) in_specs, _ = partition_list(no_sharding, in_specs) def f(x, *closed_over_args): @partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs) def g(*args): args = [x * arg for arg in args] args = merge_lists(no_sharding, args, closed_over_args) return fun(*args) if jit: g = jax.jit(g) return g(*args) jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2) @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) f = shard_map(fun, mesh, in_specs, out_specs) if jit: f = jax.jit(f) ans = jax.vmap(f, bdims)(*args) args_slice = args_slicer(args, bdims) expected_slices = [f(*args_slice(i)) for i in range(5)] treedef = jax.tree.structure(ans) if tree_util.treedef_is_strict_leaf(treedef): expected = jnp.stack(expected_slices) else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) tol = 1e-2 if jtu.test_device_matches(['tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @parameterized.parameters( sample(jtu.NUM_GENERATED_CASES.value, partial(sample_shmap_batched, 5))) def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _): mesh = self.make_mesh(mesh) args = map(jnp.array, args) no_sharding = [all(elt is None for elt in spec) for spec in in_specs] args, closed_over_args = partition_list(no_sharding, args) in_specs, _ = partition_list(no_sharding, in_specs) explicit_bdims, closed_over_bdims = partition_list(no_sharding, bdims) def f(x, *closed_over_args): @partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs) def g(*args): args = [x * arg for arg in args] args = merge_lists(no_sharding, args, closed_over_args) return fun(*args) if jit: g = jax.jit(g) if any(d is not None for d in explicit_bdims): return jax.vmap(g, explicit_bdims)(*args) else: return g(*args) xs = jnp.arange(5., dtype='float32') ans = jax.vmap(f, (0, *closed_over_bdims))(xs, *closed_over_args) args_slice = args_slicer((xs, *closed_over_args), (0, *closed_over_bdims)) expected_slices = [f(*args_slice(i)) for i in range(5)] treedef = jax.tree.structure(ans) if tree_util.treedef_is_strict_leaf(treedef): expected = jnp.stack(expected_slices) else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) tol = 1e-2 if jtu.test_device_matches(['tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") def test_custom_partitioning(self): self.skip_if_custom_partitioning_not_supported() mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) def partition(mesh, arg_shapes, result_shape): def lower_fn(x): return x return ( mesh, lower_fn, arg_shapes[0].sharding, (arg_shapes[0].sharding,), ) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): return arg_shapes[0].sharding def propagate_user_sharding(mesh, user_shape): return user_shape.sharding @custom_partitioning def f(x): return x f.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, ) @jax.jit def fwd(a): c = shard_map( f, mesh, check_rep=False, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c c = fwd(a) self.assertEqual(c.addressable_data(0).shape, (4, 2)) def test_partially_sharded_dim_with_auto(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) def g(x): return jnp.sum(x)[None] @jax.jit def f(x): x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(('i', 'j')))) re = shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))(x) re = jax.lax.with_sharding_constraint(re, NamedSharding(mesh, P(('i', 'j')))) return re self.assertAllClose(f(jnp.arange(8.)), jnp.array([1., 5., 9., 13.])) @jtu.with_config(jax_use_shardy_partitioner=True) # TODO(phawkins): enable this test unconditionally once shardy is the default. @unittest.skipIf(sdy is None, "shardy is not enabled") class SdyIntegrationTest(jtu.JaxTestCase): # Verify we can lower to a `ManualComputationOp`. def test_shardy_collective_permute(self): mesh = jtu.create_mesh((2,), ('x',)) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None)), ) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text()) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())