mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
support jax.experimental.array.Array
as a base array for key arrays
Only handle host-locally sharded `Array`s for now (like in SDAs under `pmap`). Leaving global sharding for a follow up. Also re-enable a previously skipped test as a result. Co-authored-by: Yash Katariya <yashkatariya@google.com> PiperOrigin-RevId: 469885160
This commit is contained in:
parent
711fc7ffc7
commit
8e2d1be0a5
@ -16,7 +16,6 @@
|
||||
import abc
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -24,7 +23,6 @@ import jax
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
from jax.interpreters import ad
|
||||
@ -287,7 +285,8 @@ class KeyTy:
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def physical_avals(aval):
|
||||
def physical_avals(aval): # TODO(frostig): rename to `grounded_avals`
|
||||
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
|
||||
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape),
|
||||
jnp.dtype('uint32'))]
|
||||
|
||||
@ -304,15 +303,45 @@ class KeyTy:
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def sharded_result_handler(aval, sharding, indices):
|
||||
def local_sharded_result_handler(aval, sharding, indices):
|
||||
phys_aval, = KeyTy.physical_avals(aval)
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
|
||||
# TODO(yashkatariya,frostig): remove this conditional and inline it when
|
||||
# the transient config ever settles
|
||||
if config.jax_array:
|
||||
output_type = pxla.OutputType.Array
|
||||
else:
|
||||
output_type = pxla.OutputType.ShardedDeviceArray
|
||||
phys_handler_maker = pxla.local_result_handlers[
|
||||
(core.ShapedArray, pxla.OutputType.ShardedDeviceArray)]
|
||||
phys_handler = phys_handler_maker(phys_aval, sharding, indices)
|
||||
(core.ShapedArray, output_type)]
|
||||
|
||||
# set up a grounded sharding (with a grounded sharding spec)
|
||||
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = pxla.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
phys_sharding = jax.experimental.sharding.PmapSharding(
|
||||
devices=sharding.devices,
|
||||
sharding_spec=phys_sharding_spec)
|
||||
|
||||
# set up grounded indices
|
||||
trailing_inds = [slice(None)] * len(key_shape)
|
||||
phys_indices = [(*inds, *trailing_inds) for inds in indices]
|
||||
|
||||
# make a physical handler
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
|
||||
|
||||
# set up a handler that calls the physical one and wraps back up
|
||||
def handler(bufs):
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, sharding):
|
||||
raise NotImplementedError # TODO(frostig,yashkatariya): implement!
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
|
@ -437,14 +437,21 @@ pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding):
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
if core.aval_has_custom_eltype(global_aval):
|
||||
return global_aval.dtype.global_sharded_result_handler(
|
||||
global_aval, out_sharding)
|
||||
else:
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
|
||||
|
||||
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
return lambda bufs: Array(aval, sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
return aval.dtype.local_sharded_result_handler(aval, sharding, indices)
|
||||
else:
|
||||
return lambda bufs: Array(aval, sharding, bufs, committed=True,
|
||||
_skip_checks=True)
|
||||
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
|
||||
|
@ -570,8 +570,8 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
|
||||
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
|
||||
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.sharded_result_handler(aval, sharding, indices)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
return aval.dtype.local_sharded_result_handler(aval, sharding, indices)
|
||||
else:
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
|
@ -1383,7 +1383,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
bx = vmap(f1)(ax)
|
||||
self.assertAllClose(ax, bx, check_dtypes=False)
|
||||
|
||||
@jtu.skip_on_flag('jax_array', True) # TODO(yashkatariya,frostig): fix
|
||||
def testVmapOfPmap2(self):
|
||||
N_DEVICES = jax.device_count()
|
||||
keys = random.split(random.PRNGKey(1), 13) # [13, 2]
|
||||
|
Loading…
x
Reference in New Issue
Block a user