start adding EArray, a jax.Array analog that can contain extended dtypes

This commit is contained in:
Matthew Johnson 2024-03-14 15:53:33 -07:00
parent 9616900cc9
commit 89f26db36d
4 changed files with 179 additions and 0 deletions

View File

@ -201,6 +201,7 @@ py_library_providing_imports_info(
"_src/debugging.py",
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",

110
jax/_src/earray.py Normal file
View File

@ -0,0 +1,110 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import tree_util
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.util import safe_zip, safe_map
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# EArray is an Array that can contain extended dtypes.
class EArray(basearray.Array):
__slots__ = ['aval', '_data']
__hash__ = None # type: ignore[assignment]
__array_priority__ = 100
def __init__(self, aval, data):
self.aval = aval
self._data = data
def block_until_ready(self):
_ = self._data.block_until_ready()
return self
def copy_to_host_async(self):
self._data.copy_to_host_async()
def copy(self):
return EArray(self.aval, self._data.copy())
def __repr__(self):
return 'E' + repr(self._data)
def __iter__(self):
if self.ndim == 0: raise TypeError('iteration over a 0-d array')
raise NotImplementedError
# forward to aval
shape = property(lambda self: self.aval.shape) # type: ignore[assignment]
dtype = property(lambda self: self.aval.dtype) # type: ignore[assignment]
# computed from shape and dtype
ndim = property(lambda self: len(self.aval.shape)) # type: ignore[assignment]
size = property(lambda self: math.prod(self.aval.shape)) # type: ignore[assignment]
itemsize = property(lambda self: self.aval.dtype.itemsize) # type: ignore[assignment]
def __len__(self):
if self.ndim == 0: raise TypeError('len() of unsized object')
return self.shape[0]
# forward to self._data
devices = property(lambda self: self._data.devices) # type: ignore[assignment]
_committed = property(lambda self: self._data._committed)
is_fully_addressable = property(lambda self: self._data.is_fully_addressable) # type: ignore[assignment]
is_fully_replicated = property(lambda self: self._data.is_fully_replicated) # type: ignore[assignment]
delete = property(lambda self: self._data.delete) # type: ignore[assignment]
is_deleted = property(lambda self: self._data.is_deleted) # type: ignore[assignment]
on_device_size_in_bytes = property(lambda self: self._data.on_device_size_in_bytes) # type: ignore[assignment]
unsafe_buffer_pointer = property(lambda self: self._data.unsafe_buffer_pointer) # type: ignore[assignment]
# defer to extended dtype rules
@property
def sharding(self):
phys_sharding = self._data.sharding
return self.aval.dtype._rules.logical_sharding(self.aval, phys_sharding)
# TODO(mattjj): not implemented below here, need more methods from ArrayImpl
def addressable_data(self, index: int) -> EArray:
raise NotImplementedError
@property
def addressable_shards(self):
raise NotImplementedError
@property
def global_shards(self):
raise NotImplementedError
# TODO(mattjj): _set_array_base_attributes
def _earray_shard_arg_handler(x, sharding):
arr = x._data
phys_sharding = x.aval.dtype._rules.physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(
EArray, lambda x: ((x._data,), x.aval), lambda a, xs: EArray(a, xs[0]))

View File

@ -22,3 +22,6 @@ from jax.experimental.x64_context import (
from jax._src.callback import (
io_callback as io_callback
)
from jax._src.earray import (
EArray as EArray
)

View File

@ -27,6 +27,7 @@ import numpy as np
import jax
from jax import numpy as jnp
from jax._src import earray
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
@ -554,6 +555,70 @@ class DtypesTest(jtu.JaxTestCase):
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
self.assertAllClose(new_scale, jnp.float32(1.0))
class EArrayTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_extended_dtypes_at_rest(self, jit):
# Test a trivial isomorphic-to-float32 extended dtype working with EArray
from jax._src import core
from jax._src.interpreters import pxla
class foo(dtypes.extended): pass
class FooTyRules:
@staticmethod
def convert_to(foo_dtype, target_dtype):
return True
@staticmethod
def physical_element_aval(foo_dtype):
return core.ShapedArray((), dtypes.dtype('float32'))
@staticmethod
def replicate_trailing_dims(ctx, val, aval):
del ctx, aval
return val
@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
phys_sharding = out_sharding # unlike KeyTyRules, assume same shape
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
return lambda bufs: earray.EArray(aval, phys_handler(bufs))
@staticmethod
def physical_sharding(aval, sharding):
return sharding # unlike KeyTyRules, assume same shape
@dataclasses.dataclass(frozen=True)
class FooTy(dtypes.ExtendedDType):
name: str = 'foo'
_rules: type = FooTyRules
type: type = foo
# Can we make one?
def f(x):
return jax.lax.convert_element_type(x, FooTy())
if jit:
f = jax.jit(f)
x = f(jnp.arange(3, dtype='float32')) # don't crash
self.assertIsInstance(x.dtype, FooTy)
# Can we consume one?
def g(x):
self.assertIsInstance(x.dtype, FooTy)
return x
if jit:
g = jax.jit(g)
y = g(x)
self.assertIsInstance(y.dtype, FooTy)
class TestPromotionTables(jtu.JaxTestCase):