mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32. Closes #18847
This commit is contained in:
parent
433f66ad02
commit
b7715e279d
@ -2076,14 +2076,16 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
|
||||
if handler: return handler(aval, weak_type)
|
||||
raise TypeError(type(aval))
|
||||
|
||||
raise_to_shaped_mappings : dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
ShapedArray: lambda aval, weak_type: ShapedArray(
|
||||
aval.shape, aval.dtype, weak_type, aval.named_shape),
|
||||
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
||||
aval.shape, aval.dtype, weak_type),
|
||||
raise_to_shaped_mappings: dict[type, Callable] = {
|
||||
AbstractToken: lambda aval, _: aval,
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
ShapedArray: lambda aval, weak_type: ShapedArray(
|
||||
aval.shape, aval.dtype, weak_type, aval.named_shape
|
||||
),
|
||||
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
||||
aval.shape, aval.dtype, weak_type
|
||||
),
|
||||
}
|
||||
|
||||
### Operations on shapes and dimension sizes.
|
||||
|
@ -132,7 +132,7 @@ def for_loop(nsteps: int | Sequence[int],
|
||||
nsteps, = nsteps
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||||
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
if out_tree != tree_structure(None):
|
||||
@ -251,7 +251,7 @@ def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||||
|
||||
def _for_impl_unrolled(body, nsteps, unroll, *args):
|
||||
remainder = nsteps % unroll
|
||||
i = jnp.int32(0)
|
||||
i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64))
|
||||
state = list(args)
|
||||
|
||||
for _ in range(remainder):
|
||||
@ -748,7 +748,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
|
||||
"""
|
||||
flat_state, state_tree = tree_flatten(init_state)
|
||||
state_avals = map(state_utils.val_to_ref_aval, flat_state)
|
||||
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
|
||||
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
|
||||
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
|
||||
body, state_tree, [idx_aval, *state_avals])
|
||||
if out_tree != tree_structure(None):
|
||||
@ -756,7 +756,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
|
||||
|
||||
def fori_body(i, carry):
|
||||
i = jnp.int32(i)
|
||||
i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64))
|
||||
if reverse:
|
||||
i = nsteps - i - 1
|
||||
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
|
||||
|
@ -274,7 +274,7 @@ def _pallas_call_impl_interpret(
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
|
||||
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
blocks, out_scratch = split_list(blocks, [num_inout])
|
||||
@ -787,7 +787,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
# errors before other arguments.
|
||||
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
|
||||
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
|
||||
result_flat = jax.core.eval_jaxpr(
|
||||
result_flat = jax_core.eval_jaxpr(
|
||||
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
|
||||
output_errors, _ = split_list(result_flat, [num_err_vals])
|
||||
# Store new errors back in the error refs.
|
||||
|
@ -26,6 +26,7 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import state
|
||||
@ -359,7 +360,12 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
# of bounds, it will instead move the start_index backwards so the slice
|
||||
# will fit in memory.
|
||||
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
idx_dtype = dtypes.canonicalize_dtype(jnp.int64)
|
||||
out_ones = lax.dynamic_slice(
|
||||
ref,
|
||||
[jnp.astype(s, idx_dtype) for s in slice_starts],
|
||||
slice_sizes=slice_sizes,
|
||||
)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
|
@ -52,6 +52,7 @@ pytype_strict_library(
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
"//jax:api_util",
|
||||
"//jax:config",
|
||||
"//jax:core",
|
||||
"//jax:mlir",
|
||||
"//jax:partial_eval",
|
||||
|
@ -29,6 +29,7 @@ from jax import tree_util
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
@ -2263,9 +2264,10 @@ def _for_lowering_rule(
|
||||
del which_linear
|
||||
if reverse or unroll != 1:
|
||||
raise NotImplementedError
|
||||
lower_bound = _i32_constant(0)
|
||||
upper_bound = _i32_constant(nsteps)
|
||||
step = _i32_constant(1)
|
||||
_i_constant = _i64_constant if config.enable_x64.value else _i32_constant
|
||||
lower_bound = _i_constant(0)
|
||||
upper_bound = _i_constant(nsteps)
|
||||
step = _i_constant(1)
|
||||
init_args = map(_ensure_ir_value, args, ctx.avals_in)
|
||||
# Partially discharge state from jaxpr for non-pointers
|
||||
should_discharge = [
|
||||
|
@ -190,7 +190,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
f"Expected shape: {expected_out_shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {indexers}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value dtype: {val_aval.dtype}. ")
|
||||
|
@ -132,6 +132,12 @@ class AbstractRef(core.AbstractValue):
|
||||
def __init__(self, inner_aval: core.AbstractValue):
|
||||
self.inner_aval = inner_aval
|
||||
|
||||
@property
|
||||
def weak_type(self) -> bool:
|
||||
if not hasattr(self.inner_aval, "weak_type"):
|
||||
raise AttributeError
|
||||
return self.inner_aval.weak_type
|
||||
|
||||
def update(self, inner_aval=None):
|
||||
if inner_aval is None:
|
||||
return AbstractRef(self.inner_aval)
|
||||
|
@ -41,8 +41,6 @@ jax_test(
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
],
|
||||
|
@ -29,6 +29,7 @@ from jax import lax
|
||||
from jax import random
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
@ -55,6 +56,10 @@ def smem_on_tpu():
|
||||
return None
|
||||
|
||||
|
||||
intx = dtypes.canonicalize_dtype(jnp.int64)
|
||||
floatx = dtypes.canonicalize_dtype(jnp.float64)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk",
|
||||
"interpret", "debug"])
|
||||
def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False):
|
||||
@ -65,7 +70,7 @@ def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False):
|
||||
debug=debug,
|
||||
grid=pl.cdiv(m, bm) * pl.cdiv(n, bn))
|
||||
def matmul_kernel(x_ref, y_ref, o_ref):
|
||||
pid = pl.program_id(axis=0)
|
||||
pid = pl.program_id(axis=0).astype(intx)
|
||||
num_pid_m = m // bm
|
||||
num_pid_n = n // bn
|
||||
num_pid_in_group = gm * num_pid_n
|
||||
@ -133,8 +138,6 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On CPU the test works only in interpret mode")
|
||||
if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("On GPU the test works only in 32-bit")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_cuda_compute_capability_at_least("8.0")):
|
||||
self.skipTest("Only works on GPU with capability >= sm80")
|
||||
@ -151,13 +154,10 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
class PallasCallTest(PallasBaseTest):
|
||||
|
||||
def test_add_one(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32))
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx))
|
||||
def add_one(x_ref, o_ref):
|
||||
o_ref[()] = x_ref[()] + 1.
|
||||
|
||||
@ -177,14 +177,11 @@ class PallasCallTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(add_one(x), jnp.array([1.], jnp.float32))
|
||||
|
||||
def test_add_vector_block_spec(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
|
||||
out_shape=jax.ShapeDtypeStruct((8,), intx),
|
||||
in_specs=[pl.BlockSpec((1,), lambda i: i)],
|
||||
out_specs=pl.BlockSpec((1,), lambda i: i),
|
||||
grid=8,
|
||||
@ -195,14 +192,11 @@ class PallasCallTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(add_one(jnp.arange(8)), jnp.arange(8) + 1)
|
||||
|
||||
def test_add_matrix_block_spec(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32),
|
||||
out_shape=jax.ShapeDtypeStruct((8, 8), intx),
|
||||
in_specs=[pl.BlockSpec((2, 2), lambda i, j: (i, j))],
|
||||
out_specs=pl.BlockSpec((2, 2), lambda i, j: (i, j)),
|
||||
grid=(4, 4),
|
||||
@ -225,13 +219,10 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertTrue(jnp.all(logical_and(x)))
|
||||
|
||||
def test_vector_indexing(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
|
||||
grid=1)
|
||||
def index(x_ref, i_ref, o_ref):
|
||||
o_ref[()] = x_ref[i_ref[()]]
|
||||
@ -485,13 +476,10 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertAllClose(res[0:1], to_store)
|
||||
|
||||
def test_vector_slicing(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), floatx),
|
||||
grid=1)
|
||||
def index(x_ref, idx_ref, o_ref):
|
||||
idx = idx_ref[()]
|
||||
@ -517,9 +505,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: all sort of assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
k1, k2 = random.split(random.key(0))
|
||||
@ -543,9 +528,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: all sort of assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
k1, k2 = random.split(random.key(0))
|
||||
@ -605,15 +587,12 @@ class PallasCallTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_with_input_output_aliasing(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
def add_inplace_kernel(_, o_ref, *, block_size):
|
||||
pid = pl.program_id(axis=0) # we use a 1d launch grid so axis is 0
|
||||
block_start = pid * block_size
|
||||
offsets = block_start + jnp.arange(block_size)
|
||||
offsets = block_start + jnp.arange(block_size, dtype=jnp.int32)
|
||||
mask = offsets < o_ref.shape[0]
|
||||
x = pl.load(o_ref, (offsets,), mask=mask)
|
||||
output = x + 1
|
||||
@ -634,13 +613,10 @@ class PallasCallTest(PallasBaseTest):
|
||||
np.testing.assert_allclose(out, expected)
|
||||
|
||||
def test_using_pallas_slice(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
m, n = 32, 4
|
||||
out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32)
|
||||
out_shape = jax.ShapeDtypeStruct((4, n), floatx)
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=out_shape,
|
||||
@ -996,7 +972,7 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), jnp.float32),
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), floatx),
|
||||
in_specs=[
|
||||
pl.BlockSpec((), lambda _, __: ()),
|
||||
pl.BlockSpec((bx, 1), lambda i, _: (i, 0)),
|
||||
@ -1073,15 +1049,13 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
# dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14]))
|
||||
|
||||
def test_scan_cond_vm_explicit_ref_arg(self):
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
# TODO: fix this
|
||||
self.skipTest("Fails on CPU: assertion error")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("TODO: error on TPU")
|
||||
|
||||
program = jnp.int32([0, 1, 2, 3, 2])
|
||||
params = jnp.arange(len(program) * 3.).reshape(len(program), 3)
|
||||
x = jnp.arange(7.)
|
||||
params = jnp.arange(len(program) * 3., dtype=jnp.float32)
|
||||
params = params.reshape(len(program), 3)
|
||||
x = jnp.arange(7., dtype=jnp.float32)
|
||||
bx = 4
|
||||
|
||||
@jax.jit
|
||||
@ -1113,7 +1087,7 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
return state, program_ref, params_ref
|
||||
out_ref[...] = jax.lax.fori_loop(
|
||||
0, len(program), body_fn,
|
||||
(jnp.zeros(x.shape), program_ref, params_ref))[0]
|
||||
(jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0]
|
||||
|
||||
expected = (x * params[0, 0] +
|
||||
2 * x * params[1, 1] +
|
||||
@ -1127,16 +1101,14 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
params, x)
|
||||
|
||||
def test_scan_cond_vm_closing_over_ref(self):
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
# TODO: fix this
|
||||
self.skipTest("Fails on CPU: assertion error")
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("TODO: error on TPU")
|
||||
|
||||
# ** Difference is the closure over params_ref in the switch branches. **
|
||||
program = jnp.int32([0, 1, 2, 3, 2, -1])
|
||||
params = jnp.arange(len(program) * 3.).reshape(len(program), 3)
|
||||
x = jnp.arange(7.)
|
||||
params = jnp.arange(len(program) * 3., dtype=jnp.float32)
|
||||
params = params.reshape(len(program), 3)
|
||||
x = jnp.arange(7., dtype=jnp.float32)
|
||||
bx = 4
|
||||
|
||||
@jax.jit
|
||||
@ -1169,7 +1141,7 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
return state, program_ref, params_ref
|
||||
out_ref[...] = jax.lax.fori_loop(
|
||||
0, len(program), body_fn,
|
||||
(jnp.zeros(x.shape), program_ref, params_ref))[0]
|
||||
(jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0]
|
||||
|
||||
expected = (x * params[0, 0] +
|
||||
2 * x * params[1, 1] +
|
||||
@ -1375,7 +1347,7 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
kernel,
|
||||
grid=(1,),
|
||||
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
(1, 8, 128),
|
||||
@ -1439,7 +1411,7 @@ class PallasControlFlowTest(PallasBaseTest):
|
||||
kernel,
|
||||
grid=(4,),
|
||||
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
|
||||
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
(1, 8, 128),
|
||||
@ -1634,7 +1606,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
grad_tol = 1e-1
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
|
||||
grid=1)
|
||||
def pallas_impl(x_ref, o_ref):
|
||||
x = x_ref[()]
|
||||
@ -1656,7 +1628,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
def test_pallas_around_grad(self, impl):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((), jnp.float32),
|
||||
out_shape=jax.ShapeDtypeStruct((), floatx),
|
||||
name=self.id().split(".")[-1],
|
||||
grid=1)
|
||||
def pallas_impl(x_ref, o_ref):
|
||||
@ -1675,7 +1647,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
grad_tol = 1e-1
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx),
|
||||
grid=1)
|
||||
def pallas_impl(x_ref, o_ref):
|
||||
x = x_ref[jnp.arange(2)]
|
||||
@ -1730,16 +1702,9 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
|
||||
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
|
||||
INTERPRET = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
self.skipTest("On CPU the test works only in 32-bit mode")
|
||||
|
||||
|
||||
class PallasOutOfBoundsInterpreterTest(PallasBaseTest):
|
||||
|
||||
INTERPRET: bool = True
|
||||
INTERPRET = True
|
||||
|
||||
def test_interpret_mode_out_of_bounds_access(self):
|
||||
block_size = 32
|
||||
@ -1818,7 +1783,7 @@ class PallasOutOfBoundsInterpreterTest(PallasBaseTest):
|
||||
|
||||
class PallasCheckifyInterpreterTest(PallasBaseTest):
|
||||
# TODO(b/346651778): Support non-interpret mode checkify.
|
||||
INTERPRET: bool = True
|
||||
INTERPRET = True
|
||||
|
||||
def test_no_checkify(self,):
|
||||
def kernel(y_ref):
|
||||
|
Loading…
x
Reference in New Issue
Block a user