mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Allow pspec to be passed to device_put if there is a mesh in the surrounding context
PiperOrigin-RevId: 737812111
This commit is contained in:
parent
f174b00f23
commit
549973dec6
@ -67,7 +67,9 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.mesh import get_concrete_mesh
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding)
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
@ -2280,11 +2282,20 @@ def _check_sharding(aval, s):
|
||||
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
|
||||
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
||||
|
||||
def pspec_to_sharding(val):
|
||||
if isinstance(val, P):
|
||||
mesh = get_concrete_mesh()
|
||||
if mesh is None:
|
||||
raise ValueError(
|
||||
"Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is"
|
||||
" passed to device_put")
|
||||
return NamedSharding(mesh, val)
|
||||
return val
|
||||
|
||||
def device_put(
|
||||
x,
|
||||
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
|
||||
device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||
*, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None,
|
||||
donate: bool | Any = False, may_alias: bool | None | Any = None):
|
||||
"""Transfers ``x`` to ``device``.
|
||||
|
||||
@ -2333,6 +2344,9 @@ def device_put(
|
||||
src_flat = flatten_axes("device_put source", treedef, src)
|
||||
src_flat = list(map(_infer_src_sharding, src_flat, x_flat))
|
||||
|
||||
device_flat = map(pspec_to_sharding, device_flat)
|
||||
src_flat = map(pspec_to_sharding, src_flat)
|
||||
|
||||
if isinstance(donate, bool):
|
||||
donate_flat = [donate] * len(x_flat)
|
||||
else:
|
||||
|
@ -6138,6 +6138,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertDictEqual(out.sharding.mesh._axis_types_dict,
|
||||
{AxisType.Auto: ('x',)})
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_device_put_use_mesh(self, mesh):
|
||||
out = jax.device_put(np.arange(8), P('x'))
|
||||
self.assertArraysEqual(out, np.arange(8))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
def test_device_put_no_use_mesh_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Please set a mesh via `jax.sharding.use_mesh` if a PartitionSpec is'
|
||||
' passed to device_put'):
|
||||
jax.device_put(np.arange(8), P('x'))
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_inputs_different_context(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user