mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Explicitly check for DShapedArray in _to_logical_sharding instead of returning a sharding by default.
This commit is contained in:
parent
ef9f1cbec3
commit
4794a827ca
@ -49,6 +49,7 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.core import DShapedArray
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -2125,14 +2126,13 @@ def _to_logical_sharding(
|
||||
) -> Optional[sharding_impls.XLACompatibleSharding]:
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
elif isinstance(aval, ShapedArray):
|
||||
elif isinstance(aval, (ShapedArray, DShapedArray)):
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
return sharding
|
||||
elif isinstance(aval, core.AbstractToken):
|
||||
return None
|
||||
else:
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
return sharding
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
|
Loading…
x
Reference in New Issue
Block a user