diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3f301c4f..263fa5961 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -16,34 +16,34 @@ from __future__ import annotations from collections.abc import Sequence -import itertools from functools import partial, reduce +import itertools from typing import Any, Callable import jax from jax import api_util -from jax import tree_util from jax import lax +from jax import tree_util +from jax._src import ad_util from jax._src import config +from jax._src import core as jax_core +from jax._src import effects +from jax._src import linear_util as lu from jax._src import state from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla -from jax._src import ad_util -from jax._src import core as jax_core -from jax._src.state import primitives as sp -from jax._src import linear_util as lu +from jax._src.pallas import core as pallas_core from jax._src.state import discharge as state_discharge +from jax._src.state import primitives as sp from jax._src.util import ( split_list, safe_map, safe_zip, weakref_lru_cache, tuple_insert, partition_list, merge_lists) import jax.numpy as jnp import numpy as np -from jax._src.pallas import core as pallas_core - map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -302,8 +302,14 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)] ) invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs) - # TODO(sharadmv): Fix state effect tracking after invar switch. - jvp_jaxpr = jvp_jaxpr.replace(invars=invars) + effs = [] + for eff in jvp_jaxpr.effects: + if isinstance(eff, effects.JaxprInputEffect): + eff = eff.replace( + input_index=invars.index(jvp_jaxpr.invars[eff.input_index]) + ) + effs.append(eff) + jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs) if debug: print(jvp_jaxpr) in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)]) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 535731201..51c7ece8e 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -123,7 +123,7 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): return matmul_kernel(x, y) -class PallasTest(parameterized.TestCase): +class PallasTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): @@ -459,7 +459,7 @@ class PallasCallTest(PallasTest): ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): if self.INTERPRET: - self.skipTest("No broadcasting checks in pl.load in interepreter mode") + self.skipTest("No broadcasting checks in pl.load in interpreter mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)