mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate jax.lax.tie_in
This commit is contained in:
parent
71c9be14b8
commit
91a33362de
@ -84,6 +84,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {mod}`jax.random`: passing batched keys directly to random number generation functions,
|
||||
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
|
||||
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
|
||||
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.
|
||||
|
||||
## jaxlib 0.4.24
|
||||
|
||||
|
@ -146,7 +146,7 @@ class ControlFlowOpsTest(tf_test_util.JaxToTfTestCase):
|
||||
# of the lax.while primitive.
|
||||
def cond(idx_carry):
|
||||
i, c = idx_carry
|
||||
return i < jnp.sum(lax.tie_in(i, cond_const)) # Capture cond_const
|
||||
return i < jnp.sum(cond_const) # Capture cond_const
|
||||
|
||||
def body(idx_carry):
|
||||
i, c = idx_carry
|
||||
|
@ -218,7 +218,7 @@ from jax._src.lax.lax import (
|
||||
tan_p as tan_p,
|
||||
tanh as tanh,
|
||||
tanh_p as tanh_p,
|
||||
tie_in as tie_in,
|
||||
tie_in as _deprecated_tie_in,
|
||||
top_k as top_k,
|
||||
top_k_p as top_k_p,
|
||||
transpose as transpose,
|
||||
@ -426,6 +426,11 @@ _deprecations = {
|
||||
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
|
||||
_deprecated_unop_dtype_rule,
|
||||
),
|
||||
# Added January 18 2023
|
||||
"tie_in": (
|
||||
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
|
||||
"Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
@ -438,6 +443,7 @@ if _typing.TYPE_CHECKING:
|
||||
standard_naryop = _deprecated_standard_naryop,
|
||||
standard_primitive = _deprecated_standard_primitive,
|
||||
standard_unop = _deprecated_standard_unop,
|
||||
tie_in = _deprecated_tie_in
|
||||
unop = _deprecated_unop,
|
||||
unop_dtype_rule = _deprecated_unop_dtype_rule,
|
||||
else:
|
||||
|
@ -7115,7 +7115,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
def test_hard_stuff2(self):
|
||||
@jax.custom_jvp
|
||||
def f(x):
|
||||
return lax.tie_in(x, np.zeros(x.shape, x.dtype))
|
||||
return np.zeros(x.shape, x.dtype)
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
|
@ -4568,15 +4568,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
7.66067839e-174], np.float64)
|
||||
self.assertAllClose(f(x), expected, check_dtypes=False)
|
||||
|
||||
def testIssue776(self):
|
||||
"""Tests that the scatter-add transpose rule instantiates symbolic zeros."""
|
||||
def f(u):
|
||||
y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
|
||||
# The transpose rule for lax.tie_in returns a symbolic zero for its first
|
||||
# argument.
|
||||
return lax.tie_in(y, 7.)
|
||||
# Test removed because tie_in is deprecated.
|
||||
# def testIssue776(self):
|
||||
# """Tests that the scatter-add transpose rule instantiates symbolic zeros."""
|
||||
# def f(u):
|
||||
# y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u)
|
||||
# # The transpose rule for lax.tie_in returns a symbolic zero for its first
|
||||
# # argument.
|
||||
# return lax.tie_in(y, 7.)
|
||||
|
||||
self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
|
||||
# self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,)))
|
||||
|
||||
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because this
|
||||
# is a numerical stability issue that should be solved with a custom jvp rule
|
||||
|
@ -2514,13 +2514,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
np.zeros((2, 2), dtype=np.float32),
|
||||
(np.int32(1), np.int16(2))))
|
||||
|
||||
def test_tie_in_error(self):
|
||||
raise SkipTest("test no longer needed after trivializing tie_in")
|
||||
# with core.skipping_checks():
|
||||
# with self.assertRaisesRegex(
|
||||
# TypeError, ".* of type .*tuple.* is not a valid JAX type"):
|
||||
# jax.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.)
|
||||
|
||||
def test_primitive_jaxtype_error(self):
|
||||
with jax.enable_checks(False):
|
||||
with self.assertRaisesRegex(
|
||||
|
Loading…
x
Reference in New Issue
Block a user