mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[x64] deprecate unsafe type casting in scatter-update operations
This commit is contained in:
parent
ca01d1b411
commit
d2f80ef117
@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
traces as an alternative to the Tensorboard UI.
|
||||
* Added a `jax.named_scope` context manager that adds profiler metadata to
|
||||
Python programs (similar to `jax.named_call`).
|
||||
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
|
||||
dtype casts are deprecated, and now result in a `FutureWarning`.
|
||||
In a future release, this will become an error. An example of an unsafe implicit
|
||||
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
|
||||
silently truncated to `1`.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
dtype = lax.dtype(x)
|
||||
weak_type = dtypes.is_weakly_typed(x)
|
||||
|
||||
if dtype != dtypes.result_type(x, y):
|
||||
# TODO(jakevdp): change this to an error after the deprecation period.
|
||||
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
|
||||
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
|
||||
"In future JAX releases this will result in an error.",
|
||||
FutureWarning)
|
||||
|
||||
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
||||
normalize_indices=normalize_indices)
|
||||
|
@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
||||
|
||||
# This assumes that Q's leaves all have the same dimension in the last
|
||||
# axis.
|
||||
r = jnp.zeros(tree_leaves(Q)[0].shape[-1])
|
||||
Q0 = tree_leaves(Q)[0]
|
||||
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
|
||||
q = x
|
||||
xnorm_scaled = xnorm / jnp.sqrt(2)
|
||||
xnorm_scaled = xnorm / jnp.sqrt(2.0)
|
||||
|
||||
def body_function(carry):
|
||||
k, q, r, qnorm_scaled = carry
|
||||
@ -368,7 +369,7 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
||||
def qnorm(carry):
|
||||
k, _, q, qnorm_scaled = carry
|
||||
_, qnorm = _safe_normalize(q)
|
||||
qnorm_scaled = qnorm / jnp.sqrt(2)
|
||||
qnorm_scaled = qnorm / jnp.sqrt(2.0)
|
||||
return (k, False, q, qnorm_scaled)
|
||||
|
||||
init = (k, True, q, qnorm_scaled)
|
||||
|
@ -77,7 +77,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
"max", "get"])
|
||||
def test_jit_oob_update(self, update_fn):
|
||||
def f(x, i):
|
||||
return getattr(x.at[i], update_fn)(1.)
|
||||
return getattr(x.at[i], update_fn)(1)
|
||||
|
||||
f = jax.jit(f)
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
@ -960,7 +960,7 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
|
||||
|
||||
def testIndexingWeakTypes(self):
|
||||
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True)
|
||||
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
|
||||
|
||||
a = x.at[0].set(1.0)
|
||||
self.assertEqual(a.dtype, x.dtype)
|
||||
@ -974,6 +974,67 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(c.dtype, x.dtype)
|
||||
self.assertTrue(dtypes.is_weakly_typed(c))
|
||||
|
||||
def testIndexingTypePromotion(self):
|
||||
def _check(x_type, y_type):
|
||||
x = jnp.arange(5, dtype=x_type)
|
||||
y = y_type(0)
|
||||
out = x.at[0].set(y)
|
||||
self.assertEqual(x.dtype, out.dtype)
|
||||
|
||||
@jtu.ignore_warning(category=np.ComplexWarning,
|
||||
message="Casting complex values to real")
|
||||
def _check_warns(x_type, y_type, msg):
|
||||
with self.assertWarnsRegex(FutureWarning, msg):
|
||||
_check(x_type, y_type)
|
||||
|
||||
def _check_raises(x_type, y_type, msg):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
_check(x_type, y_type)
|
||||
|
||||
# Matching dtypes are always OK
|
||||
_check(jnp.int32, jnp.int32)
|
||||
_check(jnp.float32, jnp.float32)
|
||||
_check(jnp.complex64, jnp.complex64)
|
||||
|
||||
# Weakly-typed y values promote.
|
||||
_check(jnp.int32, int)
|
||||
_check(jnp.float32, int)
|
||||
_check(jnp.float32, float)
|
||||
_check(jnp.complex64, int)
|
||||
_check(jnp.complex64, float)
|
||||
_check(jnp.complex64, complex)
|
||||
|
||||
# in standard promotion mode, strong types can promote.
|
||||
msg = "scatter inputs have incompatible types"
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
_check(jnp.int32, jnp.int16)
|
||||
_check(jnp.float32, jnp.float16)
|
||||
_check(jnp.float32, jnp.int32)
|
||||
_check(jnp.complex64, jnp.int32)
|
||||
_check(jnp.complex64, jnp.float32)
|
||||
|
||||
# TODO(jakevdp): make these _check_raises
|
||||
_check_warns(jnp.int16, jnp.int32, msg)
|
||||
_check_warns(jnp.int32, jnp.float32, msg)
|
||||
_check_warns(jnp.int32, jnp.complex64, msg)
|
||||
_check_warns(jnp.float16, jnp.float32, msg)
|
||||
_check_warns(jnp.float32, jnp.complex64, msg)
|
||||
|
||||
# in strict promotion mode, strong types do not promote.
|
||||
msg = "Input dtypes .* have no available implicit dtype promotion path"
|
||||
with jax.numpy_dtype_promotion('strict'):
|
||||
_check_raises(jnp.int32, jnp.int16, msg)
|
||||
_check_raises(jnp.float32, jnp.float16, msg)
|
||||
_check_raises(jnp.float32, jnp.int32, msg)
|
||||
_check_raises(jnp.complex64, jnp.int32, msg)
|
||||
_check_raises(jnp.complex64, jnp.float32, msg)
|
||||
|
||||
_check_raises(jnp.int16, jnp.int32, msg)
|
||||
_check_raises(jnp.int32, jnp.float32, msg)
|
||||
_check_raises(jnp.int32, jnp.complex64, msg)
|
||||
_check_raises(jnp.float16, jnp.float32, msg)
|
||||
_check_raises(jnp.float32, jnp.complex64, msg)
|
||||
|
||||
|
||||
def _broadcastable_shapes(shape):
|
||||
"""Returns all shapes that broadcast to `shape`."""
|
||||
@ -989,6 +1050,20 @@ def _broadcastable_shapes(shape):
|
||||
yield list(reversed(x))
|
||||
|
||||
|
||||
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
|
||||
def _can_cast(from_, to):
|
||||
return lax.dtype(to) == dtypes.result_type(from_, to)
|
||||
|
||||
|
||||
def _compatible_dtypes(op, dtype, inexact=False):
|
||||
if op == UpdateOps.ADD:
|
||||
return [dtype]
|
||||
elif inexact:
|
||||
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
|
||||
else:
|
||||
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
|
||||
|
||||
|
||||
class UpdateOps(enum.Enum):
|
||||
UPDATE = 0
|
||||
ADD = 1
|
||||
@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype))
|
||||
for mode in s(MODES))))
|
||||
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, mode):
|
||||
@ -1083,7 +1158,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1106,7 +1181,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1130,7 +1205,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for op in s(UpdateOps)
|
||||
for dtype in s(UpdateOps.dtypes(op))
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1157,7 +1232,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||
for dtype in float_dtypes
|
||||
for update_shape in _broadcastable_shapes(update_shape)
|
||||
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)))
|
||||
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
|
||||
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1184,7 +1259,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
else [UpdateOps.ADD])
|
||||
for dtype in s(float_dtypes)
|
||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes))))
|
||||
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
|
||||
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||
indexer, op, unique_indices):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
Loading…
x
Reference in New Issue
Block a user