mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
192 lines
6.6 KiB
Python
192 lines
6.6 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for TPU-specific interpret mode.
|
|
|
|
To work around https://github.com/jax-ml/jax/issues/25671 , this file
|
|
contains only tests that do not use shard_map.
|
|
"""
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import jax
|
|
from jax._src import test_util as jtu
|
|
import jax._src.pallas.mosaic.interpret as mosaic_interpret
|
|
from jax.experimental import pallas as pl
|
|
from jax.experimental.pallas import tpu as pltpu
|
|
import jax.numpy as jnp
|
|
|
|
import numpy as np
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
class InterpretTest(jtu.JaxTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.num_devices = jax.device_count()
|
|
if self.num_devices > 1:
|
|
# Workaround for https://github.com/jax-ml/jax/issues/25671
|
|
self.skipTest(f'requires 1 device, found {self.num_devices}')
|
|
|
|
def test_matmul_example(self):
|
|
def matmul_kernel(x_ref, y_ref, z_ref):
|
|
z_ref[...] = x_ref[...] @ y_ref[...]
|
|
|
|
@jax.jit
|
|
def matmul(x: jax.Array, y: jax.Array):
|
|
return pl.pallas_call(
|
|
matmul_kernel,
|
|
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
|
|
grid=(2, 2),
|
|
in_specs=[
|
|
pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
|
|
pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
|
|
],
|
|
out_specs=pl.BlockSpec(
|
|
(x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
|
|
),
|
|
interpret=mosaic_interpret.TPUInterpretParams(),
|
|
)(x, y)
|
|
|
|
k1, k2 = jax.random.split(jax.random.key(0))
|
|
x = jax.random.normal(k1, (1024, 1024))
|
|
y = jax.random.normal(k2, (1024, 1024))
|
|
z = matmul(x, y)
|
|
np.testing.assert_allclose(z, x @ y, atol=1e-4)
|
|
|
|
def test_dynamic_grid_and_aliasing(self):
|
|
def kernel(s_ref, x_ref, o_ref):
|
|
o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype)
|
|
|
|
iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32)
|
|
@jax.jit
|
|
def f(s, x):
|
|
return pl.pallas_call(
|
|
kernel,
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
grid=(iters,),
|
|
in_specs=[
|
|
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
|
|
pl.BlockSpec(x.shape, lambda i: (0, 0)),
|
|
],
|
|
out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)),
|
|
input_output_aliases={1: 0},
|
|
interpret=mosaic_interpret.TPUInterpretParams()
|
|
)(s, x)
|
|
|
|
s = jnp.array([1], dtype=jnp.int32)
|
|
x = jnp.arange(32 * 128.).reshape((32, 128))
|
|
y = f(s, x)
|
|
np.testing.assert_allclose(y, x + 1.0)
|
|
|
|
@parameterized.parameters('eager', 'on_wait')
|
|
def test_race_detection(self, dma_execution_mode):
|
|
def kernel_without_race(x_ref, o_ref, t_ref, sem):
|
|
copy = pltpu.make_async_copy(x_ref, t_ref, sem)
|
|
copy.start()
|
|
copy.wait()
|
|
o_ref[...] = t_ref[...] + 1.0
|
|
|
|
def kernel_with_race(x_ref, o_ref, t_ref, sem):
|
|
copy = pltpu.make_async_copy(x_ref, t_ref, sem)
|
|
copy.start()
|
|
# This read of t_ref races with the above DMA's write of t_ref.
|
|
o_ref[...] = t_ref[...] + 1.0
|
|
copy.wait()
|
|
|
|
x = jnp.zeros((8, 128), jnp.float32)
|
|
y = pl.pallas_call(kernel_without_race,
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)],
|
|
scratch_shapes=[
|
|
pltpu.VMEM(x.shape, x.dtype),
|
|
pltpu.SemaphoreType.DMA,
|
|
],
|
|
interpret=mosaic_interpret.TPUInterpretParams(
|
|
detect_races=True, dma_execution_mode=dma_execution_mode),
|
|
)(x).block_until_ready()
|
|
self.assertFalse(mosaic_interpret.races.races_found)
|
|
np.testing.assert_allclose(y, x + 1.0)
|
|
|
|
pl.pallas_call(kernel_with_race,
|
|
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
|
in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)],
|
|
scratch_shapes=[
|
|
pltpu.VMEM(x.shape, x.dtype),
|
|
pltpu.SemaphoreType.DMA,
|
|
],
|
|
interpret=mosaic_interpret.TPUInterpretParams(
|
|
detect_races=True, dma_execution_mode=dma_execution_mode),
|
|
)(x).block_until_ready()
|
|
self.assertTrue(mosaic_interpret.races.races_found)
|
|
|
|
def test_skip_floating_point_ops(self):
|
|
def matmul_kernel(x_ref, y_ref, z_ref):
|
|
z_ref[...] = x_ref[...] @ y_ref[...]
|
|
|
|
def matmul(x: jax.Array, y: jax.Array):
|
|
return pl.pallas_call(
|
|
matmul_kernel,
|
|
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
|
|
interpret=mosaic_interpret.TPUInterpretParams(
|
|
skip_floating_point_ops=True
|
|
),
|
|
)(x, y)
|
|
|
|
k1, k2 = jax.random.split(jax.random.key(0))
|
|
x = jax.random.normal(k1, (1024, 1024))
|
|
y = jax.random.normal(k2, (1024, 1024))
|
|
z = jax.jit(matmul)(x, y)
|
|
np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf))
|
|
|
|
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
|
|
self.assertNotIn("dot_general", lowered)
|
|
|
|
@parameterized.parameters('nan', 'zero')
|
|
def test_uninitialized_memory(self, uninitialized_memory):
|
|
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
|
|
o1_ref[...] = t1_ref[...]
|
|
o2_ref[...] = t2_ref[...]
|
|
|
|
x, y, z = pl.pallas_call(
|
|
kernel,
|
|
out_shape=[
|
|
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
|
|
jax.ShapeDtypeStruct((8, 128), jnp.int16),
|
|
jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
|
],
|
|
in_specs=[],
|
|
scratch_shapes=[
|
|
pltpu.VMEM((8, 128), jnp.bfloat16),
|
|
pltpu.VMEM((8, 128), jnp.int16),
|
|
],
|
|
interpret=mosaic_interpret.TPUInterpretParams(
|
|
uninitialized_memory=uninitialized_memory),
|
|
)()
|
|
if uninitialized_memory == 'nan':
|
|
self.assertTrue(jnp.isnan(x).all())
|
|
np.testing.assert_equal(np.array(y), 32767)
|
|
self.assertTrue(jnp.isnan(z).all())
|
|
if uninitialized_memory == 'zero':
|
|
np.testing.assert_equal(np.array(x), 0)
|
|
np.testing.assert_equal(np.array(y), 0)
|
|
np.testing.assert_equal(np.array(z), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|