mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support to pipeline emitter for shapes that don't perfectly divide the block shapes
PiperOrigin-RevId: 640471328
This commit is contained in:
parent
d5e43dd1e9
commit
c2a3c0bb80
@ -113,5 +113,5 @@ py_library(
|
||||
"//jax:api_util",
|
||||
"//jax:util",
|
||||
"//jax/_src/pallas",
|
||||
],
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for emitting custom TPU pipelines within a Pallas call."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
@ -29,6 +30,7 @@ from jax._src.pallas.mosaic import core as tpu_core
|
||||
from jax._src.pallas.mosaic import primitives as tpu_primitives
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
SMEM = tpu_core.TPUMemorySpace.SMEM
|
||||
@ -44,6 +46,9 @@ PipelineBlockSpecs = Union[Sequence[pallas_core.BlockSpec], Any]
|
||||
PipelineRefs = Union[Sequence[REF], Any]
|
||||
|
||||
|
||||
# TODO(sharadmv): make this a parameter and make it queryable from the Device.
|
||||
_TILING = (8, 128)
|
||||
|
||||
def _broadcast_pytree_to(from_pytree, to_pytree):
|
||||
"""Broadcast a prefix pytree to a given full tree."""
|
||||
proxy = object()
|
||||
@ -63,14 +68,73 @@ def _broadcast_pytree_to(from_pytree, to_pytree):
|
||||
return tree_util.tree_unflatten(treedef, broadcast_leaves)
|
||||
|
||||
|
||||
def _get_tpu_generation() -> int:
|
||||
kind = jax.devices()[0].device_kind
|
||||
if kind.endswith(' lite'):
|
||||
kind = kind[:-len(' lite')]
|
||||
assert kind[:-1] == "TPU v", kind
|
||||
return int(kind[-1])
|
||||
|
||||
def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
|
||||
# For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
|
||||
# and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
|
||||
# (2, 3, 128, 128) -> (1, 1, 8, 128).
|
||||
if len(shape) < 2:
|
||||
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
|
||||
leading_dims, final_dims = shape[:-2], shape[-2:]
|
||||
# We want to find the minimum power of 2 that fits the second-minor dimension
|
||||
# of shape, with maximum value 8.
|
||||
second_minor, _ = final_dims
|
||||
packing = 4 // dtype.itemsize
|
||||
max_tiling = _TILING[0]
|
||||
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
|
||||
while second_minor_tiling < min(second_minor, max_tiling):
|
||||
second_minor_tiling *= 2
|
||||
return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1])
|
||||
|
||||
|
||||
def _mod(a, n):
|
||||
""""Calculates a mod n for positive and negative a with |a| <= n."""
|
||||
return lax.rem(a + n, n)
|
||||
|
||||
|
||||
def _make_ds(idx, size):
|
||||
def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
|
||||
if s % multiple == 0:
|
||||
return s
|
||||
# Subtract off the remainder, then add multiple
|
||||
return s - s % multiple + multiple
|
||||
|
||||
|
||||
def _make_ds(
|
||||
idx: jax.Array | int, size: jax.Array | int
|
||||
) -> pl.Slice:
|
||||
"""Make a DMA slice with mosaic size hints."""
|
||||
return pl.ds(pl.multiple_of(idx * size, size), size)
|
||||
out = pl.ds(idx * size, size)
|
||||
assert isinstance(out, pl.Slice)
|
||||
return out
|
||||
|
||||
|
||||
def _make_block_slice(
|
||||
block_index: jax.Array, block_size: int, size: int, tiling: int
|
||||
) -> pl.Slice | slice:
|
||||
# Computes a slice given a block index and block size. In the default case,
|
||||
# we return slice(block_index * block_size, (block_index + 1) * block_size).
|
||||
# However, if the total size of the ref does not divide block size and we are
|
||||
# selecting the last block, we need to pick the lowest tiling size multiple
|
||||
# that contains the block.
|
||||
if size % block_size == 0:
|
||||
return _make_ds(block_index, block_size)
|
||||
if block_size % tiling != 0:
|
||||
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
|
||||
num_blocks = pl.cdiv(size, block_size)
|
||||
is_last = block_index == num_blocks - 1
|
||||
rounded_size = jnp.where(
|
||||
is_last,
|
||||
_round_up_to_nearest_multiple(size % block_size, tiling),
|
||||
block_size,
|
||||
)
|
||||
rounded_size = pl.multiple_of(rounded_size, tiling)
|
||||
return pl.ds(block_index * block_size, rounded_size)
|
||||
|
||||
|
||||
def _tuples_differ(xs, ys):
|
||||
@ -259,49 +323,110 @@ class BufferedRef:
|
||||
if self.memory_space == VMEM: return
|
||||
self.current_slot[0] = self.next_slot[0]
|
||||
|
||||
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
|
||||
# We need to handle blocks that might go OOB in the src array. An in bounds
|
||||
# block looks like this (for array shape (600, 600) and block shape
|
||||
# (256, 256)):
|
||||
#
|
||||
# +--------------+------------------|
|
||||
# | Block (0,0) | |
|
||||
# | (256, 256) | |
|
||||
# +--------------+ |
|
||||
# | A (600, 600) |
|
||||
# | |
|
||||
# +---------------------------------+
|
||||
#
|
||||
# For in-bounds blocks, we don't need to do anything special.
|
||||
# An out-of-bounds block looks like this:
|
||||
#
|
||||
# +--------------+------------------|
|
||||
# | |
|
||||
# | |
|
||||
# + |
|
||||
# | A (600, 600) |
|
||||
# +--------------+ |
|
||||
# | Block (2,0) | |
|
||||
# + --------------------------------|
|
||||
# | XXXXXXXXXX |
|
||||
# +--------------+
|
||||
# where the X's indicate where the block is out of bounds.
|
||||
#
|
||||
# When we have an out of bounds block like this, we need to truncate it to
|
||||
# a tile boundary (tiles are (8, 128) along the two minormost dimensions).
|
||||
# In this case, we'll have a block that is indexing the
|
||||
# 512:768 elements of A along the first dimension. We need to convert 768
|
||||
# into 600 (600 % 8 == 0), so our indexing will look like this:
|
||||
|
||||
# +--------------+------------------|
|
||||
# | |
|
||||
# | |
|
||||
# + |
|
||||
# | A (600, 600) |
|
||||
# +--------------+ |
|
||||
# | Block (2,0) | |
|
||||
# + --------------------------------|
|
||||
# where it is now a (88, 256) sized block.
|
||||
#
|
||||
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
|
||||
# for the last iteration on that dimension, we will pick the next highest
|
||||
# tile multiple, i.e. (96, 256).
|
||||
if len(src_shape) < 2:
|
||||
raise NotImplementedError("Must use >1D values.")
|
||||
|
||||
tiling = _make_tiling(src_shape, src_dtype)
|
||||
block_shape = tuple(1 if b is None else b for b in self.block_shape)
|
||||
block_indices = self.compute_index(*grid_indices)
|
||||
return jax.tree.map(
|
||||
_make_block_slice, block_indices, block_shape, src_shape, tiling
|
||||
)
|
||||
|
||||
def copy_in(self, src_ref, grid_indices):
|
||||
"""Starts copy of HBM dma slice into the current slot."""
|
||||
assert self.is_input
|
||||
if self.memory_space == VMEM: return
|
||||
dma_slice = self.compute_slice(grid_indices)
|
||||
next_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
self.next_slot[0] = next_slot
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
src_ref.at[dma_slice],
|
||||
self.vmem_ref.at[next_slot],
|
||||
src_ref.at[src_slice],
|
||||
self.vmem_ref.at[next_slot].at[dst_slice],
|
||||
self.sem_recv).start()
|
||||
|
||||
def copy_out(self, dst_ref, grid_indices):
|
||||
"""Starts copy of HBM dma slice from the current slot."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
dma_slice = self.compute_slice(grid_indices)
|
||||
slot = self.current_slot[0]
|
||||
self.next_slot[0] = lax.rem(slot + 1, 2)
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
self.vmem_ref.at[slot],
|
||||
dst_ref.at[dma_slice],
|
||||
self.vmem_ref.at[slot].at[src_slice],
|
||||
dst_ref.at[dst_slice],
|
||||
self.sem_send).start()
|
||||
|
||||
def wait_in(self, src_ref, grid_indices):
|
||||
"""Waits for input copy to finish."""
|
||||
assert self.is_input
|
||||
if self.memory_space == VMEM: return
|
||||
dma_slice = self.compute_slice(grid_indices)
|
||||
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
|
||||
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
src_ref.at[dma_slice], # nb: doesn't matter
|
||||
self.vmem_ref.at[self.current_slot[0]], # only dst shape is important
|
||||
src_ref.at[src_slice], # nb: doesn't matter
|
||||
self.vmem_ref.at[self.current_slot[0]].at[dst_slice], # only dst shape is important
|
||||
self.sem_recv).wait()
|
||||
|
||||
def wait_out(self, dst_ref, grid_indices):
|
||||
"""Waits for output copy to finish."""
|
||||
assert self.is_output
|
||||
if self.memory_space == VMEM: return
|
||||
dma_slice = self.compute_slice(grid_indices)
|
||||
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
|
||||
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
|
||||
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
|
||||
tpu_primitives.make_async_copy(
|
||||
self.vmem_ref.at[prev_slot], # nb: doesn't matter
|
||||
dst_ref.at[dma_slice], # only dst shape is important
|
||||
self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
|
||||
dst_ref.at[dst_slice], # only dst shape is important
|
||||
self.sem_send).wait()
|
||||
|
||||
# Accumulator methods
|
||||
|
@ -15,7 +15,9 @@
|
||||
"""Pallas utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import overload
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.util import split_list
|
||||
@ -32,9 +34,26 @@ def when(condition):
|
||||
lax.cond(condition, f, lambda: None)
|
||||
return _wrapped
|
||||
|
||||
|
||||
@overload
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
...
|
||||
|
||||
@overload
|
||||
def cdiv(a: int, b: jax.Array) -> jax.Array:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cdiv(a: jax.Array, b: int) -> jax.Array:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cdiv(a: jax.Array, b: jax.Array) -> jax.Array:
|
||||
...
|
||||
|
||||
def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array:
|
||||
if isinstance(a, int) and isinstance(b, int):
|
||||
return (a + b - 1) // b
|
||||
return lax.div(a + b - 1, b)
|
||||
|
||||
|
||||
def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
|
||||
|
@ -119,7 +119,7 @@ def mha_forward_kernel(
|
||||
# Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
|
||||
upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
|
||||
else:
|
||||
upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
|
||||
upper_bound = pl.cdiv(seq_len, block_k)
|
||||
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
|
||||
|
||||
if residual_refs:
|
||||
|
@ -210,11 +210,12 @@ jax_test(
|
||||
"gpu",
|
||||
],
|
||||
main = "pallas_pipeline_tpu_test.py",
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
"//jax:pallas_tpu",
|
||||
"//jax:pallas_tpu_ops",
|
||||
],
|
||||
] + py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -26,6 +26,24 @@ from jax.experimental import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
try:
|
||||
import hypothesis as hp
|
||||
import hypothesis.strategies as hps
|
||||
CAN_USE_HYPOTHESIS = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
CAN_USE_HYPOTHESIS = False
|
||||
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
hp.settings.register_profile(
|
||||
'deterministic',
|
||||
database=None,
|
||||
derandomize=True,
|
||||
deadline=None,
|
||||
max_examples=50,
|
||||
print_blob=True,
|
||||
)
|
||||
hp.settings.load_profile('deterministic')
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
@ -55,20 +73,46 @@ def basic_matmul_kernel(
|
||||
out_ref,
|
||||
acc_scratch_ref,
|
||||
*,
|
||||
acc_steps: int,
|
||||
k: int,
|
||||
):
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
k_index = pl.program_id(2)
|
||||
num_k = pl.num_programs(2)
|
||||
bk = lhs_ref.shape[1]
|
||||
@pl.when(k_index == 0)
|
||||
def _zero_acc():
|
||||
acc_scratch_ref[...] = jnp.zeros(
|
||||
acc_scratch_ref.shape, acc_scratch_ref.dtype)
|
||||
|
||||
acc_scratch_ref[...] += jnp.dot(
|
||||
lhs_ref[...],
|
||||
rhs_ref[...],
|
||||
preferred_element_type=acc_scratch_ref.dtype,
|
||||
)
|
||||
divisible_k = k % bk == 0
|
||||
if divisible_k:
|
||||
acc_scratch_ref[...] += jnp.dot(
|
||||
lhs_ref[...],
|
||||
rhs_ref[...],
|
||||
preferred_element_type=acc_scratch_ref.dtype,
|
||||
)
|
||||
else:
|
||||
def _last_block():
|
||||
accum_dtype = acc_scratch_ref.dtype
|
||||
lhs_mask = k_index * bk + jax.lax.broadcasted_iota(jnp.int32, lhs_ref.shape, 1) < k
|
||||
rhs_mask = k_index * bk + jax.lax.broadcasted_iota(jnp.int32, rhs_ref.shape, 0) < k
|
||||
dtype = lhs_ref.dtype
|
||||
lhs = lhs_ref[...].astype(accum_dtype)
|
||||
lhs = jnp.where(lhs_mask, lhs, 0).astype(dtype)
|
||||
rhs = rhs_ref[...].astype(accum_dtype)
|
||||
rhs = jnp.where(rhs_mask, rhs, 0).astype(dtype)
|
||||
acc_scratch_ref[...] += jnp.dot(
|
||||
lhs, rhs, preferred_element_type=acc_scratch_ref.dtype)
|
||||
def _not_last_block():
|
||||
acc_scratch_ref[...] += jnp.dot(
|
||||
lhs_ref[...],
|
||||
rhs_ref[...],
|
||||
preferred_element_type=acc_scratch_ref.dtype,
|
||||
)
|
||||
jax.lax.cond(
|
||||
k_index == num_k - 1, _last_block, _not_last_block
|
||||
)
|
||||
|
||||
@pl.when(pl.program_id(2) == acc_steps - 1)
|
||||
@pl.when(k_index == num_k - 1)
|
||||
def _reduce_out():
|
||||
out_ref[...] = acc_scratch_ref[...].astype(out_ref.dtype)
|
||||
|
||||
@ -208,9 +252,8 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
|
||||
|
||||
sharded_k = k // num_devices
|
||||
inner_grid = (n // tn, m // tm, sharded_k // tk)
|
||||
acc_steps = (sharded_k // tk)
|
||||
|
||||
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
|
||||
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
|
||||
|
||||
inner_allocs = [
|
||||
pltpu.BufferedRef.input(
|
||||
@ -506,9 +549,8 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
|
||||
sharded_k = k // num_devices
|
||||
half_m = m // 2
|
||||
inner_grid = (n // tn, half_m // tm, sharded_k // tk)
|
||||
acc_steps = (sharded_k // tk)
|
||||
|
||||
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
|
||||
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
|
||||
|
||||
inner_allocs = [
|
||||
pltpu.BufferedRef.input(
|
||||
@ -747,10 +789,9 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
|
||||
sharded_k = k // num_devices
|
||||
inner_grid = (n // tn, sharded_m // tm, sharded_k // tk)
|
||||
outer_steps = num_devices // 2
|
||||
acc_steps = sharded_k // tk
|
||||
reduce_grid = (sharded_m // tm,)
|
||||
|
||||
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
|
||||
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
|
||||
|
||||
def reduce_kernel(
|
||||
out_ref, # [tm, tn]
|
||||
@ -1032,9 +1073,7 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
|
||||
sharded_k = k // num_devices
|
||||
inner_grid = (n // tn, half_m // tm, sharded_k // tk)
|
||||
outer_steps = num_devices
|
||||
acc_steps = sharded_k // tk
|
||||
|
||||
inner_kernel = partial(basic_matmul_kernel, acc_steps=acc_steps)
|
||||
inner_kernel = partial(basic_matmul_kernel, k=sharded_k)
|
||||
|
||||
inner_allocs = [
|
||||
pltpu.BufferedRef.input(
|
||||
@ -1263,5 +1302,98 @@ class PallasCallColectivePipelineTest(parameterized.TestCase):
|
||||
)
|
||||
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
|
||||
def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int):
|
||||
|
||||
m, k = x.shape
|
||||
_, n = y.shape
|
||||
|
||||
def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref):
|
||||
|
||||
# Partition iteration space according to number of cores
|
||||
core_index = pl.program_id(0)
|
||||
num_cores = pl.num_programs(0)
|
||||
num_m_iters = pl.cdiv(m, bm)
|
||||
# Floor divide to get number of iterations per core
|
||||
iterations_per_core = jax.lax.div(num_m_iters, num_cores)
|
||||
# Last core gets the remainder of iterations
|
||||
num_iters_on_this_core = jnp.where(
|
||||
core_index == num_cores - 1,
|
||||
iterations_per_core + jax.lax.rem(num_m_iters, num_cores),
|
||||
iterations_per_core,
|
||||
)
|
||||
m_offset = core_index * iterations_per_core
|
||||
grid = (num_iters_on_this_core, pl.cdiv(n, bn), pl.cdiv(k, bk))
|
||||
|
||||
def run(acc_scratch_ref):
|
||||
pltpu.emit_pipeline(
|
||||
partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k),
|
||||
in_specs=[
|
||||
pl.BlockSpec(lambda i, j, k: (m_offset + i, k), (bm, bk)),
|
||||
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)),
|
||||
],
|
||||
out_specs=pl.BlockSpec(lambda i, j, k: (m_offset + i, j), (bm, bn)),
|
||||
grid=grid,
|
||||
)(x_hbm_ref, y_hbm_ref, o_hbm_ref)
|
||||
|
||||
accum_dtype = (
|
||||
jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32
|
||||
)
|
||||
pltpu.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype))
|
||||
|
||||
num_cores = jax.devices()[0].num_cores
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
],
|
||||
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
|
||||
grid=(num_cores,),
|
||||
)(x, y)
|
||||
|
||||
class PaddedPipelineEmitterTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest('Only TPU v4+ allowed.')
|
||||
|
||||
@hp.given(
|
||||
hps.sampled_from(['float32', 'bfloat16', 'int8']),
|
||||
hps.integers(1, 1024),
|
||||
hps.integers(1, 1024),
|
||||
hps.integers(1, 1024),
|
||||
hps.sampled_from([8, 16, 32, 128, 256, 512]),
|
||||
hps.sampled_from([128, 256, 512]),
|
||||
hps.sampled_from([128, 256, 512]),
|
||||
hps.integers(0, 4),
|
||||
)
|
||||
def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed):
|
||||
hp.assume(bm <= m)
|
||||
hp.assume(bn <= n)
|
||||
hp.assume(bk <= k)
|
||||
if dtype == 'bfloat16':
|
||||
hp.assume(bm >= 16)
|
||||
if dtype == 'int8':
|
||||
hp.assume(bm >= 32)
|
||||
hp.assume(jtu.is_device_tpu_at_least(5))
|
||||
k1, k2 = jax.random.split(jax.random.key(seed))
|
||||
x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype)
|
||||
y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype)
|
||||
|
||||
out = matmul(x, y, bm=bm, bk=bk, bn=bn)
|
||||
expected = x @ y
|
||||
atol = rtol = 1e-5
|
||||
if dtype == 'bfloat16':
|
||||
out = out.astype('float32')
|
||||
expected = expected.astype('float32')
|
||||
atol = rtol = 1e-2
|
||||
np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user