From 91a33362de22fa6792572d14f36fbcc2658e6d2a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 18 Jan 2024 13:13:47 -0800 Subject: [PATCH] Deprecate jax.lax.tie_in --- CHANGELOG.md | 1 + .../jax2tf/tests/control_flow_ops_test.py | 2 +- jax/lax/__init__.py | 8 +++++++- tests/api_test.py | 2 +- tests/lax_numpy_test.py | 17 +++++++++-------- tests/lax_test.py | 7 ------- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 97700970d..942a27cae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ Remember to align the itemized text with the first line of an item within a list * {mod}`jax.random`: passing batched keys directly to random number generation functions, such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching. + * {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0. ## jaxlib 0.4.24 diff --git a/jax/experimental/jax2tf/tests/control_flow_ops_test.py b/jax/experimental/jax2tf/tests/control_flow_ops_test.py index 8e3a1a0c0..253a5ffc6 100644 --- a/jax/experimental/jax2tf/tests/control_flow_ops_test.py +++ b/jax/experimental/jax2tf/tests/control_flow_ops_test.py @@ -146,7 +146,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase): # of the lax.while primitive. def cond(idx_carry): i, c = idx_carry - return i < jnp.sum(lax.tie_in(i, cond_const)) # Capture cond_const + return i < jnp.sum(cond_const) # Capture cond_const def body(idx_carry): i, c = idx_carry diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 703d0ede4..7c8ed3436 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -218,7 +218,7 @@ from jax._src.lax.lax import ( tan_p as tan_p, tanh as tanh, tanh_p as tanh_p, - tie_in as tie_in, + tie_in as _deprecated_tie_in, top_k as top_k, top_k_p as top_k_p, transpose as transpose, @@ -426,6 +426,11 @@ _deprecations = { "jax.lax.unop_dtype_rule is an internal API and has been deprecated.", _deprecated_unop_dtype_rule, ), + # Added January 18 2023 + "tie_in": ( + "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " + "Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in, + ), } import typing as _typing @@ -438,6 +443,7 @@ if _typing.TYPE_CHECKING: standard_naryop = _deprecated_standard_naryop, standard_primitive = _deprecated_standard_primitive, standard_unop = _deprecated_standard_unop, + tie_in = _deprecated_tie_in unop = _deprecated_unop, unop_dtype_rule = _deprecated_unop_dtype_rule, else: diff --git a/tests/api_test.py b/tests/api_test.py index cbe0f9b2e..228d38349 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7115,7 +7115,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_hard_stuff2(self): @jax.custom_jvp def f(x): - return lax.tie_in(x, np.zeros(x.shape, x.dtype)) + return np.zeros(x.shape, x.dtype) @f.defjvp def f_jvp(primals, tangents): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c4715808f..9bfcde74e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4568,15 +4568,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): 7.66067839e-174], np.float64) self.assertAllClose(f(x), expected, check_dtypes=False) - def testIssue776(self): - """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" - def f(u): - y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) - # The transpose rule for lax.tie_in returns a symbolic zero for its first - # argument. - return lax.tie_in(y, 7.) + # Test removed because tie_in is deprecated. + # def testIssue776(self): + # """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" + # def f(u): + # y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) + # # The transpose rule for lax.tie_in returns a symbolic zero for its first + # # argument. + # return lax.tie_in(y, 7.) - self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) + # self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this # is a numerical stability issue that should be solved with a custom jvp rule diff --git a/tests/lax_test.py b/tests/lax_test.py index 67e58b875..a70e3c1bd 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2514,13 +2514,6 @@ class LaxTest(jtu.JaxTestCase): np.zeros((2, 2), dtype=np.float32), (np.int32(1), np.int16(2)))) - def test_tie_in_error(self): - raise SkipTest("test no longer needed after trivializing tie_in") - # with core.skipping_checks(): - # with self.assertRaisesRegex( - # TypeError, ".* of type .*tuple.* is not a valid JAX type"): - # jax.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.) - def test_primitive_jaxtype_error(self): with jax.enable_checks(False): with self.assertRaisesRegex(