Raise an error if a user passes None to host_local_array_to_global_array or global_array_to_host_local_array

PiperOrigin-RevId: 543596009
This commit is contained in:
Yash Katariya 2023-06-26 18:15:04 -07:00 committed by jax authors
parent c6a60054b9
commit c632cace1e
2 changed files with 13 additions and 0 deletions

View File

@ -42,6 +42,11 @@ Remember to align the itemized text with the first line of an item within a list
* Executable.cost_analysis() works on Cloud TPU
* Added a warning if a non-allowlisted `jaxlib` plugin is in use.
* Added `jax.tree_util.tree_leaves_with_path`.
* `None` is not a valid input to
`jax.experimental.multihost_utils.host_local_array_to_global_array` or
`jax.experimental.multihost_utils.global_array_to_host_local_array`.
Please use `jax.sharding.PartitionSpec()` if you wanted to replicate your
input.
* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel

View File

@ -225,6 +225,10 @@ def _global_to_local_aval(global_aval, mesh, pspec):
def host_local_array_to_global_array_impl(
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
if pspec is None:
raise ValueError(
'`None` is not a valid input to the pspecs argument. Please use '
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
# If the Array is not fully addressable i.e. not host local, return it.
if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable:
return arr
@ -326,6 +330,10 @@ mlir.register_lowering(host_local_array_to_global_array_p, _ltg_lowering)
def global_array_to_host_local_array_impl(
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
if pspec is None:
raise ValueError(
'`None` is not a valid input to the pspecs argument. Please use '
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
# If the Array is already fully addressable i.e. host local, return it.
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
return arr