mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Raise an error if a user passes None to host_local_array_to_global_array
or global_array_to_host_local_array
PiperOrigin-RevId: 543596009
This commit is contained in:
parent
c6a60054b9
commit
c632cace1e
@ -42,6 +42,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Executable.cost_analysis() works on Cloud TPU
|
||||
* Added a warning if a non-allowlisted `jaxlib` plugin is in use.
|
||||
* Added `jax.tree_util.tree_leaves_with_path`.
|
||||
* `None` is not a valid input to
|
||||
`jax.experimental.multihost_utils.host_local_array_to_global_array` or
|
||||
`jax.experimental.multihost_utils.global_array_to_host_local_array`.
|
||||
Please use `jax.sharding.PartitionSpec()` if you wanted to replicate your
|
||||
input.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
|
||||
|
@ -225,6 +225,10 @@ def _global_to_local_aval(global_aval, mesh, pspec):
|
||||
|
||||
def host_local_array_to_global_array_impl(
|
||||
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
||||
if pspec is None:
|
||||
raise ValueError(
|
||||
'`None` is not a valid input to the pspecs argument. Please use '
|
||||
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
||||
# If the Array is not fully addressable i.e. not host local, return it.
|
||||
if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable:
|
||||
return arr
|
||||
@ -326,6 +330,10 @@ mlir.register_lowering(host_local_array_to_global_array_p, _ltg_lowering)
|
||||
|
||||
def global_array_to_host_local_array_impl(
|
||||
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
||||
if pspec is None:
|
||||
raise ValueError(
|
||||
'`None` is not a valid input to the pspecs argument. Please use '
|
||||
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
||||
# If the Array is already fully addressable i.e. host local, return it.
|
||||
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
|
||||
return arr
|
||||
|
Loading…
x
Reference in New Issue
Block a user