From 4794a827ca29723ee1234bf069386921092012dc Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Wed, 12 Jul 2023 13:08:57 -0400 Subject: [PATCH] Explicitly check for DShapedArray in _to_logical_sharding instead of returning a sharding by default. --- jax/_src/interpreters/pxla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1a198f590..261b53240 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -49,6 +49,7 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types from jax._src.config import config +from jax._src.core import DShapedArray from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -2125,14 +2126,13 @@ def _to_logical_sharding( ) -> Optional[sharding_impls.XLACompatibleSharding]: if is_unspecified(sharding) or is_auto(sharding): return None - elif isinstance(aval, ShapedArray): + elif isinstance(aval, (ShapedArray, DShapedArray)): assert isinstance(sharding, sharding_impls.XLACompatibleSharding) return sharding elif isinstance(aval, core.AbstractToken): return None else: - assert isinstance(sharding, sharding_impls.XLACompatibleSharding) - return sharding + raise TypeError(aval) @profiler.annotate_function