diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 00bbbbe88..1a956de1f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -161,7 +161,7 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray): else: sharding_str = "" memoryspace_str = ( - "" if self.memory_space is None else f"{self.memory_space}>" + "" if self.memory_space is None else f"<{self.memory_space}>" ) return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" @@ -206,8 +206,10 @@ class MemoryRef: ) def get_ref_aval(self) -> AbstractMemoryRef: + # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we + # try to apply JAX ops to it. return AbstractMemoryRef( - ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space) + jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) class AbstractMemoryRef(state.AbstractRef): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index a3ca823a1..1c10d2bda 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -334,7 +334,13 @@ def _pallas_call_abstract_eval( *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_ ): del avals - return out_avals + # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. + return [ + jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) + if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) + else a + for a in out_avals + ] pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)