Buffer -> Array in some pxla type annotations.

PiperOrigin-RevId: 520975371
This commit is contained in:
Parker Schuh 2023-03-31 11:41:49 -07:00 committed by jax authors
parent b37c741c6f
commit 82fcfc3851
2 changed files with 14 additions and 8 deletions

View File

@ -390,7 +390,7 @@ def shard_args(
indices: Sequence[Sequence[Index]],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
args,
) -> Sequence[Union[Sequence[xb.xla_client.Buffer], jax.Array]]:
) -> Sequence[jax.Array]:
"""Shard each argument data array along its leading axis.
Args:
@ -554,7 +554,7 @@ def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
indices: Optional[Tuple[Index, ...]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
) -> Callable[[List[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
@ -582,7 +582,7 @@ local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool,
is_out_sharding_from_xla: bool
) -> Callable[[Sequence[xb.xla_client.Buffer]], Any]:
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
@ -630,7 +630,7 @@ def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: List[Union[Any, xb.xla_client.Buffer]],
device_buffers: List[Any],
indices: Optional[Tuple[Index, ...]] = None,
):
"""Returns a ShardedDeviceArray implementation based on arguments.
@ -3527,8 +3527,8 @@ class _ThreadLocalState(threading.local):
_thread_local_state = _ThreadLocalState()
def device_put(x, devices: Sequence[xb.xla_client.Device],
replicate: bool=False) -> List[xb.xla_client.Buffer]:
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> List[xc.ArrayImpl]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return [jax.device_put(x, device) for device in devices]

View File

@ -55,8 +55,7 @@ _deprecated_Device = xc.Device
XlaOp = xc.XlaOp
xe = xc._xla
Backend = xe.Client
Buffer = xc.Buffer
_CppDeviceArray = xe.Buffer
Buffer = _deprecated_DeviceArray
_deprecations = {
# Added Feb 9, 2023:
@ -71,6 +70,13 @@ _deprecations = {
),
_deprecated_DeviceArray,
),
"Buffer": (
(
"jax.interpreters.xla.Buffer is deprecated. Use jax.Array"
" instead."
),
_deprecated_DeviceArray,
),
"device_put": (
(
"jax.interpreters.xla.device_put is deprecated. Please use"