Raise rather than return error

This commit is contained in:
Jake VanderPlas 2024-12-17 11:20:55 -08:00
parent 0fa541972e
commit 2518c6233e

View File

@ -1669,7 +1669,7 @@ def get_sharding(sharding, ndim):
context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))