mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check that memory_kind of an aval is always None
PiperOrigin-RevId: 744136969
This commit is contained in:
parent
aab6613944
commit
fc5d9a4fce
@ -1899,6 +1899,7 @@ def get_sharding(sharding, shape):
|
||||
raise ValueError("Mesh of an aval must be an AbstractMesh. "
|
||||
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
|
||||
_check_divisibility(out_s, shape)
|
||||
assert out_s.memory_kind is None
|
||||
return out_s
|
||||
|
||||
def str_short_aval(shape, dtype, mesh, spec, vma,
|
||||
|
Loading…
x
Reference in New Issue
Block a user