[sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.

PiperOrigin-RevId: 687089357
This commit is contained in:
Yash Katariya 2024-10-17 16:51:00 -07:00 committed by jax authors
parent 5df4878ad0
commit 57a95a77ff
3 changed files with 27 additions and 3 deletions

View File

@ -5299,7 +5299,11 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
sharding = canonicalize_device_to_sharding(device)
if (config.sharding_in_types.value and device is None and
isinstance(object, Array)):
sharding = object.sharding
else:
sharding = canonicalize_device_to_sharding(device) # type: ignore
# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and

View File

@ -251,7 +251,12 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]:
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
if config.sharding_in_types.value:
return [lax._convert_element_type(x, to_dtype, weak_type,
getattr(x, "sharding", None))
for x in args]
else:
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:

View File

@ -4420,7 +4420,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
return jnp.array(inp, dtype=np.float32, device=s)
out = f()
print(f.trace().jaxpr)
self.assertArraysEqual(out, inp.astype('float32'))
self.assertEqual(out.sharding, s)
self.assertEqual(out.dtype, np.float32)
@ -4884,6 +4883,22 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
def test_jnp_array(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
assert x.dtype == jnp.int32
y = jnp.array(x, dtype=jnp.float32)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.sharding.spec, s.spec)
return y
f(arr)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):