mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
start adding EArray, a jax.Array analog that can contain extended dtypes
This commit is contained in:
parent
9616900cc9
commit
89f26db36d
@ -201,6 +201,7 @@ py_library_providing_imports_info(
|
||||
"_src/debugging.py",
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/earray.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
"_src/interpreters/ad.py",
|
||||
|
110
jax/_src/earray.py
Normal file
110
jax/_src/earray.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
# EArray is an Array that can contain extended dtypes.
|
||||
class EArray(basearray.Array):
|
||||
__slots__ = ['aval', '_data']
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
__array_priority__ = 100
|
||||
|
||||
def __init__(self, aval, data):
|
||||
self.aval = aval
|
||||
self._data = data
|
||||
|
||||
def block_until_ready(self):
|
||||
_ = self._data.block_until_ready()
|
||||
return self
|
||||
|
||||
def copy_to_host_async(self):
|
||||
self._data.copy_to_host_async()
|
||||
|
||||
def copy(self):
|
||||
return EArray(self.aval, self._data.copy())
|
||||
|
||||
def __repr__(self):
|
||||
return 'E' + repr(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0: raise TypeError('iteration over a 0-d array')
|
||||
raise NotImplementedError
|
||||
|
||||
# forward to aval
|
||||
shape = property(lambda self: self.aval.shape) # type: ignore[assignment]
|
||||
dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment]
|
||||
|
||||
# computed from shape and dtype
|
||||
ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment]
|
||||
size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment]
|
||||
itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment]
|
||||
def __len__(self):
|
||||
if self.ndim == 0: raise TypeError('len() of unsized object')
|
||||
return self.shape[0]
|
||||
|
||||
# forward to self._data
|
||||
devices = property(lambda self: self._data.devices) # type: ignore[assignment]
|
||||
_committed = property(lambda self: self._data._committed)
|
||||
is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment]
|
||||
is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment]
|
||||
delete = property(lambda self: self._data.delete) # type: ignore[assignment]
|
||||
is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment]
|
||||
on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment]
|
||||
unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment]
|
||||
|
||||
# defer to extended dtype rules
|
||||
@property
|
||||
def sharding(self):
|
||||
phys_sharding = self._data.sharding
|
||||
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
|
||||
|
||||
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
|
||||
|
||||
def addressable_data(self, index: int) -> EArray:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def addressable_shards(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def global_shards(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(mattjj): _set_array_base_attributes
|
||||
|
||||
def _earray_shard_arg_handler(x, sharding):
|
||||
arr = x._data
|
||||
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
|
||||
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
core.pytype_aval_mappings[EArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
|
||||
tree_util.dispatch_registry.register_node(
|
||||
EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))
|
@ -22,3 +22,6 @@ from jax.experimental.x64_context import (
|
||||
from jax._src.callback import (
|
||||
io_callback as io_callback
|
||||
)
|
||||
from jax._src.earray import (
|
||||
EArray as EArray
|
||||
)
|
||||
|
@ -27,6 +27,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import earray
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
@ -554,6 +555,70 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
|
||||
self.assertAllClose(new_scale, jnp.float32(1.0))
|
||||
|
||||
class EArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_extended_dtypes_at_rest(self, jit):
|
||||
# Test a trivial isomorphic-to-float32 extended dtype working with EArray
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import pxla
|
||||
|
||||
class foo(dtypes.extended): pass
|
||||
|
||||
class FooTyRules:
|
||||
|
||||
@staticmethod
|
||||
def convert_to(foo_dtype, target_dtype):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def physical_element_aval(foo_dtype):
|
||||
return core.ShapedArray((), dtypes.dtype('float32'))
|
||||
|
||||
@staticmethod
|
||||
def replicate_trailing_dims(ctx, val, aval):
|
||||
del ctx, aval
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def logical_sharding(aval, phys_sharding):
|
||||
return phys_sharding
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed):
|
||||
phys_sharding = out_sharding # unlike KeyTyRules, assume same shape
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
|
||||
return lambda bufs: earray.EArray(aval, phys_handler(bufs))
|
||||
|
||||
@staticmethod
|
||||
def physical_sharding(aval, sharding):
|
||||
return sharding # unlike KeyTyRules, assume same shape
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FooTy(dtypes.ExtendedDType):
|
||||
name: str = 'foo'
|
||||
_rules: type = FooTyRules
|
||||
type: type = foo
|
||||
|
||||
# Can we make one?
|
||||
def f(x):
|
||||
return jax.lax.convert_element_type(x, FooTy())
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
x = f(jnp.arange(3, dtype='float32')) # don't crash
|
||||
self.assertIsInstance(x.dtype, FooTy)
|
||||
|
||||
# Can we consume one?
|
||||
def g(x):
|
||||
self.assertIsInstance(x.dtype, FooTy)
|
||||
return x
|
||||
if jit:
|
||||
g = jax.jit(g)
|
||||
y = g(x)
|
||||
self.assertIsInstance(y.dtype, FooTy)
|
||||
|
||||
|
||||
class TestPromotionTables(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user