[Pallas TPU] Fix some issues introduced by the recent changes

The new Pallas-specific aval interacts very badly with the default abstract
eval rules of most lax ops, causing frequent failures.

PiperOrigin-RevId: 676362377
This commit is contained in:
Adam Paszke 2024-09-19 04:39:11 -07:00 committed by jax authors
parent 8a8f74663f
commit 3ccca54e42
2 changed files with 11 additions and 3 deletions

View File

@ -161,7 +161,7 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
else:
sharding_str = ""
memoryspace_str = (
"" if self.memory_space is None else f"{self.memory_space}>"
"" if self.memory_space is None else f"<{self.memory_space}>"
)
return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}"
@ -206,8 +206,10 @@ class MemoryRef:
)
def get_ref_aval(self) -> AbstractMemoryRef:
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
# try to apply JAX ops to it.
return AbstractMemoryRef(
ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space)
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
class AbstractMemoryRef(state.AbstractRef):

View File

@ -334,7 +334,13 @@ def _pallas_call_abstract_eval(
*avals, out_avals: tuple[jax_core.AbstractValue, ...], **_
):
del avals
return out_avals
# Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
return [
jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
else a
for a in out_avals
]
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)