mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Pallas TPU] Fix some issues introduced by the recent changes
The new Pallas-specific aval interacts very badly with the default abstract eval rules of most lax ops, causing frequent failures. PiperOrigin-RevId: 676362377
This commit is contained in:
parent
8a8f74663f
commit
3ccca54e42
@ -161,7 +161,7 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
|
||||
else:
|
||||
sharding_str = ""
|
||||
memoryspace_str = (
|
||||
"" if self.memory_space is None else f"{self.memory_space}>"
|
||||
"" if self.memory_space is None else f"<{self.memory_space}>"
|
||||
)
|
||||
return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}"
|
||||
|
||||
@ -206,8 +206,10 @@ class MemoryRef:
|
||||
)
|
||||
|
||||
def get_ref_aval(self) -> AbstractMemoryRef:
|
||||
# TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we
|
||||
# try to apply JAX ops to it.
|
||||
return AbstractMemoryRef(
|
||||
ShapedArrayWithMemorySpace(self.shape, self.dtype), self.memory_space)
|
||||
jax_core.ShapedArray(self.shape, self.dtype), self.memory_space)
|
||||
|
||||
|
||||
class AbstractMemoryRef(state.AbstractRef):
|
||||
|
@ -334,7 +334,13 @@ def _pallas_call_abstract_eval(
|
||||
*avals, out_avals: tuple[jax_core.AbstractValue, ...], **_
|
||||
):
|
||||
del avals
|
||||
return out_avals
|
||||
# Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
|
||||
return [
|
||||
jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
|
||||
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace)
|
||||
else a
|
||||
for a in out_avals
|
||||
]
|
||||
|
||||
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
Loading…
x
Reference in New Issue
Block a user