mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters of the `BlockSpec` constructor, and the mapped block dimensions. Fix a bug where previously a missing block_shape while the index_map was present was resulting in a crash.
This commit is contained in:
parent
6f79925d61
commit
ea548e7c86
@ -33,7 +33,10 @@ for i in range(n):
|
|||||||
This generalizes to any tuple of integers (a length `d` grid will correspond
|
This generalizes to any tuple of integers (a length `d` grid will correspond
|
||||||
to `d` nested loops).
|
to `d` nested loops).
|
||||||
The kernel is executed as many times
|
The kernel is executed as many times
|
||||||
as `prod(grid)`. Each of these invocations is referred to as a "program".
|
as `prod(grid)`.
|
||||||
|
The default grid value `None` stands for `()`, and results in one
|
||||||
|
kernel invocation.
|
||||||
|
Each of these invocations is referred to as a "program".
|
||||||
To access which program (i.e. which element of the grid) the kernel is currently
|
To access which program (i.e. which element of the grid) the kernel is currently
|
||||||
executing, we use {func}`jax.experimental.pallas.program_id`.
|
executing, we use {func}`jax.experimental.pallas.program_id`.
|
||||||
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
|
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
|
||||||
@ -233,3 +236,47 @@ See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/t
|
|||||||
[309 309 309 319 319 319]]
|
[309 309 309 319 319 319]]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
A `None` value appearing as a dimension value in the `block_shape` behaves
|
||||||
|
as the value `1`, except that the corresponding
|
||||||
|
block axis is squeezed. In the example below, observe that the
|
||||||
|
shape of the `o_ref` is (2,) when the block shape was specified as
|
||||||
|
`(None, 2)` (the leading dimension was squeezed).
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> def kernel(o_ref):
|
||||||
|
... assert o_ref.shape == (2,)
|
||||||
|
... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
|
||||||
|
>>> pl.pallas_call(kernel,
|
||||||
|
... jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
|
||||||
|
... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
|
||||||
|
... grid=(3, 2), interpret=True)()
|
||||||
|
Array([[ 0, 0, 10, 10],
|
||||||
|
[ 1, 1, 11, 11],
|
||||||
|
[ 2, 2, 12, 12]], dtype=int32)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
When we construct a `BlockSpec` we can use the value `None` for the
|
||||||
|
`block_shape` parameter, in which case the shape of the overall array
|
||||||
|
is used as `block_shape`.
|
||||||
|
And if we use the value `None` for the `index_map` parameter
|
||||||
|
then a default index map function that returns a tuple of zeros is
|
||||||
|
used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> show_invocations(x_shape=(4, 4), block_shape=None, grid=(2, 3),
|
||||||
|
... out_index_map=None)
|
||||||
|
[[12 12 12 12]
|
||||||
|
[12 12 12 12]
|
||||||
|
[12 12 12 12]
|
||||||
|
[12 12 12 12]]
|
||||||
|
|
||||||
|
>>> show_invocations(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
|
||||||
|
... out_index_map=None)
|
||||||
|
[[12 12 12 12]
|
||||||
|
[12 12 12 12]
|
||||||
|
[12 12 12 12]
|
||||||
|
[12 12 12 12]]
|
||||||
|
|
||||||
|
```
|
||||||
|
@ -211,12 +211,15 @@ class BlockSpec:
|
|||||||
|
|
||||||
def compute_index(self, *args):
|
def compute_index(self, *args):
|
||||||
assert self.index_map is not None
|
assert self.index_map is not None
|
||||||
assert self.block_shape is not None
|
|
||||||
out = self.index_map(*args)
|
out = self.index_map(*args)
|
||||||
if not isinstance(out, tuple):
|
if not isinstance(out, tuple):
|
||||||
out = (out,)
|
out = (out,)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class NoBlockSpec:
|
||||||
|
pass
|
||||||
|
no_block_spec = NoBlockSpec()
|
||||||
|
|
||||||
|
|
||||||
# A PyTree of BlockSpec | NoBlockSpec.
|
# A PyTree of BlockSpec | NoBlockSpec.
|
||||||
BlockSpecTree = Any
|
BlockSpecTree = Any
|
||||||
@ -310,9 +313,11 @@ def _convert_block_spec_to_block_mapping(
|
|||||||
return None
|
return None
|
||||||
if block_spec.index_map is None:
|
if block_spec.index_map is None:
|
||||||
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
|
compute_index = lambda *args, **kwargs: (0,) * len(aval.shape)
|
||||||
block_shape = aval.shape
|
|
||||||
else:
|
else:
|
||||||
compute_index = block_spec.compute_index
|
compute_index = block_spec.compute_index
|
||||||
|
if block_spec.block_shape is None:
|
||||||
|
block_shape = aval.shape
|
||||||
|
else:
|
||||||
block_shape = block_spec.block_shape
|
block_shape = block_spec.block_shape
|
||||||
block_shape = tuple(
|
block_shape = tuple(
|
||||||
mapped if s is None else s for s in block_shape)
|
mapped if s is None else s for s in block_shape)
|
||||||
@ -338,8 +343,7 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
|||||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||||
|
|
||||||
|
|
||||||
def _get_ref_avals(grid,
|
def _get_ref_avals(in_avals: Sequence[jax_core.ShapedArray],
|
||||||
in_avals: Sequence[jax_core.ShapedArray],
|
|
||||||
in_specs: Sequence[BlockSpec],
|
in_specs: Sequence[BlockSpec],
|
||||||
in_paths: Sequence[tree_util.KeyPath],
|
in_paths: Sequence[tree_util.KeyPath],
|
||||||
out_avals: Sequence[jax_core.ShapedArray],
|
out_avals: Sequence[jax_core.ShapedArray],
|
||||||
@ -363,9 +367,9 @@ def _get_ref_avals(grid,
|
|||||||
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
|
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
|
||||||
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
|
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
|
||||||
)
|
)
|
||||||
trimmed_block_shape = tuple(s for s in block_shape if s is not None)
|
block_shape_unmapped = tuple(s for s in block_shape if s is not None)
|
||||||
ref_aval = ref_aval.update(
|
ref_aval = ref_aval.update(
|
||||||
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape))
|
inner_aval=ref_aval.inner_aval.update(shape=block_shape_unmapped))
|
||||||
|
|
||||||
if not jax_core.is_constant_shape(ref_aval.shape):
|
if not jax_core.is_constant_shape(ref_aval.shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -382,11 +386,7 @@ def _get_ref_avals(grid,
|
|||||||
make_ref_aval(aval, out_spec, out_path, "output")
|
make_ref_aval(aval, out_spec, out_path, "output")
|
||||||
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
|
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
|
||||||
]
|
]
|
||||||
return in_specs, in_ref_avals, out_specs, out_ref_avals
|
return in_ref_avals, out_ref_avals
|
||||||
|
|
||||||
class NoBlockSpec:
|
|
||||||
pass
|
|
||||||
no_block_spec = NoBlockSpec()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
@dataclasses.dataclass(init=False, unsafe_hash=True)
|
||||||
@ -453,8 +453,8 @@ class GridSpec:
|
|||||||
)
|
)
|
||||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||||
in_avals, in_tree, out_avals, out_tree)
|
in_avals, in_tree, out_avals, out_tree)
|
||||||
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
|
in_ref_avals, out_ref_avals = _get_ref_avals(
|
||||||
self.grid, in_avals, flat_in_specs, in_paths,
|
in_avals, flat_in_specs, in_paths,
|
||||||
out_avals, flat_out_specs, out_paths)
|
out_avals, flat_out_specs, out_paths)
|
||||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
|
||||||
# Create args, kwargs pytree def
|
# Create args, kwargs pytree def
|
||||||
@ -468,7 +468,7 @@ class GridSpec:
|
|||||||
mapped_dims=(),
|
mapped_dims=(),
|
||||||
what="input",
|
what="input",
|
||||||
),
|
),
|
||||||
in_specs,
|
flat_in_specs,
|
||||||
in_paths,
|
in_paths,
|
||||||
in_ref_avals,
|
in_ref_avals,
|
||||||
)
|
)
|
||||||
@ -481,7 +481,7 @@ class GridSpec:
|
|||||||
mapped_dims=(),
|
mapped_dims=(),
|
||||||
what="output",
|
what="output",
|
||||||
),
|
),
|
||||||
out_specs,
|
flat_out_specs,
|
||||||
out_paths,
|
out_paths,
|
||||||
out_ref_avals,
|
out_ref_avals,
|
||||||
)
|
)
|
||||||
|
@ -187,9 +187,9 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
|||||||
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
in_avals, in_avals_tree = tree_util.tree_flatten(tuple(unflat_in_avals))
|
||||||
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
flat_in_specs, flat_out_specs = self._get_in_out_specs(
|
||||||
in_avals, in_avals_tree, out_avals, out_tree)
|
in_avals, in_avals_tree, out_avals, out_tree)
|
||||||
in_specs, in_ref_avals, out_specs, out_ref_avals = (
|
in_ref_avals, out_ref_avals = (
|
||||||
pallas_core._get_ref_avals(
|
pallas_core._get_ref_avals(
|
||||||
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
|
||||||
out_avals, flat_out_specs, out_paths))
|
out_avals, flat_out_specs, out_paths))
|
||||||
scalar_ref_avals = [
|
scalar_ref_avals = [
|
||||||
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
|
||||||
@ -209,7 +209,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
|||||||
mapped_dims=(),
|
mapped_dims=(),
|
||||||
what="input",
|
what="input",
|
||||||
),
|
),
|
||||||
in_specs,
|
flat_in_specs,
|
||||||
in_paths[num_flat_scalar_prefetch:],
|
in_paths[num_flat_scalar_prefetch:],
|
||||||
in_ref_avals,
|
in_ref_avals,
|
||||||
)
|
)
|
||||||
@ -222,7 +222,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
|
|||||||
mapped_dims=(),
|
mapped_dims=(),
|
||||||
what="output",
|
what="output",
|
||||||
),
|
),
|
||||||
out_specs,
|
flat_out_specs,
|
||||||
out_paths,
|
out_paths,
|
||||||
out_ref_avals,
|
out_ref_avals,
|
||||||
)
|
)
|
||||||
|
@ -1034,12 +1034,14 @@ def pallas_call(
|
|||||||
See details at :ref:`pallas_grid`.
|
See details at :ref:`pallas_grid`.
|
||||||
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
in_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||||
a structure matching that of the positional arguments.
|
a structure matching that of the positional arguments.
|
||||||
|
The default value for ``in_specs`` specifies the whole array for all
|
||||||
|
inputs, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
||||||
See details at :ref:`pallas_blockspec`.
|
See details at :ref:`pallas_blockspec`.
|
||||||
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
out_specs: a PyTree of :class:`jax.experimental.pallas.BlockSpec` with
|
||||||
a structure matching that of the outputs.
|
a structure matching that of the outputs.
|
||||||
|
The default value for ``out_specs`` specifies the whole array,
|
||||||
|
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``.
|
||||||
See details at :ref:`pallas_blockspec`.
|
See details at :ref:`pallas_blockspec`.
|
||||||
The default value for `out_specs` specifies the whole array,
|
|
||||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
|
|
||||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||||
the index of the output that aliases them. These indices are in the
|
the index of the output that aliases them. These indices are in the
|
||||||
flattened inputs and outputs.
|
flattened inputs and outputs.
|
||||||
@ -1063,6 +1065,9 @@ def pallas_call(
|
|||||||
if grid_spec is None:
|
if grid_spec is None:
|
||||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||||
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
|
||||||
|
# TODO(necula): this canonicalization may be convenient for some usage
|
||||||
|
# but it is lossy, because it prevents expressing functions that return
|
||||||
|
# lists.
|
||||||
if isinstance(out_shape, list):
|
if isinstance(out_shape, list):
|
||||||
out_shape = tuple(out_shape)
|
out_shape = tuple(out_shape)
|
||||||
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
|
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
|
||||||
|
@ -253,6 +253,114 @@ class PallasCallTest(PallasBaseTest):
|
|||||||
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
|
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
|
||||||
self.assertIsInstance(res, tuple)
|
self.assertIsInstance(res, tuple)
|
||||||
|
|
||||||
|
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
|
||||||
|
def test_block_spec_with_padding(self):
|
||||||
|
def f(*, shape, block_shape):
|
||||||
|
def kernel(o1_ref):
|
||||||
|
assert o1_ref.shape == block_shape
|
||||||
|
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
|
||||||
|
|
||||||
|
return self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct(shape, dtype=np.int32),
|
||||||
|
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
|
||||||
|
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
|
||||||
|
# No padding
|
||||||
|
pids = f(shape=(8,), block_shape=(2,))
|
||||||
|
self.assertAllClose(pids,
|
||||||
|
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
|
||||||
|
# Pad the last block
|
||||||
|
pids = f(shape=(8,), block_shape=(3,))
|
||||||
|
self.assertAllClose(pids,
|
||||||
|
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
|
||||||
|
# Works even if the shape is smaller than 1 block
|
||||||
|
pids = f(shape=(3,), block_shape=(8,))
|
||||||
|
self.assertAllClose(pids,
|
||||||
|
np.array([0, 0, 0], dtype=np.int32))
|
||||||
|
|
||||||
|
def test_block_spec_mapped_dimension(self):
|
||||||
|
@functools.partial(
|
||||||
|
self.pallas_call,
|
||||||
|
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
||||||
|
in_specs=[
|
||||||
|
pl.BlockSpec((None, 4), lambda _: (0, 0)),
|
||||||
|
pl.BlockSpec((None, 4), lambda _: (1, 0)),
|
||||||
|
],
|
||||||
|
grid=1,
|
||||||
|
)
|
||||||
|
def add_vectors(x_ref, y_ref, o_ref):
|
||||||
|
o_ref[:] = x_ref[:] + y_ref[:]
|
||||||
|
xy = jnp.arange(8., dtype=np.float32).reshape((2, 4))
|
||||||
|
out = add_vectors(xy, xy)
|
||||||
|
out_ref = xy[0] + xy[1]
|
||||||
|
np.testing.assert_allclose(out, out_ref)
|
||||||
|
|
||||||
|
def test_pallas_call_no_grid(self):
|
||||||
|
o_ref_shape = None
|
||||||
|
def kernel(o_ref):
|
||||||
|
nonlocal o_ref_shape
|
||||||
|
o_ref_shape = o_ref.shape
|
||||||
|
o_ref[...] = jnp.full(o_ref.shape, 42)
|
||||||
|
|
||||||
|
pids = self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct((8,), dtype=np.int32))()
|
||||||
|
self.assertAllClose(pids, np.array([42] * 8, dtype=np.int32))
|
||||||
|
self.assertEqual(o_ref_shape, (8,))
|
||||||
|
|
||||||
|
def test_pallas_call_no_block_spec(self):
|
||||||
|
o_ref_shape = None
|
||||||
|
def kernel(o_ref):
|
||||||
|
nonlocal o_ref_shape
|
||||||
|
o_ref_shape = o_ref.shape
|
||||||
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||||
|
|
||||||
|
pids = self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||||
|
grid=(1,))()
|
||||||
|
self.assertEqual(o_ref_shape, (8,))
|
||||||
|
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||||
|
|
||||||
|
def test_block_spec_no_block_shape_and_no_index_map(self):
|
||||||
|
o_ref_shape = None
|
||||||
|
def kernel(o_ref):
|
||||||
|
nonlocal o_ref_shape
|
||||||
|
o_ref_shape = o_ref.shape
|
||||||
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||||
|
|
||||||
|
pids = self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||||
|
out_specs=pl.BlockSpec(),
|
||||||
|
grid=(1,))()
|
||||||
|
self.assertEqual(o_ref_shape, (8,))
|
||||||
|
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||||
|
|
||||||
|
def test_block_spec_no_block_shape(self):
|
||||||
|
o_ref_shape = None
|
||||||
|
def kernel(o_ref):
|
||||||
|
nonlocal o_ref_shape
|
||||||
|
o_ref_shape = o_ref.shape
|
||||||
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||||
|
|
||||||
|
pids = self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||||
|
out_specs=pl.BlockSpec(None, lambda i: i),
|
||||||
|
grid=(1,))()
|
||||||
|
self.assertEqual(o_ref_shape, (8,))
|
||||||
|
self.assertAllClose(pids, np.array([0] * 8, dtype=np.int32))
|
||||||
|
|
||||||
|
def test_block_spec_no_index_map(self):
|
||||||
|
o_ref_shape = None
|
||||||
|
def kernel(o_ref):
|
||||||
|
nonlocal o_ref_shape
|
||||||
|
o_ref_shape = o_ref.shape
|
||||||
|
o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0))
|
||||||
|
|
||||||
|
pids = self.pallas_call(kernel,
|
||||||
|
jax.ShapeDtypeStruct((8,), dtype=np.int32),
|
||||||
|
out_specs=pl.BlockSpec((4,)),
|
||||||
|
grid=(1,))()
|
||||||
|
self.assertEqual(o_ref_shape, (4,))
|
||||||
|
self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32))
|
||||||
|
|
||||||
def test_hoisted_consts(self):
|
def test_hoisted_consts(self):
|
||||||
# See https://github.com/google/jax/issues/21557.
|
# See https://github.com/google/jax/issues/21557.
|
||||||
x = jnp.zeros(32)
|
x = jnp.zeros(32)
|
||||||
@ -440,50 +548,6 @@ class PallasCallTest(PallasBaseTest):
|
|||||||
self.assertEqual(f(x), 2.)
|
self.assertEqual(f(x), 2.)
|
||||||
self.assertEqual(trace_count, 1)
|
self.assertEqual(trace_count, 1)
|
||||||
|
|
||||||
def test_custom_jvp_call(self):
|
|
||||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
|
|
||||||
def softmax(x, axis=-1):
|
|
||||||
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
|
|
||||||
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
|
|
||||||
|
|
||||||
@softmax.defjvp
|
|
||||||
def softmax_jvp(axis, primals, tangents):
|
|
||||||
(x,), (x_dot,) = primals, tangents
|
|
||||||
y = softmax(x, axis)
|
|
||||||
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
|
||||||
|
|
||||||
m, n = 16, 32
|
|
||||||
x = random.normal(random.key(0), (m, n))
|
|
||||||
|
|
||||||
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
|
||||||
def softmax_kernel(x_ref, y_ref):
|
|
||||||
y_ref[:] = softmax(x_ref[:])
|
|
||||||
|
|
||||||
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
|
||||||
|
|
||||||
@jtu.skip_on_devices("gpu") # TODO: RET_CHECK failure
|
|
||||||
def test_block_spec_with_padding(self):
|
|
||||||
def f(*, shape, block_shape):
|
|
||||||
def kernel(o1_ref):
|
|
||||||
assert o1_ref.shape == block_shape
|
|
||||||
o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0))
|
|
||||||
|
|
||||||
return self.pallas_call(kernel,
|
|
||||||
jax.ShapeDtypeStruct(shape, dtype=np.int32),
|
|
||||||
grid=((shape[0] + block_shape[0] - 1) // block_shape[0],),
|
|
||||||
out_specs=pl.BlockSpec(block_shape, lambda i: i))()
|
|
||||||
# No padding
|
|
||||||
pids = f(shape=(8,), block_shape=(2,))
|
|
||||||
self.assertAllClose(pids,
|
|
||||||
np.array([0, 0, 1, 1, 2, 2, 3, 3], dtype=np.int32))
|
|
||||||
# Pad the last block
|
|
||||||
pids = f(shape=(8,), block_shape=(3,))
|
|
||||||
self.assertAllClose(pids,
|
|
||||||
np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=np.int32))
|
|
||||||
# Works even if the shape is smaller than 1 block
|
|
||||||
pids = f(shape=(3,), block_shape=(8,))
|
|
||||||
self.assertAllClose(pids,
|
|
||||||
np.array([0, 0, 0], dtype=np.int32))
|
|
||||||
|
|
||||||
class PallasCallInterpreterTest(PallasCallTest):
|
class PallasCallInterpreterTest(PallasCallTest):
|
||||||
INTERPRET = True
|
INTERPRET = True
|
||||||
@ -1495,6 +1559,27 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
|||||||
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2,
|
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2,
|
||||||
atol=grad_tol, rtol=grad_tol)
|
atol=grad_tol, rtol=grad_tol)
|
||||||
|
|
||||||
|
def test_custom_jvp_call(self):
|
||||||
|
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||||
|
def softmax(x, axis=-1):
|
||||||
|
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
|
||||||
|
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
|
||||||
|
|
||||||
|
@softmax.defjvp
|
||||||
|
def softmax_jvp(axis, primals, tangents):
|
||||||
|
(x,), (x_dot,) = primals, tangents
|
||||||
|
y = softmax(x, axis)
|
||||||
|
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
||||||
|
|
||||||
|
m, n = 16, 32
|
||||||
|
x = random.normal(random.key(0), (m, n))
|
||||||
|
|
||||||
|
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
||||||
|
def softmax_kernel(x_ref, y_ref):
|
||||||
|
y_ref[:] = softmax(x_ref[:])
|
||||||
|
|
||||||
|
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
||||||
|
|
||||||
# TODO(sharadmv): enable this when we update Triton
|
# TODO(sharadmv): enable this when we update Triton
|
||||||
# def test_jvp_matmul(self):
|
# def test_jvp_matmul(self):
|
||||||
# k1, k2 = random.split(random.key(0))
|
# k1, k2 = random.split(random.key(0))
|
||||||
@ -1505,23 +1590,6 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
|||||||
# interpret=self.INTERPRET)
|
# interpret=self.INTERPRET)
|
||||||
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
|
# jtu.check_grads(mm, (x, y), modes=["fwd"], order=1)
|
||||||
|
|
||||||
def test_slicing_block_spec(self):
|
|
||||||
@functools.partial(
|
|
||||||
self.pallas_call,
|
|
||||||
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
|
||||||
in_specs=[
|
|
||||||
pl.BlockSpec((None, 4), lambda _: (0, 0)),
|
|
||||||
pl.BlockSpec((None, 4), lambda _: (1, 0)),
|
|
||||||
],
|
|
||||||
grid=1,
|
|
||||||
)
|
|
||||||
def add_vectors(x_ref, y_ref, o_ref):
|
|
||||||
o_ref[:] = x_ref[:] + y_ref[:]
|
|
||||||
xy = jnp.arange(8.).reshape((2, 4))
|
|
||||||
out = add_vectors(xy, xy)
|
|
||||||
out_ref = xy[0] + xy[1]
|
|
||||||
np.testing.assert_allclose(out, out_ref)
|
|
||||||
|
|
||||||
|
|
||||||
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
|
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
|
||||||
INTERPRET = True
|
INTERPRET = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user