[x64] deprecate unsafe type casting in scatter-update operations

This commit is contained in:
Jake VanderPlas 2022-06-09 15:21:49 -07:00
parent ca01d1b411
commit d2f80ef117
5 changed files with 100 additions and 11 deletions

View File

@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
traces as an alternative to the Tensorboard UI. traces as an alternative to the Tensorboard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to * Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`). Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
dtype casts are deprecated, and now result in a `FutureWarning`.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.
## jaxlib 0.3.11 (Unreleased) ## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -16,6 +16,7 @@
import sys import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import numpy as np import numpy as np
@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
dtype = lax.dtype(x) dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x) weak_type = dtypes.is_weakly_typed(x)
if dtype != dtypes.result_type(x, y):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
"In future JAX releases this will result in an error.",
FutureWarning)
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx, indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices) normalize_indices=normalize_indices)

View File

@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
# This assumes that Q's leaves all have the same dimension in the last # This assumes that Q's leaves all have the same dimension in the last
# axis. # axis.
r = jnp.zeros(tree_leaves(Q)[0].shape[-1]) Q0 = tree_leaves(Q)[0]
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
q = x q = x
xnorm_scaled = xnorm / jnp.sqrt(2) xnorm_scaled = xnorm / jnp.sqrt(2.0)
def body_function(carry): def body_function(carry):
k, q, r, qnorm_scaled = carry k, q, r, qnorm_scaled = carry
@ -368,7 +369,7 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
def qnorm(carry): def qnorm(carry):
k, _, q, qnorm_scaled = carry k, _, q, qnorm_scaled = carry
_, qnorm = _safe_normalize(q) _, qnorm = _safe_normalize(q)
qnorm_scaled = qnorm / jnp.sqrt(2) qnorm_scaled = qnorm / jnp.sqrt(2.0)
return (k, False, q, qnorm_scaled) return (k, False, q, qnorm_scaled)
init = (k, True, q, qnorm_scaled) init = (k, True, q, qnorm_scaled)

View File

@ -77,7 +77,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
"max", "get"]) "max", "get"])
def test_jit_oob_update(self, update_fn): def test_jit_oob_update(self, update_fn):
def f(x, i): def f(x, i):
return getattr(x.at[i], update_fn)(1.) return getattr(x.at[i], update_fn)(1)
f = jax.jit(f) f = jax.jit(f)
checked_f = checkify.checkify(f, errors=checkify.index_checks) checked_f = checkify.checkify(f, errors=checkify.index_checks)

View File

@ -960,7 +960,7 @@ class IndexingTest(jtu.JaxTestCase):
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32)) jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
def testIndexingWeakTypes(self): def testIndexingWeakTypes(self):
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True) x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
a = x.at[0].set(1.0) a = x.at[0].set(1.0)
self.assertEqual(a.dtype, x.dtype) self.assertEqual(a.dtype, x.dtype)
@ -974,6 +974,67 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(c.dtype, x.dtype) self.assertEqual(c.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(c)) self.assertTrue(dtypes.is_weakly_typed(c))
def testIndexingTypePromotion(self):
def _check(x_type, y_type):
x = jnp.arange(5, dtype=x_type)
y = y_type(0)
out = x.at[0].set(y)
self.assertEqual(x.dtype, out.dtype)
@jtu.ignore_warning(category=np.ComplexWarning,
message="Casting complex values to real")
def _check_warns(x_type, y_type, msg):
with self.assertWarnsRegex(FutureWarning, msg):
_check(x_type, y_type)
def _check_raises(x_type, y_type, msg):
with self.assertRaisesRegex(ValueError, msg):
_check(x_type, y_type)
# Matching dtypes are always OK
_check(jnp.int32, jnp.int32)
_check(jnp.float32, jnp.float32)
_check(jnp.complex64, jnp.complex64)
# Weakly-typed y values promote.
_check(jnp.int32, int)
_check(jnp.float32, int)
_check(jnp.float32, float)
_check(jnp.complex64, int)
_check(jnp.complex64, float)
_check(jnp.complex64, complex)
# in standard promotion mode, strong types can promote.
msg = "scatter inputs have incompatible types"
with jax.numpy_dtype_promotion('standard'):
_check(jnp.int32, jnp.int16)
_check(jnp.float32, jnp.float16)
_check(jnp.float32, jnp.int32)
_check(jnp.complex64, jnp.int32)
_check(jnp.complex64, jnp.float32)
# TODO(jakevdp): make these _check_raises
_check_warns(jnp.int16, jnp.int32, msg)
_check_warns(jnp.int32, jnp.float32, msg)
_check_warns(jnp.int32, jnp.complex64, msg)
_check_warns(jnp.float16, jnp.float32, msg)
_check_warns(jnp.float32, jnp.complex64, msg)
# in strict promotion mode, strong types do not promote.
msg = "Input dtypes .* have no available implicit dtype promotion path"
with jax.numpy_dtype_promotion('strict'):
_check_raises(jnp.int32, jnp.int16, msg)
_check_raises(jnp.float32, jnp.float16, msg)
_check_raises(jnp.float32, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.float32, msg)
_check_raises(jnp.int16, jnp.int32, msg)
_check_raises(jnp.int32, jnp.float32, msg)
_check_raises(jnp.int32, jnp.complex64, msg)
_check_raises(jnp.float16, jnp.float32, msg)
_check_raises(jnp.float32, jnp.complex64, msg)
def _broadcastable_shapes(shape): def _broadcastable_shapes(shape):
"""Returns all shapes that broadcast to `shape`.""" """Returns all shapes that broadcast to `shape`."""
@ -989,6 +1050,20 @@ def _broadcastable_shapes(shape):
yield list(reversed(x)) yield list(reversed(x))
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
def _can_cast(from_, to):
return lax.dtype(to) == dtypes.result_type(from_, to)
def _compatible_dtypes(op, dtype, inexact=False):
if op == UpdateOps.ADD:
return [dtype]
elif inexact:
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
else:
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
class UpdateOps(enum.Enum): class UpdateOps(enum.Enum):
UPDATE = 0 UPDATE = 0
ADD = 1 ADD = 1
@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps) for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op)) for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape)) for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes) for update_dtype in s(_compatible_dtypes(op, dtype))
for mode in s(MODES)))) for mode in s(MODES))))
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode): indexer, op, mode):
@ -1083,7 +1158,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps) for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op)) for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape)) for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op): indexer, op):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -1106,7 +1181,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps) for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op)) for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape)) for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
indexer, op): indexer, op):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -1130,7 +1205,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps) for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op)) for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape)) for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op): indexer, op):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -1157,7 +1232,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes for dtype in float_dtypes
for update_shape in _broadcastable_shapes(update_shape) for update_shape in _broadcastable_shapes(update_shape)
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes))) for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode): indexer, op, mode):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
@ -1184,7 +1259,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
else [UpdateOps.ADD]) else [UpdateOps.ADD])
for dtype in s(float_dtypes) for dtype in s(float_dtypes)
for update_shape in s(_broadcastable_shapes(update_shape)) for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes)))) for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype, def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, unique_indices): indexer, op, unique_indices):
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())