Explicitly check for DShapedArray in _to_logical_sharding instead of returning a sharding by default.

This commit is contained in:
Alexey Radul 2023-07-12 13:08:57 -04:00
parent ef9f1cbec3
commit 4794a827ca

View File

@ -49,6 +49,7 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import DShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -2125,14 +2126,13 @@ def _to_logical_sharding(
) -> Optional[sharding_impls.XLACompatibleSharding]:
if is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, ShapedArray):
elif isinstance(aval, (ShapedArray, DShapedArray)):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
return sharding
elif isinstance(aval, core.AbstractToken):
return None
else:
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
return sharding
raise TypeError(aval)
@profiler.annotate_function