Remove _maybe_device_put because jax.device_put accepts None on the device parameter

PiperOrigin-RevId: 618223250
This commit is contained in:
Yash Katariya 2024-03-22 10:39:08 -07:00 committed by jax authors
parent 5f467b96af
commit d7e5ddee5e

View File

@ -2287,9 +2287,6 @@ def empty_like(prototype: ArrayLike | DuckTypedArray,
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
return arr if device is None else jax.device_put(arr, device)
def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
if isinstance(device, xc.Device):
return SingleDeviceSharding(device)
@ -2308,7 +2305,8 @@ def full(shape: Any, fill_value: ArrayLike,
shape = canonicalize_shape(shape)
return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device))
else:
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
return jax.device_put(
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
@util.implements(np.full_like)
@ -2328,7 +2326,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
else:
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type]
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
return jax.device_put(
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
@util.implements(np.zeros)