Deprecate jax.lax.tie_in

This commit is contained in:
Jake VanderPlas 2024-01-18 13:13:47 -08:00
parent 71c9be14b8
commit 91a33362de
6 changed files with 19 additions and 18 deletions

View File

@ -84,6 +84,7 @@ Remember to align the itemized text with the first line of an item within a list
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.
## jaxlib 0.4.24

View File

@ -146,7 +146,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
# of the lax.while primitive.
def cond(idx_carry):
i, c = idx_carry
return i < jnp.sum(lax.tie_in(i, cond_const)) # Capture cond_const
return i < jnp.sum(cond_const) # Capture cond_const
def body(idx_carry):
i, c = idx_carry

View File

@ -218,7 +218,7 @@ from jax._src.lax.lax import (
tan_p as tan_p,
tanh as tanh,
tanh_p as tanh_p,
tie_in as tie_in,
tie_in as _deprecated_tie_in,
top_k as top_k,
top_k_p as top_k_p,
transpose as transpose,
@ -426,6 +426,11 @@ _deprecations = {
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
_deprecated_unop_dtype_rule,
),
# Added January 18 2023
"tie_in": (
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
"Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in,
),
}
import typing as _typing
@ -438,6 +443,7 @@ if _typing.TYPE_CHECKING:
standard_naryop = _deprecated_standard_naryop,
standard_primitive = _deprecated_standard_primitive,
standard_unop = _deprecated_standard_unop,
tie_in = _deprecated_tie_in
unop = _deprecated_unop,
unop_dtype_rule = _deprecated_unop_dtype_rule,
else:

View File

@ -7115,7 +7115,7 @@ class CustomJVPTest(jtu.JaxTestCase):
def test_hard_stuff2(self):
@jax.custom_jvp
def f(x):
return lax.tie_in(x, np.zeros(x.shape, x.dtype))
return np.zeros(x.shape, x.dtype)
@f.defjvp
def f_jvp(primals, tangents):

View File

@ -4568,15 +4568,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
7.66067839e-174], np.float64)
self.assertAllClose(f(x), expected, check_dtypes=False)
def testIssue776(self):
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
def f(u):
y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
# The transpose rule for lax.tie_in returns a symbolic zero for its first
# argument.
return lax.tie_in(y, 7.)
# Test removed because tie_in is deprecated.
# def testIssue776(self):
# """Tests that the scatter-add transpose rule instantiates symbolic zeros."""
# def f(u):
# y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
# # The transpose rule for lax.tie_in returns a symbolic zero for its first
# # argument.
# return lax.tie_in(y, 7.)
self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
# self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
# is a numerical stability issue that should be solved with a custom jvp rule

View File

@ -2514,13 +2514,6 @@ class LaxTest(jtu.JaxTestCase):
np.zeros((2, 2), dtype=np.float32),
(np.int32(1), np.int16(2))))
def test_tie_in_error(self):
raise SkipTest("test no longer needed after trivializing tie_in")
# with core.skipping_checks():
# with self.assertRaisesRegex(
# TypeError, ".* of type .*tuple.* is not a valid JAX type"):
# jax.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)
def test_primitive_jaxtype_error(self):
with jax.enable_checks(False):
with self.assertRaisesRegex(