Check for ArrayImpl rather than sharding because this code is supposed to check for concrete Array until a shard_like primitive exists.

PiperOrigin-RevId: 486689809
This commit is contained in:
Yash Katariya 2022-11-07 09:52:02 -08:00 committed by jax authors
parent 587885bbd3
commit da519f3b2c

View File

@ -1352,12 +1352,12 @@ def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] =
# probably in the form of a primitive like `val = match_sharding_p.bind(x, val)`
# (so it works in staged-out code as well as 'eager' code). Related to
# equi-sharding.
if config.jax_array and shape is None and hasattr(x, 'sharding'):
if config.jax_array and shape is None and isinstance(x, array.ArrayImpl):
sharding = x.sharding # type: ignore[union-attr]
if (not dispatch.is_single_device_sharding(sharding) and
not isinstance(sharding, PmapSharding)):
return array.make_array_from_callback(
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
return val