mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove _maybe_device_put because jax.device_put accepts None
on the device parameter
PiperOrigin-RevId: 618223250
This commit is contained in:
parent
5f467b96af
commit
d7e5ddee5e
@ -2287,9 +2287,6 @@ def empty_like(prototype: ArrayLike | DuckTypedArray,
|
||||
return zeros_like(prototype, dtype=dtype, shape=shape, device=device)
|
||||
|
||||
|
||||
def _maybe_device_put(arr: Array, device: xc.Device | Sharding | None) -> Array:
|
||||
return arr if device is None else jax.device_put(arr, device)
|
||||
|
||||
def _normalize_to_sharding(device: xc.Device | Sharding | None) -> Sharding | None:
|
||||
if isinstance(device, xc.Device):
|
||||
return SingleDeviceSharding(device)
|
||||
@ -2308,7 +2305,8 @@ def full(shape: Any, fill_value: ArrayLike,
|
||||
shape = canonicalize_shape(shape)
|
||||
return lax.full(shape, fill_value, dtype, sharding=_normalize_to_sharding(device))
|
||||
else:
|
||||
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
return jax.device_put(
|
||||
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
|
||||
|
||||
@util.implements(np.full_like)
|
||||
@ -2328,7 +2326,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
|
||||
else:
|
||||
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
|
||||
dtype = result_type(a) if dtype is None else dtype # type: ignore[arg-type]
|
||||
return _maybe_device_put(broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
return jax.device_put(
|
||||
broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
|
||||
|
||||
|
||||
@util.implements(np.zeros)
|
||||
|
Loading…
x
Reference in New Issue
Block a user