mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Updated the JVP rule for pallas_call_p to propagate new invar indices to effects
Prior to this change some of the tests in PallasTest were failing under JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check. PiperOrigin-RevId: 638652928
This commit is contained in:
parent
8729952d82
commit
daa99025b9
@ -16,34 +16,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import itertools
|
||||
from functools import partial, reduce
|
||||
import itertools
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import state
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.util import (
|
||||
split_list, safe_map, safe_zip, weakref_lru_cache,
|
||||
tuple_insert, partition_list, merge_lists)
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
@ -302,8 +302,14 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
jvp_jaxpr.invars, [len(primals), len(out_shapes), len(tangents)]
|
||||
)
|
||||
invars = (*primal_refs, *tangent_refs, *primal_out_refs, *tangent_out_refs)
|
||||
# TODO(sharadmv): Fix state effect tracking after invar switch.
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=invars)
|
||||
effs = []
|
||||
for eff in jvp_jaxpr.effects:
|
||||
if isinstance(eff, effects.JaxprInputEffect):
|
||||
eff = eff.replace(
|
||||
input_index=invars.index(jvp_jaxpr.invars[eff.input_index])
|
||||
)
|
||||
effs.append(eff)
|
||||
jvp_jaxpr = jvp_jaxpr.replace(invars=invars, effects=effs)
|
||||
if debug:
|
||||
print(jvp_jaxpr)
|
||||
in_bms, out_bms = split_list(grid_mapping.block_mappings, [len(primals)])
|
||||
|
@ -123,7 +123,7 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False):
|
||||
return matmul_kernel(x, y)
|
||||
|
||||
|
||||
class PallasTest(parameterized.TestCase):
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
|
||||
def setUp(self):
|
||||
@ -459,7 +459,7 @@ class PallasCallTest(PallasTest):
|
||||
)
|
||||
def test_invalid_broadcasted_load(self, x_shape, mask_shape):
|
||||
if self.INTERPRET:
|
||||
self.skipTest("No broadcasting checks in pl.load in interepreter mode")
|
||||
self.skipTest("No broadcasting checks in pl.load in interpreter mode")
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user