Another take at enabling Pallas GPU tests on x64

Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
This commit is contained in:
Sergei Lebedev 2024-07-22 11:20:15 +00:00
parent 433f66ad02
commit b7715e279d
10 changed files with 65 additions and 85 deletions

View File

@ -2076,14 +2076,16 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
if handler: return handler(aval, weak_type)
raise TypeError(type(aval))
raise_to_shaped_mappings : dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape),
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type),
raise_to_shaped_mappings: dict[type, Callable] = {
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape
),
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),
}
### Operations on shapes and dimension sizes.

View File

@ -132,7 +132,7 @@ def for_loop(nsteps: int | Sequence[int],
nsteps, = nsteps
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
if out_tree != tree_structure(None):
@ -251,7 +251,7 @@ def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
def _for_impl_unrolled(body, nsteps, unroll, *args):
remainder = nsteps % unroll
i = jnp.int32(0)
i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64))
state = list(args)
for _ in range(remainder):
@ -748,7 +748,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
"""
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(state_utils.val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
if out_tree != tree_structure(None):
@ -756,7 +756,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
def fori_body(i, carry):
i = jnp.int32(i)
i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64))
if reverse:
i = nsteps - i - 1
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,

View File

@ -274,7 +274,7 @@ def _pallas_call_impl_interpret(
len(blocks),
len(scratch_values),
)
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
*blocks, *scratch)
blocks = blocks[grid_mapping.num_index_operands:]
blocks, out_scratch = split_list(blocks, [num_inout])
@ -787,7 +787,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
# errors before other arguments.
jaxpr_args = [*input_error_vals, *scalars, *inputs, *outputs, *scratch]
assert len(checked_jaxpr.jaxpr.invars) == len(jaxpr_args)
result_flat = jax.core.eval_jaxpr(
result_flat = jax_core.eval_jaxpr(
checked_jaxpr.jaxpr, checked_jaxpr.consts, *jaxpr_args)
output_errors, _ = split_list(result_flat, [num_err_vals])
# Store new errors back in the error refs.

View File

@ -26,6 +26,7 @@ from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src import state
@ -359,7 +360,12 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
# of bounds, it will instead move the start_index backwards so the slice
# will fit in memory.
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
idx_dtype = dtypes.canonicalize_dtype(jnp.int64)
out_ones = lax.dynamic_slice(
ref,
[jnp.astype(s, idx_dtype) for s in slice_starts],
slice_sizes=slice_sizes,
)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):

View File

@ -52,6 +52,7 @@ pytype_strict_library(
"//jax",
"//jax:ad_util",
"//jax:api_util",
"//jax:config",
"//jax:core",
"//jax:mlir",
"//jax:partial_eval",

View File

@ -29,6 +29,7 @@ from jax import tree_util
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
@ -2263,9 +2264,10 @@ def _for_lowering_rule(
del which_linear
if reverse or unroll != 1:
raise NotImplementedError
lower_bound = _i32_constant(0)
upper_bound = _i32_constant(nsteps)
step = _i32_constant(1)
_i_constant = _i64_constant if config.enable_x64.value else _i32_constant
lower_bound = _i_constant(0)
upper_bound = _i_constant(nsteps)
step = _i_constant(1)
init_args = map(_ensure_ir_value, args, ctx.avals_in)
# Partially discharge state from jaxpr for non-pointers
should_discharge = [

View File

@ -190,7 +190,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {indexers}. ")
if ref_aval.dtype != val_aval.dtype:
if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value dtype: {val_aval.dtype}. ")

View File

@ -132,6 +132,12 @@ class AbstractRef(core.AbstractValue):
def __init__(self, inner_aval: core.AbstractValue):
self.inner_aval = inner_aval
@property
def weak_type(self) -> bool:
if not hasattr(self.inner_aval, "weak_type"):
raise AttributeError
return self.inner_aval.weak_type
def update(self, inner_aval=None):
if inner_aval is None:
return AbstractRef(self.inner_aval)

View File

@ -41,8 +41,6 @@ jax_test(
disable_configs = [
"gpu",
"gpu_x32",
"gpu_a100",
"gpu_h100",
"gpu_p100",
"gpu_p100_x32",
],

View File

@ -29,6 +29,7 @@ from jax import lax
from jax import random
from jax._src import checkify
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _trace_to_jaxpr
@ -55,6 +56,10 @@ def smem_on_tpu():
return None
intx = dtypes.canonicalize_dtype(jnp.int64)
floatx = dtypes.canonicalize_dtype(jnp.float64)
@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk",
"interpret", "debug"])
def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False):
@ -65,7 +70,7 @@ def matmul(x, y, *, bm, bn, gm, bk, interpret, debug=False):
debug=debug,
grid=pl.cdiv(m, bm) * pl.cdiv(n, bn))
def matmul_kernel(x_ref, y_ref, o_ref):
pid = pl.program_id(axis=0)
pid = pl.program_id(axis=0).astype(intx)
num_pid_m = m // bm
num_pid_n = n // bn
num_pid_in_group = gm * num_pid_n
@ -133,8 +138,6 @@ class PallasBaseTest(jtu.JaxTestCase):
def setUp(self):
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled:
self.skipTest("On GPU the test works only in 32-bit")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
@ -151,13 +154,10 @@ class PallasBaseTest(jtu.JaxTestCase):
class PallasCallTest(PallasBaseTest):
def test_add_one(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32))
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx))
def add_one(x_ref, o_ref):
o_ref[()] = x_ref[()] + 1.
@ -177,14 +177,11 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_allclose(add_one(x), jnp.array([1.], jnp.float32))
def test_add_vector_block_spec(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
out_shape=jax.ShapeDtypeStruct((8,), intx),
in_specs=[pl.BlockSpec((1,), lambda i: i)],
out_specs=pl.BlockSpec((1,), lambda i: i),
grid=8,
@ -195,14 +192,11 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_allclose(add_one(jnp.arange(8)), jnp.arange(8) + 1)
def test_add_matrix_block_spec(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32),
out_shape=jax.ShapeDtypeStruct((8, 8), intx),
in_specs=[pl.BlockSpec((2, 2), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((2, 2), lambda i, j: (i, j)),
grid=(4, 4),
@ -225,13 +219,10 @@ class PallasCallTest(PallasBaseTest):
self.assertTrue(jnp.all(logical_and(x)))
def test_vector_indexing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
grid=1)
def index(x_ref, i_ref, o_ref):
o_ref[()] = x_ref[i_ref[()]]
@ -485,13 +476,10 @@ class PallasCallTest(PallasBaseTest):
self.assertAllClose(res[0:1], to_store)
def test_vector_slicing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), floatx),
grid=1)
def index(x_ref, idx_ref, o_ref):
idx = idx_ref[()]
@ -517,9 +505,6 @@ class PallasCallTest(PallasBaseTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: all sort of assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
k1, k2 = random.split(random.key(0))
@ -543,9 +528,6 @@ class PallasCallTest(PallasBaseTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: all sort of assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
k1, k2 = random.split(random.key(0))
@ -605,15 +587,12 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_allclose(dummy(x), jnp.ones_like(x), atol=1e-5, rtol=1e-5)
def test_with_input_output_aliasing(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
def add_inplace_kernel(_, o_ref, *, block_size):
pid = pl.program_id(axis=0) # we use a 1d launch grid so axis is 0
block_start = pid * block_size
offsets = block_start + jnp.arange(block_size)
offsets = block_start + jnp.arange(block_size, dtype=jnp.int32)
mask = offsets < o_ref.shape[0]
x = pl.load(o_ref, (offsets,), mask=mask)
output = x + 1
@ -634,13 +613,10 @@ class PallasCallTest(PallasBaseTest):
np.testing.assert_allclose(out, expected)
def test_using_pallas_slice(self):
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("On TPU the test works only in interpret mode")
m, n = 32, 4
out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32)
out_shape = jax.ShapeDtypeStruct((4, n), floatx)
@functools.partial(
self.pallas_call,
out_shape=out_shape,
@ -996,7 +972,7 @@ class PallasControlFlowTest(PallasBaseTest):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), jnp.float32),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), floatx),
in_specs=[
pl.BlockSpec((), lambda _, __: ()),
pl.BlockSpec((bx, 1), lambda i, _: (i, 0)),
@ -1073,15 +1049,13 @@ class PallasControlFlowTest(PallasBaseTest):
# dx, jnp.float32([0., 2, 4, 6, 0, 10, 12 + 12, 14]))
def test_scan_cond_vm_explicit_ref_arg(self):
if jtu.test_device_matches(["cpu"]):
# TODO: fix this
self.skipTest("Fails on CPU: assertion error")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("TODO: error on TPU")
program = jnp.int32([0, 1, 2, 3, 2])
params = jnp.arange(len(program) * 3.).reshape(len(program), 3)
x = jnp.arange(7.)
params = jnp.arange(len(program) * 3., dtype=jnp.float32)
params = params.reshape(len(program), 3)
x = jnp.arange(7., dtype=jnp.float32)
bx = 4
@jax.jit
@ -1113,7 +1087,7 @@ class PallasControlFlowTest(PallasBaseTest):
return state, program_ref, params_ref
out_ref[...] = jax.lax.fori_loop(
0, len(program), body_fn,
(jnp.zeros(x.shape), program_ref, params_ref))[0]
(jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0]
expected = (x * params[0, 0] +
2 * x * params[1, 1] +
@ -1127,16 +1101,14 @@ class PallasControlFlowTest(PallasBaseTest):
params, x)
def test_scan_cond_vm_closing_over_ref(self):
if jtu.test_device_matches(["cpu"]):
# TODO: fix this
self.skipTest("Fails on CPU: assertion error")
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
self.skipTest("TODO: error on TPU")
# ** Difference is the closure over params_ref in the switch branches. **
program = jnp.int32([0, 1, 2, 3, 2, -1])
params = jnp.arange(len(program) * 3.).reshape(len(program), 3)
x = jnp.arange(7.)
params = jnp.arange(len(program) * 3., dtype=jnp.float32)
params = params.reshape(len(program), 3)
x = jnp.arange(7., dtype=jnp.float32)
bx = 4
@jax.jit
@ -1169,7 +1141,7 @@ class PallasControlFlowTest(PallasBaseTest):
return state, program_ref, params_ref
out_ref[...] = jax.lax.fori_loop(
0, len(program), body_fn,
(jnp.zeros(x.shape), program_ref, params_ref))[0]
(jnp.zeros(x.shape, dtype=jnp.float32), program_ref, params_ref))[0]
expected = (x * params[0, 0] +
2 * x * params[1, 1] +
@ -1375,7 +1347,7 @@ class PallasControlFlowTest(PallasBaseTest):
kernel,
grid=(1,),
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
in_specs=[
pl.BlockSpec(
(1, 8, 128),
@ -1439,7 +1411,7 @@ class PallasControlFlowTest(PallasBaseTest):
kernel,
grid=(4,),
out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
out_shape=jax.ShapeDtypeStruct([1, 1], intx),
in_specs=[
pl.BlockSpec(
(1, 8, 128),
@ -1634,7 +1606,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
grad_tol = 1e-1
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), floatx),
grid=1)
def pallas_impl(x_ref, o_ref):
x = x_ref[()]
@ -1656,7 +1628,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
def test_pallas_around_grad(self, impl):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((), jnp.float32),
out_shape=jax.ShapeDtypeStruct((), floatx),
name=self.id().split(".")[-1],
grid=1)
def pallas_impl(x_ref, o_ref):
@ -1675,7 +1647,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
grad_tol = 1e-1
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx),
grid=1)
def pallas_impl(x_ref, o_ref):
x = x_ref[jnp.arange(2)]
@ -1730,16 +1702,9 @@ class PallasCallAutodifferentiationTest(PallasBaseTest):
class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest):
INTERPRET = True
def setUp(self):
super().setUp()
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
# TODO: assertion failures on CPU in 64-bit mode
self.skipTest("On CPU the test works only in 32-bit mode")
class PallasOutOfBoundsInterpreterTest(PallasBaseTest):
INTERPRET: bool = True
INTERPRET = True
def test_interpret_mode_out_of_bounds_access(self):
block_size = 32
@ -1818,7 +1783,7 @@ class PallasOutOfBoundsInterpreterTest(PallasBaseTest):
class PallasCheckifyInterpreterTest(PallasBaseTest):
# TODO(b/346651778): Support non-interpret mode checkify.
INTERPRET: bool = True
INTERPRET = True
def test_no_checkify(self,):
def kernel(y_ref):