mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
This commit is contained in:
parent
84d748f43c
commit
d7a22d3720
@ -10,6 +10,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
* Bug fixes
|
||||
* Fixed a bug that meant that negative static_argnums to a jit were mishandled
|
||||
by the jit dispatch fast path.
|
||||
|
||||
## jax 0.4.30 (June 18, 2024)
|
||||
|
||||
* Changes
|
||||
|
@ -58,6 +58,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -4354,9 +4355,14 @@ class APITest(jtu.JaxTestCase):
|
||||
g = jax.grad(f, argnums=-1)
|
||||
g(x, y) # doesn't crash
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
||||
def test_jit_negative_static_argnums(self):
|
||||
g = jax.jit(lambda x, y: x * y, static_argnums=-1)
|
||||
g(1, 2) # doesn't crash
|
||||
@partial(jax.jit, static_argnums=-1)
|
||||
def g(x, y):
|
||||
assert isinstance(y, int)
|
||||
return x * y
|
||||
for i in range(3): # Loop verifies we exercise both Python and C++ dispatch
|
||||
self.assertEqual(2 * i, g(2, i), msg=i)
|
||||
|
||||
def test_fastpath_cache_confusion(self):
|
||||
# https://github.com/google/jax/issues/12542
|
||||
|
Loading…
x
Reference in New Issue
Block a user