mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check for ArrayImpl
rather than sharding
because this code is supposed to check for concrete Array until a shard_like
primitive exists.
PiperOrigin-RevId: 486689809
This commit is contained in:
parent
587885bbd3
commit
da519f3b2c
@ -1352,12 +1352,12 @@ def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] =
|
||||
# probably in the form of a primitive like `val = match_sharding_p.bind(x, val)`
|
||||
# (so it works in staged-out code as well as 'eager' code). Related to
|
||||
# equi-sharding.
|
||||
if config.jax_array and shape is None and hasattr(x, 'sharding'):
|
||||
if config.jax_array and shape is None and isinstance(x, array.ArrayImpl):
|
||||
sharding = x.sharding # type: ignore[union-attr]
|
||||
if (not dispatch.is_single_device_sharding(sharding) and
|
||||
not isinstance(sharding, PmapSharding)):
|
||||
return array.make_array_from_callback(
|
||||
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
|
||||
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
|
||||
return val
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user