PRNGKeyArray: implement scatter/gather via .at()

This commit is contained in:
Jake VanderPlas 2023-04-25 15:54:33 -07:00
parent 4e88943d10
commit 6e84ed2992
3 changed files with 89 additions and 11 deletions

View File

@ -617,7 +617,7 @@ def scatter(
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = lax._reduction_jaxpr(_scatter_reduction_computation,
lax._abstractify(lax._const(operand, 0)))
core.ShapedArray((), lax.dtype(operand)))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
@ -1423,7 +1423,7 @@ def _scatter_dtype_rule(operand, indices, updates, **kwargs):
if not dtypes.issubdtype(indices.dtype, np.integer):
raise ValueError("indices must have an integer type")
lax.check_same_dtypes("scatter", operand, updates)
return dtypes.canonicalize_dtype(operand.dtype)
return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True)
def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr,
update_consts, dimension_numbers, indices_are_sorted,
@ -1998,12 +1998,18 @@ batching.primitive_batchers[scatter_p] = (
def _scatter_lower(ctx, operand, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return [aval_out.dtype._rules.scatter_mlir(
ctx, ctx.avals_in, aval_out, operand, indices, updates,
update_jaxpr=update_jaxpr, update_consts=update_consts,
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode)]
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
updates, dnums=dimension_numbers)
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),

View File

@ -45,7 +45,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib import gpu_prng
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import _set_array_base_attributes
from jax._src.numpy.array_methods import _set_array_base_attributes, _IndexUpdateHelper
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
NamedSharding, PmapSharding, GSPMDSharding)
@ -152,6 +152,10 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
@abc.abstractmethod
def sharding(self): ...
@property
@abc.abstractmethod
def at(self) -> _IndexUpdateHelper: ...
@abc.abstractmethod
def __len__(self) -> int: ...
@abc.abstractmethod
@ -267,10 +271,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
# still needed? If, with some work, none are needed, then do we want
# to remove stackables altogether? This may be the only application.
# Dynamically overridden below.
def reshape(self, newshape, order=None) -> PRNGKeyArrayImpl:
raise NotImplementedError("reshape method must be overridden")
def __repr__(self):
return (f'{self.__class__.__name__}[{self.impl.tag}]'
f' {{ {self._base_array} }}')
@ -284,18 +284,21 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
# Overwritten immediately below
@property
def at(self) -> _IndexUpdateHelper: assert False
@property
def T(self) -> PRNGKeyArray: assert False
def __getitem__(self, _) -> PRNGKeyArray: assert False
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
def ravel(self, *_, **__) -> PRNGKeyArray: assert False
def reshape(self, *_, **__) -> PRNGKeyArray: assert False
def squeeze(self, *_, **__) -> PRNGKeyArray: assert False
def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
def take(self, *_, **__) -> PRNGKeyArray: assert False
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
_set_array_base_attributes(PRNGKeyArrayImpl, include=[
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
'transpose', 'flatten', 'T'])
'__getitem__', 'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
@ -491,6 +494,28 @@ class KeyTyRules:
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
return res
@staticmethod
def scatter_mlir(ctx, avals_in, aval_out, x, indices, updates, *,
update_jaxpr, update_consts, dimension_numbers,
unique_indices, indices_are_sorted, mode):
aval_x, aval_indices, aval_updates = avals_in
aval_y = aval_out
key_shape = aval_x.dtype.impl.key_shape
trailing_window_dims = [aval_updates.ndim + i for i in range(len(key_shape))]
dimension_numbers = dimension_numbers._replace(
update_window_dims=(*dimension_numbers.update_window_dims, *trailing_window_dims))
scatter_lower = partial(
lax_internal.slicing._scatter_lower, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode)
res, = mlir.delegate_lowering(
ctx, scatter_lower, x, indices, updates,
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices,
keys_aval_to_base_arr_aval(aval_updates)],
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
return res
class KeyTy:
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really

View File

@ -2102,6 +2102,53 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertKeysEqual(key, jnp.array(key, dtype=key.dtype))
self.assertKeysEqual(key, jnp.asarray(key, dtype=key.dtype))
@parameterized.parameters([
(0,),
(slice(1),),
(np.array([0, 2]),),
(np.array([False, True, True]),)
])
def test_getitem(self, idx):
key = random.PRNGKey(123)
keys = jax.random.split(key, 3)
key_func = arr_func = lambda x: x[idx]
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
@parameterized.parameters([
(0,),
(slice(1),),
(np.array([0, 2]),),
(np.array([False, True, True]),)
])
def test_gather(self, idx):
key = random.PRNGKey(123)
keys = jax.random.split(key, 3)
key_func = arr_func = lambda x: x.at[idx].get()
self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
@parameterized.parameters([
(0,),
(slice(1),),
(np.array([0, 2]),),
(np.array([False, True, True]),)
])
def test_scatter(self, idx):
key = random.PRNGKey(123)
keys = jax.random.split(key, 3)
key_func = arr_func = lambda x, y: x.at[idx].set(y)
self.check_shape(key_func, keys, key)
self.check_against_reference(key_func, arr_func, keys, key)
def test_errors(self):
key = random.PRNGKey(123)
with self.assertRaisesRegex(ValueError, "dtype=key<fry> is not a valid dtype"):