1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Allow pspec to be passed to device_put if there is a mesh in the surrounding context

PiperOrigin-RevId: 737812111
This commit is contained in:
Yash Katariya 2025-03-17 17:47:21 -07:00 committed by jax authors
parent f174b00f23
commit 549973dec6
2 changed files with 30 additions and 3 deletions
jax/_src
tests

@ -67,7 +67,9 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.mesh import get_concrete_mesh
from jax._src.sharding_impls import (
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
s.shard_shape(aval.shape) # should raise an Error if incompatible
def pspec_to_sharding(val):
if isinstance(val, P):
mesh = get_concrete_mesh()
if mesh is None:
raise ValueError(
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
" passed to device_put")
return NamedSharding(mesh, val)
return val
def device_put(
x,
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
donate: bool | Any = False, may_alias: bool | None | Any = None):
"""Transfers ``x`` to ``device``.
@ -2333,6 +2344,9 @@ def device_put(
src_flat = flatten_axes("device_put source", treedef, src)
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
device_flat = map(pspec_to_sharding, device_flat)
src_flat = map(pspec_to_sharding, src_flat)
if isinstance(donate, bool):
donate_flat = [donate] * len(x_flat)
else:

@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
{AxisType.Auto: ('x',)})
@jtu.with_user_mesh((2,), 'x')
def test_device_put_use_mesh(self, mesh):
out = jax.device_put(np.arange(8), P('x'))
self.assertArraysEqual(out, np.arange(8))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
def test_device_put_no_use_mesh_error(self):
with self.assertRaisesRegex(
ValueError,
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
' passed to device_put'):
jax.device_put(np.arange(8), P('x'))
@jtu.with_user_mesh((2,), 'x')
def test_inputs_different_context(self, mesh):
np_inp = np.arange(16).reshape(8, 2)