mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.
PiperOrigin-RevId: 687089357
This commit is contained in:
parent
5df4878ad0
commit
57a95a77ff
@ -5299,7 +5299,11 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
# whenever x is weak, but avoids introducing weak types with something like
|
||||
# array([1, 2, 3])
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
sharding = canonicalize_device_to_sharding(device)
|
||||
if (config.sharding_in_types.value and device is None and
|
||||
isinstance(object, Array)):
|
||||
sharding = object.sharding
|
||||
else:
|
||||
sharding = canonicalize_device_to_sharding(device) # type: ignore
|
||||
|
||||
# Use device_put to avoid a copy for ndarray inputs.
|
||||
if (not copy and isinstance(object, np.ndarray) and
|
||||
|
@ -251,7 +251,12 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]:
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
if config.sharding_in_types.value:
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type,
|
||||
getattr(x, "sharding", None))
|
||||
for x in args]
|
||||
else:
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
|
||||
|
@ -4420,7 +4420,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return jnp.array(inp, dtype=np.float32, device=s)
|
||||
|
||||
out = f()
|
||||
print(f.trace().jaxpr)
|
||||
self.assertArraysEqual(out, inp.astype('float32'))
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertEqual(out.dtype, np.float32)
|
||||
@ -4884,6 +4883,22 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
def test_jnp_array(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
assert x.dtype == jnp.int32
|
||||
y = jnp.array(x, dtype=jnp.float32)
|
||||
self.assertEqual(y.dtype, jnp.float32)
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
f(arr)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user