Updated the JVP rule for pallas_call_p to propagate new invar indices to effects

Prior to this change some of the tests in PallasTest were failing under
JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check.
PiperOrigin-RevId: 638652928
This commit is contained in:
Sergei Lebedev 2024-05-30 07:57:55 -07:00 committed by jax authors
parent 8729952d82
commit daa99025b9
2 changed files with 19 additions and 13 deletions

View File

@ -16,34 +16,34 @@
from __future__ import annotations
from collections.abc import Sequence
import itertools
from functools import partial, reduce
import itertools
from typing import Any, Callable
import jax
from jax import api_util
from jax import tree_util
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import config
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src.state import primitives as sp
from jax._src import linear_util as lu
from jax._src.pallas import core as pallas_core
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as sp
from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache,
tuple_insert, partition_list, merge_lists)
import jax.numpy as jnp
import numpy as np
from jax._src.pallas import core as pallas_core
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -302,8 +302,14 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
)
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
# TODO(sharadmv): Fix state effect tracking after invar switch.
jvp_jaxpr = jvp_jaxpr.replace(invars=invars)
effs = []
for eff in jvp_jaxpr.effects:
if isinstance(eff, effects.JaxprInputEffect):
eff = eff.replace(
input_index=invars.index(jvp_jaxpr.invars[eff.input_index])
)
effs.append(eff)
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
if debug:
print(jvp_jaxpr)
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])

View File

@ -123,7 +123,7 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
return matmul_kernel(x, y)
class PallasTest(parameterized.TestCase):
class PallasTest(jtu.JaxTestCase):
INTERPRET = False
def setUp(self):
@ -459,7 +459,7 @@ class PallasCallTest(PallasTest):
)
def test_invalid_broadcasted_load(self, x_shape, mask_shape):
if self.INTERPRET:
self.skipTest("No broadcasting checks in pl.load in interepreter mode")
self.skipTest("No broadcasting checks in pl.load in interpreter mode")
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)