From d7a22d37207e111001ea6d4500aaba42fed04e46 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 20 Jun 2024 15:17:37 -0700 Subject: [PATCH] [JAX] Teach jit fast path how to handle negative static_argnums correctly. PiperOrigin-RevId: 645172085 --- CHANGELOG.md | 4 ++++ tests/api_test.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80428fc72..ad4dfcecf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Remember to align the itemized text with the first line of an item within a list ## jaxlib 0.4.31 +* Bug fixes + * Fixed a bug that meant that negative static_argnums to a jit were mishandled + by the jit dispatch fast path. + ## jax 0.4.30 (June 18, 2024) * Changes diff --git a/tests/api_test.py b/tests/api_test.py index 89a9b4e1f..e80c9e186 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -58,6 +58,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_client from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -4354,9 +4355,14 @@ class APITest(jtu.JaxTestCase): g = jax.grad(f, argnums=-1) g(x, y) # doesn't crash + @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def test_jit_negative_static_argnums(self): - g = jax.jit(lambda x, y: x * y, static_argnums=-1) - g(1, 2) # doesn't crash + @partial(jax.jit, static_argnums=-1) + def g(x, y): + assert isinstance(y, int) + return x * y + for i in range(3): # Loop verifies we exercise both Python and C++ dispatch + self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): # https://github.com/google/jax/issues/12542