mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Pallas/Fuser] Add basic closed over consts support to pull_block_spec
PiperOrigin-RevId: 747657069
This commit is contained in:
parent
69d21c69ae
commit
4fa3cd91d3
@ -239,9 +239,7 @@ def pull_block_spec(
|
||||
jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr(
|
||||
f, *args, **kwargs
|
||||
)
|
||||
# TODO(sharadmv): handle these consts better, they should correspond to
|
||||
# scalar prefetch.
|
||||
del consts, out_tree_
|
||||
del out_tree_
|
||||
jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars)
|
||||
block_specs_ = jax.tree.map(
|
||||
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
||||
@ -263,6 +261,7 @@ def pull_block_spec(
|
||||
)
|
||||
kernel_fn = make_kernel_function(
|
||||
jaxpr,
|
||||
consts,
|
||||
in_tree,
|
||||
out_tree,
|
||||
read_usage_env,
|
||||
@ -408,6 +407,7 @@ def _pull_block_spec(
|
||||
|
||||
def make_kernel_function(
|
||||
jaxpr: core.Jaxpr,
|
||||
consts,
|
||||
in_tree,
|
||||
out_tree,
|
||||
read_usage_env,
|
||||
@ -505,6 +505,8 @@ def make_kernel_function(
|
||||
def write_env(var, val):
|
||||
env[var] = val
|
||||
|
||||
for const, constvar in zip(consts, jaxpr.constvars):
|
||||
env[constvar] = const
|
||||
for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages):
|
||||
if Usage.REGULAR in usage:
|
||||
env[invar] = arg
|
||||
@ -1232,6 +1234,7 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs):
|
||||
)
|
||||
kernel_fn = make_kernel_function(
|
||||
jaxpr,
|
||||
(),
|
||||
in_tree,
|
||||
out_tree,
|
||||
read_usage_env,
|
||||
@ -1289,6 +1292,7 @@ def _custom_jvp_call_eval_rule(
|
||||
)
|
||||
kernel_fn = make_kernel_function(
|
||||
jaxpr,
|
||||
(),
|
||||
in_tree,
|
||||
out_tree,
|
||||
read_usage_env,
|
||||
|
@ -769,6 +769,41 @@ class PullBlockSpecHOPTest(jtu.JaxTestCase):
|
||||
kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x
|
||||
)
|
||||
|
||||
def test_pull_block_spec_handles_closed_over_constants(self):
|
||||
x = jnp.ones((2, 512, 512))
|
||||
i = jnp.array(1)
|
||||
|
||||
def f():
|
||||
return x[i]
|
||||
|
||||
f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f)
|
||||
self.assertLen(new_values, 1)
|
||||
self.assertLen(scalar_prefetch_values, 1)
|
||||
|
||||
block_spec = pl.BlockSpec(
|
||||
(None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l)
|
||||
)
|
||||
kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec(
|
||||
f2,
|
||||
block_spec,
|
||||
grid=(2, 2, 4, 4),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values)
|
||||
self.assertLen(value_block_specs, 1)
|
||||
scalar_prefetch_values = jax.tree.map(
|
||||
lambda x: x[None], scalar_prefetch_values
|
||||
)
|
||||
fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x)
|
||||
new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),)
|
||||
# Try pulling again
|
||||
# This should not raise an error.
|
||||
_ = block_spec_lib.pull_block_spec(
|
||||
fn,
|
||||
block_spec,
|
||||
grid=(1,),
|
||||
scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(),
|
||||
)(new_values_type)
|
||||
|
||||
|
||||
class PushBlockSpecTest(parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user