mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Raise rather than return error
This commit is contained in:
parent
0fa541972e
commit
2518c6233e
@ -1669,7 +1669,7 @@ def get_sharding(sharding, ndim):
|
||||
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
if not context_mesh:
|
||||
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
return NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user