1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Check that memory_kind of an aval is always None

PiperOrigin-RevId: 744136969
This commit is contained in:
Yash Katariya 2025-04-04 19:22:31 -07:00 committed by jax authors
parent aab6613944
commit fc5d9a4fce

@ -1899,6 +1899,7 @@ def get_sharding(sharding, shape):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
_check_divisibility(out_s, shape)
assert out_s.memory_kind is None
return out_s
def str_short_aval(shape, dtype, mesh, spec, vma,