diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3fab08da3..98a2b5b72 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5299,7 +5299,11 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # whenever x is weak, but avoids introducing weak types with something like # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) - sharding = canonicalize_device_to_sharding(device) + if (config.sharding_in_types.value and device is None and + isinstance(object, Array)): + sharding = object.sharding + else: + sharding = canonicalize_device_to_sharding(device) # type: ignore # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index c5b1530ca..27496ad99 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -251,7 +251,12 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]: else: to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] - return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] + if config.sharding_in_types.value: + return [lax._convert_element_type(x, to_dtype, weak_type, + getattr(x, "sharding", None)) + for x in args] + else: + return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2585c4171..9d21b5962 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4420,7 +4420,6 @@ class ArrayPjitTest(jtu.JaxTestCase): return jnp.array(inp, dtype=np.float32, device=s) out = f() - print(f.trace().jaxpr) self.assertArraysEqual(out, inp.astype('float32')) self.assertEqual(out.sharding, s) self.assertEqual(out.dtype, np.float32) @@ -4884,6 +4883,22 @@ class ShardingInTypesTest(jtu.JaxTestCase): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) + def test_jnp_array(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + assert x.dtype == jnp.int32 + y = jnp.array(x, dtype=jnp.float32) + self.assertEqual(y.dtype, jnp.float32) + self.assertEqual(y.sharding.spec, s.spec) + return y + + f(arr) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):