mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
PRNGKeyArray: implement scatter/gather via .at()
This commit is contained in:
parent
4e88943d10
commit
6e84ed2992
@ -617,7 +617,7 @@ def scatter(
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = lax._reduction_jaxpr(_scatter_reduction_computation,
|
||||
lax._abstractify(lax._const(operand, 0)))
|
||||
core.ShapedArray((), lax.dtype(operand)))
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
@ -1423,7 +1423,7 @@ def _scatter_dtype_rule(operand, indices, updates, **kwargs):
|
||||
if not dtypes.issubdtype(indices.dtype, np.integer):
|
||||
raise ValueError("indices must have an integer type")
|
||||
lax.check_same_dtypes("scatter", operand, updates)
|
||||
return dtypes.canonicalize_dtype(operand.dtype)
|
||||
return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True)
|
||||
|
||||
def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr,
|
||||
update_consts, dimension_numbers, indices_are_sorted,
|
||||
@ -1998,12 +1998,18 @@ batching.primitive_batchers[scatter_p] = (
|
||||
def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
aval_out, = ctx.avals_out
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return [aval_out.dtype._rules.scatter_mlir(
|
||||
ctx, ctx.avals_in, aval_out, operand, indices, updates,
|
||||
update_jaxpr=update_jaxpr, update_consts=update_consts,
|
||||
dimension_numbers=dimension_numbers, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, mode=mode)]
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
|
||||
updates, dnums=dimension_numbers)
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
|
@ -45,7 +45,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import _set_array_base_attributes
|
||||
from jax._src.numpy.array_methods import _set_array_base_attributes, _IndexUpdateHelper
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, PmapSharding, GSPMDSharding)
|
||||
@ -152,6 +152,10 @@ class PRNGKeyArray(abc.ABC, metaclass=PRNGKeyArrayMeta):
|
||||
@abc.abstractmethod
|
||||
def sharding(self): ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def at(self) -> _IndexUpdateHelper: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int: ...
|
||||
@abc.abstractmethod
|
||||
@ -267,10 +271,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
# still needed? If, with some work, none are needed, then do we want
|
||||
# to remove stackables altogether? This may be the only application.
|
||||
|
||||
# Dynamically overridden below.
|
||||
def reshape(self, newshape, order=None) -> PRNGKeyArrayImpl:
|
||||
raise NotImplementedError("reshape method must be overridden")
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}[{self.impl.tag}]'
|
||||
f' {{ {self._base_array} }}')
|
||||
@ -284,18 +284,21 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
|
||||
# Overwritten immediately below
|
||||
@property
|
||||
def at(self) -> _IndexUpdateHelper: assert False
|
||||
@property
|
||||
def T(self) -> PRNGKeyArray: assert False
|
||||
def __getitem__(self, _) -> PRNGKeyArray: assert False
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def ravel(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def reshape(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def squeeze(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def take(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
|
||||
_set_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
'__getitem__', 'ravel', 'squeeze', 'swapaxes', 'take', 'reshape',
|
||||
'transpose', 'flatten', 'T'])
|
||||
'__getitem__', 'at', 'flatten', 'ravel', 'reshape',
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
|
||||
@ -491,6 +494,28 @@ class KeyTyRules:
|
||||
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def scatter_mlir(ctx, avals_in, aval_out, x, indices, updates, *,
|
||||
update_jaxpr, update_consts, dimension_numbers,
|
||||
unique_indices, indices_are_sorted, mode):
|
||||
aval_x, aval_indices, aval_updates = avals_in
|
||||
aval_y = aval_out
|
||||
key_shape = aval_x.dtype.impl.key_shape
|
||||
trailing_window_dims = [aval_updates.ndim + i for i in range(len(key_shape))]
|
||||
dimension_numbers = dimension_numbers._replace(
|
||||
update_window_dims=(*dimension_numbers.update_window_dims, *trailing_window_dims))
|
||||
scatter_lower = partial(
|
||||
lax_internal.slicing._scatter_lower, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
|
||||
mode=mode)
|
||||
res, = mlir.delegate_lowering(
|
||||
ctx, scatter_lower, x, indices, updates,
|
||||
avals_in=[keys_aval_to_base_arr_aval(aval_x), aval_indices,
|
||||
keys_aval_to_base_arr_aval(aval_updates)],
|
||||
avals_out=[keys_aval_to_base_arr_aval(aval_y)])
|
||||
return res
|
||||
|
||||
|
||||
class KeyTy:
|
||||
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
|
||||
|
@ -2102,6 +2102,53 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertKeysEqual(key, jnp.array(key, dtype=key.dtype))
|
||||
self.assertKeysEqual(key, jnp.asarray(key, dtype=key.dtype))
|
||||
|
||||
@parameterized.parameters([
|
||||
(0,),
|
||||
(slice(1),),
|
||||
(np.array([0, 2]),),
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_getitem(self, idx):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jax.random.split(key, 3)
|
||||
|
||||
key_func = arr_func = lambda x: x[idx]
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
|
||||
@parameterized.parameters([
|
||||
(0,),
|
||||
(slice(1),),
|
||||
(np.array([0, 2]),),
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_gather(self, idx):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jax.random.split(key, 3)
|
||||
|
||||
key_func = arr_func = lambda x: x.at[idx].get()
|
||||
|
||||
self.check_shape(key_func, keys)
|
||||
self.check_against_reference(key_func, arr_func, keys)
|
||||
|
||||
@parameterized.parameters([
|
||||
(0,),
|
||||
(slice(1),),
|
||||
(np.array([0, 2]),),
|
||||
(np.array([False, True, True]),)
|
||||
])
|
||||
def test_scatter(self, idx):
|
||||
key = random.PRNGKey(123)
|
||||
keys = jax.random.split(key, 3)
|
||||
|
||||
key_func = arr_func = lambda x, y: x.at[idx].set(y)
|
||||
|
||||
self.check_shape(key_func, keys, key)
|
||||
self.check_against_reference(key_func, arr_func, keys, key)
|
||||
|
||||
|
||||
|
||||
def test_errors(self):
|
||||
key = random.PRNGKey(123)
|
||||
with self.assertRaisesRegex(ValueError, "dtype=key<fry> is not a valid dtype"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user