mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before like conditionals, scans, etc. PiperOrigin-RevId: 730899808
This commit is contained in:
parent
a3a48af105
commit
3d87a01bea
@ -692,7 +692,8 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
def read_env(atom: jax_core.Atom):
|
||||
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
|
||||
|
||||
def write_env(var: jax_core.Var, val):
|
||||
def write_env(var: jax_core.Var, val, require_value: bool = True):
|
||||
env[var] = val
|
||||
# TODO(apaszke): Handle other avals (refs, etc.).
|
||||
if isinstance(aval := var.aval, jax_core.ShapedArray):
|
||||
# TODO(apaszke): Clarify the type invariants for lane semantics?
|
||||
@ -701,7 +702,12 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
# Those with empty shapes should be represented by their scalar type.
|
||||
mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype)
|
||||
if not isinstance(val, ir.Value):
|
||||
raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}")
|
||||
if require_value:
|
||||
raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}")
|
||||
else:
|
||||
if var.aval.shape:
|
||||
raise AssertionError("Only scalars can be represented by non-ir.Values")
|
||||
return # Skip following checks.
|
||||
if aval.shape:
|
||||
if not ir.VectorType.isinstance(val.type):
|
||||
raise AssertionError(f"Non-scalar arrays must be represented by vectors, got: {val.type}")
|
||||
@ -715,10 +721,9 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
raise AssertionError(f"Scalars must be represented by non-vector types, got: {val.type}")
|
||||
if val.type != mlir_dtype:
|
||||
raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}")
|
||||
env[var] = val
|
||||
|
||||
map(write_env, jaxpr.constvars, consts)
|
||||
map(write_env, jaxpr.invars, args)
|
||||
map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
|
||||
# TODO(justinfu): Handle transform scopes.
|
||||
last_local_name_stack: list[str] = []
|
||||
named_regions = []
|
||||
@ -958,9 +963,9 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree):
|
||||
if shape:
|
||||
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
|
||||
indices = [zero_index for _ in range(len(shape))]
|
||||
return vector_dialect.load(ty, x_smem, indices)
|
||||
else:
|
||||
indices = []
|
||||
return vector_dialect.load(ty, x_smem, indices)
|
||||
return memref_dialect.load(x_smem, [])
|
||||
|
||||
|
||||
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane)
|
||||
@ -1019,10 +1024,11 @@ def _swap_lowering_rule_wg(
|
||||
if shape:
|
||||
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
|
||||
indices = [zero_index for _ in range(len(shape))]
|
||||
old_value = vector_dialect.load(ty, x_smem, indices)
|
||||
vector_dialect.store(value, x_smem, indices)
|
||||
else:
|
||||
indices = []
|
||||
old_value = vector_dialect.load(ty, x_smem, indices)
|
||||
vector_dialect.store(value, x_smem, indices)
|
||||
old_value = memref_dialect.load(x_smem, [])
|
||||
memref_dialect.store(value, x_smem, [])
|
||||
return old_value
|
||||
|
||||
|
||||
@ -1051,6 +1057,7 @@ def _slice_lowering_rule(
|
||||
|
||||
|
||||
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
|
||||
if len(cases) != 2:
|
||||
raise NotImplementedError(
|
||||
@ -1059,11 +1066,22 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
|
||||
)
|
||||
pred_aval, *cases_avals = ctx.avals_in
|
||||
[out_aval] = ctx.avals_out
|
||||
pred = _ensure_fa(pred, pred_aval.dtype)
|
||||
cases = _bcast(*cases, *cases_avals, out_aval)
|
||||
# ``select`` expects the first case to be the true branch, but ``select_n``
|
||||
# orders the cases in reverse.
|
||||
return pred.select(*reversed(cases))
|
||||
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
|
||||
pred = _ensure_fa(pred, pred_aval.dtype)
|
||||
cases = _bcast(*cases, *cases_avals, out_aval)
|
||||
# ``select`` expects the first case to be the true branch, but ``select_n``
|
||||
# orders the cases in reverse.
|
||||
return pred.select(*reversed(cases))
|
||||
else:
|
||||
pred = _ensure_ir_value(pred, pred_aval.dtype)
|
||||
cases = [_ensure_ir_value(c, c_aval.dtype) for c, c_aval in zip(cases, cases_avals)]
|
||||
# TODO(bchetioui): support implicit broadcast.
|
||||
if any(a.shape != out_aval.shape for a in ctx.avals_in):
|
||||
raise NotImplementedError(
|
||||
"Implicit broadcast not implemented with warpgroup semantics")
|
||||
# ``select`` expects the first case to be the true branch, but ``select_n``
|
||||
# orders the cases in reverse.
|
||||
return arith_dialect.select(pred, *reversed(cases))
|
||||
|
||||
|
||||
@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane)
|
||||
@ -1148,11 +1166,13 @@ def _convert_element_type_lowering_rule_wg(
|
||||
elif from_integer and to_integer:
|
||||
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
|
||||
convert = arith_dialect.trunci
|
||||
else:
|
||||
elif ir.IntegerType(cur_dtype).width < ir.IntegerType(new_dtype).width:
|
||||
if mgpu_utils.is_signed(x_aval.dtype):
|
||||
convert = arith_dialect.extsi
|
||||
else:
|
||||
convert = arith_dialect.extui
|
||||
else:
|
||||
convert = lambda _, x: x # signed <-> unsigned conversions
|
||||
elif from_integer and to_float:
|
||||
if mgpu_utils.is_signed(x_aval.dtype):
|
||||
convert = arith_dialect.sitofp
|
||||
@ -1229,20 +1249,17 @@ for op, si_impl, ui_impl, f_impl in [
|
||||
arith_dialect.divf,
|
||||
),
|
||||
(lax.rem_p, arith_dialect.remsi, arith_dialect.remui, arith_dialect.remf),
|
||||
(lax.and_p, arith_dialect.andi, arith_dialect.andi, None),
|
||||
(lax.or_p, arith_dialect.ori, arith_dialect.ori, None),
|
||||
(lax.xor_p, arith_dialect.xori, arith_dialect.xori, None),
|
||||
(
|
||||
lax.max_p,
|
||||
arith_dialect.maxsi,
|
||||
arith_dialect.maxui,
|
||||
arith_dialect.maxnumf,
|
||||
arith_dialect.maximumf,
|
||||
),
|
||||
(
|
||||
lax.min_p,
|
||||
arith_dialect.minsi,
|
||||
arith_dialect.minui,
|
||||
arith_dialect.minnumf,
|
||||
arith_dialect.minimumf,
|
||||
),
|
||||
]:
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial(
|
||||
@ -1252,6 +1269,23 @@ for op, si_impl, ui_impl, f_impl in [
|
||||
f_impl=f_impl,
|
||||
)
|
||||
|
||||
|
||||
def _binary_boolean_op_lowering_rule_wg(
|
||||
ctx: LoweringRuleContext, x, y, *, impl
|
||||
):
|
||||
x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
return impl(x, y)
|
||||
|
||||
for op, impl in [
|
||||
(lax.and_p, arith_dialect.andi),
|
||||
(lax.or_p, arith_dialect.ori),
|
||||
(lax.xor_p, arith_dialect.xori),
|
||||
]:
|
||||
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial(
|
||||
_binary_boolean_op_lowering_rule_wg,
|
||||
impl=impl,
|
||||
)
|
||||
|
||||
CmpIPred = arith_dialect.CmpIPredicate
|
||||
CmpFPred = arith_dialect.CmpFPredicate
|
||||
|
||||
@ -1262,7 +1296,7 @@ def _comparison_lowering_rule_wg(
|
||||
x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out)
|
||||
if jnp.issubdtype(x_aval, jnp.signedinteger):
|
||||
return arith_dialect.cmpi(si_pred, x, y)
|
||||
elif jnp.issubdtype(x_aval, jnp.integer):
|
||||
elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool):
|
||||
return arith_dialect.cmpi(ui_pred, x, y)
|
||||
elif jnp.issubdtype(x_aval, jnp.floating):
|
||||
return arith_dialect.cmpf(f_pred, x, y)
|
||||
@ -1452,6 +1486,19 @@ def _debug_print_lowering_rule(
|
||||
|
||||
return ()
|
||||
|
||||
@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _debug_print_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
fmt,
|
||||
has_placeholders: bool,
|
||||
):
|
||||
del ctx, has_placeholders # Unused.
|
||||
if args:
|
||||
raise NotImplementedError("debug_print only supports string messages in warpgroup semantics")
|
||||
mgpu.debug_print(fmt)
|
||||
return ()
|
||||
|
||||
|
||||
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
@ -1593,13 +1640,18 @@ def _lower_jaxpr_to_for_loop(
|
||||
out_avals = ctx.avals_out[-len(arg_avals):]
|
||||
|
||||
is_acc = [isinstance(v, mgpu.WGMMAAccumulator) for v in args]
|
||||
def as_fas(vals, avals):
|
||||
def as_values(vals, avals):
|
||||
if is_acc != [isinstance(v, mgpu.WGMMAAccumulator) for v in vals]:
|
||||
raise ValueError("Unexpected loop carry w.r.t. accumulators.")
|
||||
|
||||
return [v if a else _ensure_fa(v, av) for a, v, av in zip(is_acc, vals, avals)]
|
||||
_ensure = (
|
||||
_ensure_fa
|
||||
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane
|
||||
else _ensure_ir_value
|
||||
)
|
||||
return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)]
|
||||
|
||||
@mgpu.fori(length, as_fas(args, arg_avals))
|
||||
@mgpu.fori(length, as_values(args, arg_avals))
|
||||
def loop(loop_index, body_args):
|
||||
if has_loop_index:
|
||||
loop_index = arith_dialect.addi(loop_index, start)
|
||||
@ -1607,14 +1659,15 @@ def _lower_jaxpr_to_for_loop(
|
||||
else:
|
||||
jaxpr_args = [*consts, *body_args]
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args
|
||||
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args, thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
return as_fas(outs, out_avals)
|
||||
return as_values(outs, out_avals)
|
||||
|
||||
return loop.results
|
||||
|
||||
|
||||
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _scan_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
@ -1781,6 +1834,7 @@ def _while_lowering_rule(
|
||||
|
||||
|
||||
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane)
|
||||
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup)
|
||||
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
index_aval, *_arg_avals = ctx.avals_in
|
||||
|
||||
@ -1800,7 +1854,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
# extract the return types
|
||||
with ir.InsertionPoint(ir.Module.create().body):
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args
|
||||
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args, thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))]
|
||||
del outs
|
||||
@ -1822,7 +1876,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
|
||||
for branch, region in zip(branches, regions):
|
||||
with ir.InsertionPoint(region.blocks.append()):
|
||||
outs = lower_jaxpr_to_mosaic_gpu(
|
||||
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts
|
||||
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts, thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
|
||||
yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out))
|
||||
@ -1917,48 +1971,41 @@ def _bcast_wg(
|
||||
if not out_aval.shape:
|
||||
return _ensure_ir_value(x, x_aval.dtype), _ensure_ir_value(y, y_aval.dtype)
|
||||
x_dtype = x_aval.dtype
|
||||
if not isinstance(x, ir.Value) or not ir.VectorType.isinstance(x.type):
|
||||
if not isinstance(x, ir.Value):
|
||||
if x_aval.weak_type:
|
||||
x_dtype = y_aval.dtype
|
||||
x = _ensure_vector(x, x_dtype)
|
||||
x = _ensure_ir_value(x, x_dtype)
|
||||
y_dtype = y_aval.dtype
|
||||
if not isinstance(y, ir.Value) or not ir.VectorType.isinstance(y.type):
|
||||
if not isinstance(y, ir.Value):
|
||||
if y_aval.weak_type:
|
||||
y_dtype = x_aval.dtype
|
||||
y = _ensure_vector(y, y_dtype)
|
||||
if x_aval.shape != out_aval.shape:
|
||||
assert not x_aval.shape # TODO(slebedev): Support non-scalar inputs.
|
||||
y = _ensure_ir_value(y, y_dtype)
|
||||
if not ir.VectorType.isinstance(x.type):
|
||||
assert not x_aval.shape
|
||||
x = vector_dialect.splat(
|
||||
ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(x_dtype)),
|
||||
x,
|
||||
)
|
||||
if y_aval.shape != out_aval.shape:
|
||||
assert not y_aval.shape # TODO(slebedev): Support non-scalar inputs.
|
||||
elif x_aval.shape != out_aval.shape:
|
||||
raise NotImplementedError("Unsupported broadcast")
|
||||
if not ir.VectorType.isinstance(y.type):
|
||||
assert not y_aval.shape
|
||||
y = vector_dialect.splat(
|
||||
ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(y_dtype)),
|
||||
y,
|
||||
)
|
||||
elif y_aval.shape != out_aval.shape:
|
||||
raise NotImplementedError("Unsupported broadcast")
|
||||
return x, y
|
||||
|
||||
|
||||
def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type):
|
||||
assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
return x
|
||||
|
||||
if isinstance(x, (np.number, np.ndarray, int, float)):
|
||||
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
|
||||
|
||||
raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}")
|
||||
|
||||
|
||||
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
|
||||
if isinstance(x, ir.Value):
|
||||
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
|
||||
if ir.VectorType.isinstance(x.type):
|
||||
assert ir.VectorType(x.type).element_type == mlir_dtype
|
||||
else:
|
||||
assert x.type == mlir_dtype
|
||||
assert x.type == mlir_dtype, (x.type, mlir_dtype)
|
||||
return x
|
||||
elif isinstance(x, mgpu.FragmentedArray):
|
||||
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)
|
||||
|
@ -880,7 +880,7 @@ def _jaxpr_call_lowering_rule(
|
||||
program_ids[axis] = lowering._program_id(axis, ctx.module_ctx.squashed_dims)
|
||||
new_module_ctx = dataclasses.replace(ctx.module_ctx, program_ids=program_ids)
|
||||
return lowering.lower_jaxpr_to_mosaic_gpu(
|
||||
new_module_ctx, ctx.launch_ctx, jaxpr, args
|
||||
new_module_ctx, ctx.launch_ctx, jaxpr, args, thread_semantics=ctx.thread_semantics,
|
||||
)
|
||||
|
||||
|
||||
|
@ -114,8 +114,6 @@ class PallasCallTest(PallasTest):
|
||||
thread_semantics=[*plgpu.ThreadSemantics],
|
||||
)
|
||||
def test_binary_op(self, op, dtype, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("Needs scan_p WG lowering")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -146,8 +144,6 @@ class PallasCallTest(PallasTest):
|
||||
thread_semantics=[*plgpu.ThreadSemantics],
|
||||
)
|
||||
def test_comparison_op(self, op, dtype, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("Needs scan_p WG lowering")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -319,8 +315,6 @@ class PallasCallTest(PallasTest):
|
||||
thread_semantics=[*plgpu.ThreadSemantics],
|
||||
)
|
||||
def test_copy_smem_to_gmem(self, indexer, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("Needs scan_p WG lowering")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
@ -788,8 +782,6 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
|
||||
def test_run_scoped(self, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("Needs scan_p WG lowering")
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
def body(tmp_ref):
|
||||
@ -906,12 +898,18 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
|
||||
np.testing.assert_array_equal(kernel(x), x)
|
||||
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_array(self, force_while):
|
||||
@parameterized.product(
|
||||
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
)
|
||||
def test_fori_loop_array(self, force_while, thread_semantics):
|
||||
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
# TODO(apaszke,bchetioui,slebedev): Support while + array carries.
|
||||
self.skipTest("WG semantics unsupported")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
# Equivalent to x_ref[...] + 2 + 3.
|
||||
@ -920,12 +918,17 @@ class PallasCallTest(PallasTest):
|
||||
x = jnp.arange(256, dtype=jnp.int32)
|
||||
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
|
||||
|
||||
@parameterized.parameters(False, True)
|
||||
def test_fori_loop_scalar(self, force_while):
|
||||
@parameterized.product(
|
||||
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
|
||||
)
|
||||
def test_fori_loop_scalar(self, force_while, thread_semantics):
|
||||
if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup:
|
||||
self.skipTest("WG semantics does not support force_while.")
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
|
||||
)
|
||||
def kernel(o_ref):
|
||||
# Equivalent to 2 + 3.
|
||||
@ -1035,25 +1038,26 @@ class PallasCallTest(PallasTest):
|
||||
with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"):
|
||||
kernel()
|
||||
|
||||
def test_cond(self):
|
||||
@parameterized.parameters([*plgpu.ThreadSemantics])
|
||||
def test_cond(self, thread_semantics):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
|
||||
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
acc = _sum_same_dtype(x_ref[...])
|
||||
jax.lax.cond(
|
||||
acc % 2 == 0,
|
||||
lambda: pl.debug_print("acc * 2: {}", acc * 2),
|
||||
lambda: pl.debug_print("acc: {}", acc),
|
||||
x_ref[0] % 2 == 0,
|
||||
lambda: pl.debug_print("acc % 2"),
|
||||
lambda: pl.debug_print("acc"),
|
||||
)
|
||||
o_ref[...] = jnp.broadcast_to(acc, o_ref.shape)
|
||||
o_ref[...] = jnp.broadcast_to(jnp.asarray(0, dtype=o_ref.dtype), o_ref.shape)
|
||||
|
||||
x = jnp.arange(256, dtype=jnp.int32)
|
||||
x = jnp.full((256,), 1234, dtype=jnp.int32)
|
||||
with self.capture_stdout() as output:
|
||||
jax.block_until_ready(kernel(x))
|
||||
|
||||
self.assertIn("acc * 2:", output())
|
||||
self.assertIn("acc % 2", output())
|
||||
|
||||
def test_cond_returning_array(self):
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user