From 82fcfc3851b57c7b6d269282947cf7ccdf19d668 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 31 Mar 2023 11:41:49 -0700 Subject: [PATCH] Buffer -> Array in some pxla type annotations. PiperOrigin-RevId: 520975371 --- jax/_src/interpreters/pxla.py | 12 ++++++------ jax/interpreters/xla.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 97e0d3481..be5e02af5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -390,7 +390,7 @@ def shard_args( indices: Sequence[Sequence[Index]], shardings: Sequence[sharding_impls.XLACompatibleSharding], args, -) -> Sequence[Union[Sequence[xb.xla_client.Buffer], jax.Array]]: +) -> Sequence[jax.Array]: """Shard each argument data array along its leading axis. Args: @@ -554,7 +554,7 @@ def local_aval_to_result_handler( aval: core.AbstractValue, sharding: sharding_impls.XLACompatibleSharding, indices: Optional[Tuple[Index, ...]], -) -> Callable[[List[xb.xla_client.Buffer]], Any]: +) -> Callable[[List[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. Args: @@ -582,7 +582,7 @@ local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} def global_aval_to_result_handler( aval: core.AbstractValue, out_sharding, committed: bool, is_out_sharding_from_xla: bool -) -> Callable[[Sequence[xb.xla_client.Buffer]], Any]: +) -> Callable[[Sequence[xc.ArrayImpl]], Any]: """Returns a function for handling the raw buffers of a single output aval. Args: @@ -630,7 +630,7 @@ def make_sharded_device_array( aval: ShapedArray, sharding_spec: Optional[ShardingSpec], # Any is for JAX extensions implementing their own buffer. - device_buffers: List[Union[Any, xb.xla_client.Buffer]], + device_buffers: List[Any], indices: Optional[Tuple[Index, ...]] = None, ): """Returns a ShardedDeviceArray implementation based on arguments. @@ -3527,8 +3527,8 @@ class _ThreadLocalState(threading.local): _thread_local_state = _ThreadLocalState() -def device_put(x, devices: Sequence[xb.xla_client.Device], - replicate: bool=False) -> List[xb.xla_client.Buffer]: +def device_put(x, devices: Sequence[xc.ArrayImpl], + replicate: bool=False) -> List[xc.ArrayImpl]: """Call device_put on a sequence of devices and return a flat sequence of buffers.""" if replicate: return [jax.device_put(x, device) for device in devices] diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 8c5c1487e..30789c1d9 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -55,8 +55,7 @@ _deprecated_Device = xc.Device XlaOp = xc.XlaOp xe = xc._xla Backend = xe.Client -Buffer = xc.Buffer -_CppDeviceArray = xe.Buffer +Buffer = _deprecated_DeviceArray _deprecations = { # Added Feb 9, 2023: @@ -71,6 +70,13 @@ _deprecations = { ), _deprecated_DeviceArray, ), + "Buffer": ( + ( + "jax.interpreters.xla.Buffer is deprecated. Use jax.Array" + " instead." + ), + _deprecated_DeviceArray, + ), "device_put": ( ( "jax.interpreters.xla.device_put is deprecated. Please use"