[Pallas/Fuser] Add basic closed over consts support to pull_block_spec

PiperOrigin-RevId: 747657069
This commit is contained in:
Sharad Vikram 2025-04-14 19:08:12 -07:00 committed by jax authors
parent 69d21c69ae
commit 4fa3cd91d3
2 changed files with 42 additions and 3 deletions

View File

@ -239,9 +239,7 @@ def pull_block_spec(
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
f, *args, **kwargs
)
# TODO(sharadmv): handle these consts better, they should correspond to
# scalar prefetch.
del consts, out_tree_
del out_tree_
jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars)
block_specs_ = jax.tree.map(
_unwrap_block_spec_scalar_prefetch, out_block_specs
@ -263,6 +261,7 @@ def pull_block_spec(
)
kernel_fn = make_kernel_function(
jaxpr,
consts,
in_tree,
out_tree,
read_usage_env,
@ -408,6 +407,7 @@ def _pull_block_spec(
def make_kernel_function(
jaxpr: core.Jaxpr,
consts,
in_tree,
out_tree,
read_usage_env,
@ -505,6 +505,8 @@ def make_kernel_function(
def write_env(var, val):
env[var] = val
for const, constvar in zip(consts, jaxpr.constvars):
env[constvar] = const
for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages):
if Usage.REGULAR in usage:
env[invar] = arg
@ -1232,6 +1234,7 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs):
)
kernel_fn = make_kernel_function(
jaxpr,
(),
in_tree,
out_tree,
read_usage_env,
@ -1289,6 +1292,7 @@ def _custom_jvp_call_eval_rule(
)
kernel_fn = make_kernel_function(
jaxpr,
(),
in_tree,
out_tree,
read_usage_env,

View File

@ -769,6 +769,41 @@ class PullBlockSpecHOPTest(jtu.JaxTestCase):
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x
)
def test_pull_block_spec_handles_closed_over_constants(self):
x = jnp.ones((2, 512, 512))
i = jnp.array(1)
def f():
return x[i]
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
self.assertLen(new_values, 1)
self.assertLen(scalar_prefetch_values, 1)
block_spec = pl.BlockSpec(
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
)
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
f2,
block_spec,
grid=(2, 2, 4, 4),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values)
self.assertLen(value_block_specs, 1)
scalar_prefetch_values = jax.tree.map(
lambda x: x[None], scalar_prefetch_values
)
fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x)
new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),)
# Try pulling again
# This should not raise an error.
_ = block_spec_lib.pull_block_spec(
fn,
block_spec,
grid=(1,),
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
)(new_values_type)
class PushBlockSpecTest(parameterized.TestCase):