mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Buffer -> Array in some pxla type annotations.
PiperOrigin-RevId: 520975371
This commit is contained in:
parent
b37c741c6f
commit
82fcfc3851
@ -390,7 +390,7 @@ def shard_args(
|
||||
indices: Sequence[Sequence[Index]],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
args,
|
||||
) -> Sequence[Union[Sequence[xb.xla_client.Buffer], jax.Array]]:
|
||||
) -> Sequence[jax.Array]:
|
||||
"""Shard each argument data array along its leading axis.
|
||||
|
||||
Args:
|
||||
@ -554,7 +554,7 @@ def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: sharding_impls.XLACompatibleSharding,
|
||||
indices: Optional[Tuple[Index, ...]],
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
) -> Callable[[List[xc.ArrayImpl]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
@ -582,7 +582,7 @@ local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue, out_sharding, committed: bool,
|
||||
is_out_sharding_from_xla: bool
|
||||
) -> Callable[[Sequence[xb.xla_client.Buffer]], Any]:
|
||||
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
@ -630,7 +630,7 @@ def make_sharded_device_array(
|
||||
aval: ShapedArray,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
# Any is for JAX extensions implementing their own buffer.
|
||||
device_buffers: List[Union[Any, xb.xla_client.Buffer]],
|
||||
device_buffers: List[Any],
|
||||
indices: Optional[Tuple[Index, ...]] = None,
|
||||
):
|
||||
"""Returns a ShardedDeviceArray implementation based on arguments.
|
||||
@ -3527,8 +3527,8 @@ class _ThreadLocalState(threading.local):
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def device_put(x, devices: Sequence[xb.xla_client.Device],
|
||||
replicate: bool=False) -> List[xb.xla_client.Buffer]:
|
||||
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
replicate: bool=False) -> List[xc.ArrayImpl]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return [jax.device_put(x, device) for device in devices]
|
||||
|
@ -55,8 +55,7 @@ _deprecated_Device = xc.Device
|
||||
XlaOp = xc.XlaOp
|
||||
xe = xc._xla
|
||||
Backend = xe.Client
|
||||
Buffer = xc.Buffer
|
||||
_CppDeviceArray = xe.Buffer
|
||||
Buffer = _deprecated_DeviceArray
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 9, 2023:
|
||||
@ -71,6 +70,13 @@ _deprecations = {
|
||||
),
|
||||
_deprecated_DeviceArray,
|
||||
),
|
||||
"Buffer": (
|
||||
(
|
||||
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
|
||||
" instead."
|
||||
),
|
||||
_deprecated_DeviceArray,
|
||||
),
|
||||
"device_put": (
|
||||
(
|
||||
"jax.interpreters.xla.device_put is deprecated. Please use"
|
||||
|
Loading…
x
Reference in New Issue
Block a user