[x64] deprecate unsafe type casting in scatter-update operations

This commit is contained in:
Jake VanderPlas 2022-06-09 15:21:49 -07:00
parent ca01d1b411
commit d2f80ef117
5 changed files with 100 additions and 11 deletions

View File

@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
traces as an alternative to the Tensorboard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
dtype casts are deprecated, and now result in a `FutureWarning`.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -16,6 +16,7 @@
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import numpy as np
@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x)
if dtype != dtypes.result_type(x, y):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
"In future JAX releases this will result in an error.",
FutureWarning)
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)

View File

@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
# This assumes that Q's leaves all have the same dimension in the last
# axis.
r = jnp.zeros(tree_leaves(Q)[0].shape[-1])
Q0 = tree_leaves(Q)[0]
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
q = x
xnorm_scaled = xnorm / jnp.sqrt(2)
xnorm_scaled = xnorm / jnp.sqrt(2.0)
def body_function(carry):
k, q, r, qnorm_scaled = carry
@ -368,7 +369,7 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
def qnorm(carry):
k, _, q, qnorm_scaled = carry
_, qnorm = _safe_normalize(q)
qnorm_scaled = qnorm / jnp.sqrt(2)
qnorm_scaled = qnorm / jnp.sqrt(2.0)
return (k, False, q, qnorm_scaled)
init = (k, True, q, qnorm_scaled)

View File

@ -77,7 +77,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
"max", "get"])
def test_jit_oob_update(self, update_fn):
def f(x, i):
return getattr(x.at[i], update_fn)(1.)
return getattr(x.at[i], update_fn)(1)
f = jax.jit(f)
checked_f = checkify.checkify(f, errors=checkify.index_checks)

View File

@ -960,7 +960,7 @@ class IndexingTest(jtu.JaxTestCase):
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
def testIndexingWeakTypes(self):
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True)
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
a = x.at[0].set(1.0)
self.assertEqual(a.dtype, x.dtype)
@ -974,6 +974,67 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(c.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(c))
def testIndexingTypePromotion(self):
def _check(x_type, y_type):
x = jnp.arange(5, dtype=x_type)
y = y_type(0)
out = x.at[0].set(y)
self.assertEqual(x.dtype, out.dtype)
@jtu.ignore_warning(category=np.ComplexWarning,
message="Casting complex values to real")
def _check_warns(x_type, y_type, msg):
with self.assertWarnsRegex(FutureWarning, msg):
_check(x_type, y_type)
def _check_raises(x_type, y_type, msg):
with self.assertRaisesRegex(ValueError, msg):
_check(x_type, y_type)
# Matching dtypes are always OK
_check(jnp.int32, jnp.int32)
_check(jnp.float32, jnp.float32)
_check(jnp.complex64, jnp.complex64)
# Weakly-typed y values promote.
_check(jnp.int32, int)
_check(jnp.float32, int)
_check(jnp.float32, float)
_check(jnp.complex64, int)
_check(jnp.complex64, float)
_check(jnp.complex64, complex)
# in standard promotion mode, strong types can promote.
msg = "scatter inputs have incompatible types"
with jax.numpy_dtype_promotion('standard'):
_check(jnp.int32, jnp.int16)
_check(jnp.float32, jnp.float16)
_check(jnp.float32, jnp.int32)
_check(jnp.complex64, jnp.int32)
_check(jnp.complex64, jnp.float32)
# TODO(jakevdp): make these _check_raises
_check_warns(jnp.int16, jnp.int32, msg)
_check_warns(jnp.int32, jnp.float32, msg)
_check_warns(jnp.int32, jnp.complex64, msg)
_check_warns(jnp.float16, jnp.float32, msg)
_check_warns(jnp.float32, jnp.complex64, msg)
# in strict promotion mode, strong types do not promote.
msg = "Input dtypes .* have no available implicit dtype promotion path"
with jax.numpy_dtype_promotion('strict'):
_check_raises(jnp.int32, jnp.int16, msg)
_check_raises(jnp.float32, jnp.float16, msg)
_check_raises(jnp.float32, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.float32, msg)
_check_raises(jnp.int16, jnp.int32, msg)
_check_raises(jnp.int32, jnp.float32, msg)
_check_raises(jnp.int32, jnp.complex64, msg)
_check_raises(jnp.float16, jnp.float32, msg)
_check_raises(jnp.float32, jnp.complex64, msg)
def _broadcastable_shapes(shape):
"""Returns all shapes that broadcast to `shape`."""
@ -989,6 +1050,20 @@ def _broadcastable_shapes(shape):
yield list(reversed(x))
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
def _can_cast(from_, to):
return lax.dtype(to) == dtypes.result_type(from_, to)
def _compatible_dtypes(op, dtype, inexact=False):
if op == UpdateOps.ADD:
return [dtype]
elif inexact:
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
else:
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for update_dtype in s(_compatible_dtypes(op, dtype))
for mode in s(MODES))))
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
@ -1083,7 +1158,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
@ -1106,7 +1181,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
@ -1130,7 +1205,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
@ -1157,7 +1232,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_shape in _broadcastable_shapes(update_shape)
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)))
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
rng = jtu.rand_default(self.rng())
@ -1184,7 +1259,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
else [UpdateOps.ADD])
for dtype in s(float_dtypes)
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, unique_indices):
rng = jtu.rand_default(self.rng())