support jax.experimental.array.Array as a base array for key arrays

Only handle host-locally sharded `Array`s for now (like in SDAs under
`pmap`). Leaving global sharding for a follow up.

Also re-enable a previously skipped test as a result.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 469885160
This commit is contained in:
Roy Frostig 2022-08-24 19:48:36 -07:00 committed by jax authors
parent 711fc7ffc7
commit 8e2d1be0a5
4 changed files with 48 additions and 13 deletions

View File

@ -16,7 +16,6 @@
import abc
from functools import partial
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
import warnings
import numpy as np
@ -24,7 +23,6 @@ import jax
from jax import lax
from jax import core
from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.dtypes import float0
from jax.interpreters import ad
@ -287,7 +285,8 @@ class KeyTy:
# handlers
@staticmethod
def physical_avals(aval):
def physical_avals(aval): # TODO(frostig): rename to `grounded_avals`
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
jnp.dtype('uint32'))]
@ -304,15 +303,45 @@ class KeyTy:
return handler
@staticmethod
def sharded_result_handler(aval, sharding, indices):
def local_sharded_result_handler(aval, sharding, indices):
phys_aval, = KeyTy.physical_avals(aval)
key_shape = aval.dtype.impl.key_shape
# TODO(yashkatariya,frostig): remove this conditional and inline it when
# the transient config ever settles
if config.jax_array:
output_type = pxla.OutputType.Array
else:
output_type = pxla.OutputType.ShardedDeviceArray
phys_handler_maker = pxla.local_result_handlers[
(core.ShapedArray, pxla.OutputType.ShardedDeviceArray)]
phys_handler = phys_handler_maker(phys_aval, sharding, indices)
(core.ShapedArray, output_type)]
# set up a grounded sharding (with a grounded sharding spec)
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
phys_sharding_spec = pxla.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
phys_sharding = jax.experimental.sharding.PmapSharding(
devices=sharding.devices,
sharding_spec=phys_sharding_spec)
# set up grounded indices
trailing_inds = [slice(None)] * len(key_shape)
phys_indices = [(*inds, *trailing_inds) for inds in indices]
# make a physical handler
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
# set up a handler that calls the physical one and wraps back up
def handler(bufs):
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return handler
@staticmethod
def global_sharded_result_handler(aval, sharding):
raise NotImplementedError # TODO(frostig,yashkatariya): implement!
# eltype-polymorphic primitive lowering rules
@staticmethod

View File

@ -437,14 +437,21 @@ pxla.shard_arg_handlers[Array] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding):
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
_skip_checks=True)
if core.aval_has_custom_eltype(global_aval):
return global_aval.dtype.global_sharded_result_handler(
global_aval, out_sharding)
else:
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
_skip_checks=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
def _array_local_result_handler(aval, sharding, indices):
return lambda bufs: Array(aval, sharding, bufs, committed=True,
_skip_checks=True)
if core.aval_has_custom_eltype(aval):
return aval.dtype.local_sharded_result_handler(aval, sharding, indices)
else:
return lambda bufs: Array(aval, sharding, bufs, committed=True,
_skip_checks=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler

View File

@ -570,8 +570,8 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.sharded_result_handler(aval, sharding, indices)
if core.aval_has_custom_eltype(aval):
return aval.dtype.local_sharded_result_handler(aval, sharding, indices)
else:
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)

View File

@ -1383,7 +1383,6 @@ class PythonPmapTest(jtu.JaxTestCase):
bx = vmap(f1)(ax)
self.assertAllClose(ax, bx, check_dtypes=False)
@jtu.skip_on_flag('jax_array', True) # TODO(yashkatariya,frostig): fix
def testVmapOfPmap2(self):
N_DEVICES = jax.device_count()
keys = random.split(random.PRNGKey(1), 13) # [13, 2]