mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 06:16:05 +00:00

This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT. This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily. Fixes https://github.com/jax-ml/jax/issues/24521 PiperOrigin-RevId: 694274616
124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
from jax._src import api_util
|
|
from jax._src import basearray
|
|
from jax._src import core
|
|
from jax._src import tree_util
|
|
from jax._src import sharding_impls
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.interpreters import xla
|
|
from jax._src.util import safe_zip, safe_map
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
# EArray is an Array that can contain extended dtypes.
|
|
class EArray(basearray.Array):
|
|
__slots__ = ['aval', '_data']
|
|
__hash__ = None # type: ignore[assignment]
|
|
__array_priority__ = 100
|
|
|
|
def __init__(self, aval, data):
|
|
self.aval = aval
|
|
self._data = data
|
|
|
|
def block_until_ready(self):
|
|
_ = self._data.block_until_ready()
|
|
return self
|
|
|
|
def copy_to_host_async(self):
|
|
self._data.copy_to_host_async()
|
|
|
|
def copy(self):
|
|
return EArray(self.aval, self._data.copy())
|
|
|
|
def __repr__(self):
|
|
return 'E' + repr(self._data)
|
|
|
|
def __iter__(self):
|
|
if self.ndim == 0: raise TypeError('iteration over a 0-d array')
|
|
raise NotImplementedError
|
|
|
|
# forward to aval
|
|
shape = property(lambda self: self.aval.shape) # type: ignore[assignment]
|
|
dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment]
|
|
|
|
# computed from shape and dtype
|
|
ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment]
|
|
size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment]
|
|
itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment]
|
|
def __len__(self):
|
|
if self.ndim == 0: raise TypeError('len() of unsized object')
|
|
return self.shape[0]
|
|
|
|
# forward to self._data
|
|
devices = property(lambda self: self._data.devices) # type: ignore[assignment]
|
|
_committed = property(lambda self: self._data._committed)
|
|
is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment]
|
|
is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment]
|
|
delete = property(lambda self: self._data.delete) # type: ignore[assignment]
|
|
is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment]
|
|
on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment]
|
|
unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment]
|
|
|
|
# defer to extended dtype rules
|
|
@property
|
|
def sharding(self):
|
|
phys_sharding = self._data.sharding
|
|
return sharding_impls.logical_sharding(self.aval, phys_sharding)
|
|
|
|
@property
|
|
def committed(self):
|
|
return self._data.committed
|
|
|
|
@property
|
|
def device(self):
|
|
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
|
|
return self._data.device
|
|
return self.sharding
|
|
|
|
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
|
|
|
def addressable_data(self, index: int) -> EArray:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def addressable_shards(self):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def global_shards(self):
|
|
raise NotImplementedError
|
|
|
|
# TODO(mattjj): _set_array_base_attributes
|
|
|
|
def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
|
|
arrs = [x._data for x in xs]
|
|
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
|
|
for x, sharding in zip(xs, shardings)]
|
|
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
|
|
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
|
|
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
|
|
|
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
|
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
|
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
|
tree_util.dispatch_registry.register_node(
|
|
EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))
|