mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[x64] deprecate unsafe type casting in scatter-update operations
This commit is contained in:
parent
ca01d1b411
commit
d2f80ef117
@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
traces as an alternative to the Tensorboard UI.
|
traces as an alternative to the Tensorboard UI.
|
||||||
* Added a `jax.named_scope` context manager that adds profiler metadata to
|
* Added a `jax.named_scope` context manager that adds profiler metadata to
|
||||||
Python programs (similar to `jax.named_call`).
|
Python programs (similar to `jax.named_call`).
|
||||||
|
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
|
||||||
|
dtype casts are deprecated, and now result in a `FutureWarning`.
|
||||||
|
In a future release, this will become an error. An example of an unsafe implicit
|
||||||
|
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
|
||||||
|
silently truncated to `1`.
|
||||||
|
|
||||||
## jaxlib 0.3.11 (Unreleased)
|
## jaxlib 0.3.11 (Unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
|||||||
dtype = lax.dtype(x)
|
dtype = lax.dtype(x)
|
||||||
weak_type = dtypes.is_weakly_typed(x)
|
weak_type = dtypes.is_weakly_typed(x)
|
||||||
|
|
||||||
|
if dtype != dtypes.result_type(x, y):
|
||||||
|
# TODO(jakevdp): change this to an error after the deprecation period.
|
||||||
|
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
|
||||||
|
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
|
||||||
|
"In future JAX releases this will result in an error.",
|
||||||
|
FutureWarning)
|
||||||
|
|
||||||
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||||
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
||||||
normalize_indices=normalize_indices)
|
normalize_indices=normalize_indices)
|
||||||
|
@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
|||||||
|
|
||||||
# This assumes that Q's leaves all have the same dimension in the last
|
# This assumes that Q's leaves all have the same dimension in the last
|
||||||
# axis.
|
# axis.
|
||||||
r = jnp.zeros(tree_leaves(Q)[0].shape[-1])
|
Q0 = tree_leaves(Q)[0]
|
||||||
|
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
|
||||||
q = x
|
q = x
|
||||||
xnorm_scaled = xnorm / jnp.sqrt(2)
|
xnorm_scaled = xnorm / jnp.sqrt(2.0)
|
||||||
|
|
||||||
def body_function(carry):
|
def body_function(carry):
|
||||||
k, q, r, qnorm_scaled = carry
|
k, q, r, qnorm_scaled = carry
|
||||||
@ -368,7 +369,7 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
|||||||
def qnorm(carry):
|
def qnorm(carry):
|
||||||
k, _, q, qnorm_scaled = carry
|
k, _, q, qnorm_scaled = carry
|
||||||
_, qnorm = _safe_normalize(q)
|
_, qnorm = _safe_normalize(q)
|
||||||
qnorm_scaled = qnorm / jnp.sqrt(2)
|
qnorm_scaled = qnorm / jnp.sqrt(2.0)
|
||||||
return (k, False, q, qnorm_scaled)
|
return (k, False, q, qnorm_scaled)
|
||||||
|
|
||||||
init = (k, True, q, qnorm_scaled)
|
init = (k, True, q, qnorm_scaled)
|
||||||
|
@ -77,7 +77,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
|||||||
"max", "get"])
|
"max", "get"])
|
||||||
def test_jit_oob_update(self, update_fn):
|
def test_jit_oob_update(self, update_fn):
|
||||||
def f(x, i):
|
def f(x, i):
|
||||||
return getattr(x.at[i], update_fn)(1.)
|
return getattr(x.at[i], update_fn)(1)
|
||||||
|
|
||||||
f = jax.jit(f)
|
f = jax.jit(f)
|
||||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||||
|
@ -960,7 +960,7 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
|
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))
|
||||||
|
|
||||||
def testIndexingWeakTypes(self):
|
def testIndexingWeakTypes(self):
|
||||||
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True)
|
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)
|
||||||
|
|
||||||
a = x.at[0].set(1.0)
|
a = x.at[0].set(1.0)
|
||||||
self.assertEqual(a.dtype, x.dtype)
|
self.assertEqual(a.dtype, x.dtype)
|
||||||
@ -974,6 +974,67 @@ class IndexingTest(jtu.JaxTestCase):
|
|||||||
self.assertEqual(c.dtype, x.dtype)
|
self.assertEqual(c.dtype, x.dtype)
|
||||||
self.assertTrue(dtypes.is_weakly_typed(c))
|
self.assertTrue(dtypes.is_weakly_typed(c))
|
||||||
|
|
||||||
|
def testIndexingTypePromotion(self):
|
||||||
|
def _check(x_type, y_type):
|
||||||
|
x = jnp.arange(5, dtype=x_type)
|
||||||
|
y = y_type(0)
|
||||||
|
out = x.at[0].set(y)
|
||||||
|
self.assertEqual(x.dtype, out.dtype)
|
||||||
|
|
||||||
|
@jtu.ignore_warning(category=np.ComplexWarning,
|
||||||
|
message="Casting complex values to real")
|
||||||
|
def _check_warns(x_type, y_type, msg):
|
||||||
|
with self.assertWarnsRegex(FutureWarning, msg):
|
||||||
|
_check(x_type, y_type)
|
||||||
|
|
||||||
|
def _check_raises(x_type, y_type, msg):
|
||||||
|
with self.assertRaisesRegex(ValueError, msg):
|
||||||
|
_check(x_type, y_type)
|
||||||
|
|
||||||
|
# Matching dtypes are always OK
|
||||||
|
_check(jnp.int32, jnp.int32)
|
||||||
|
_check(jnp.float32, jnp.float32)
|
||||||
|
_check(jnp.complex64, jnp.complex64)
|
||||||
|
|
||||||
|
# Weakly-typed y values promote.
|
||||||
|
_check(jnp.int32, int)
|
||||||
|
_check(jnp.float32, int)
|
||||||
|
_check(jnp.float32, float)
|
||||||
|
_check(jnp.complex64, int)
|
||||||
|
_check(jnp.complex64, float)
|
||||||
|
_check(jnp.complex64, complex)
|
||||||
|
|
||||||
|
# in standard promotion mode, strong types can promote.
|
||||||
|
msg = "scatter inputs have incompatible types"
|
||||||
|
with jax.numpy_dtype_promotion('standard'):
|
||||||
|
_check(jnp.int32, jnp.int16)
|
||||||
|
_check(jnp.float32, jnp.float16)
|
||||||
|
_check(jnp.float32, jnp.int32)
|
||||||
|
_check(jnp.complex64, jnp.int32)
|
||||||
|
_check(jnp.complex64, jnp.float32)
|
||||||
|
|
||||||
|
# TODO(jakevdp): make these _check_raises
|
||||||
|
_check_warns(jnp.int16, jnp.int32, msg)
|
||||||
|
_check_warns(jnp.int32, jnp.float32, msg)
|
||||||
|
_check_warns(jnp.int32, jnp.complex64, msg)
|
||||||
|
_check_warns(jnp.float16, jnp.float32, msg)
|
||||||
|
_check_warns(jnp.float32, jnp.complex64, msg)
|
||||||
|
|
||||||
|
# in strict promotion mode, strong types do not promote.
|
||||||
|
msg = "Input dtypes .* have no available implicit dtype promotion path"
|
||||||
|
with jax.numpy_dtype_promotion('strict'):
|
||||||
|
_check_raises(jnp.int32, jnp.int16, msg)
|
||||||
|
_check_raises(jnp.float32, jnp.float16, msg)
|
||||||
|
_check_raises(jnp.float32, jnp.int32, msg)
|
||||||
|
_check_raises(jnp.complex64, jnp.int32, msg)
|
||||||
|
_check_raises(jnp.complex64, jnp.float32, msg)
|
||||||
|
|
||||||
|
_check_raises(jnp.int16, jnp.int32, msg)
|
||||||
|
_check_raises(jnp.int32, jnp.float32, msg)
|
||||||
|
_check_raises(jnp.int32, jnp.complex64, msg)
|
||||||
|
_check_raises(jnp.float16, jnp.float32, msg)
|
||||||
|
_check_raises(jnp.float32, jnp.complex64, msg)
|
||||||
|
|
||||||
|
|
||||||
def _broadcastable_shapes(shape):
|
def _broadcastable_shapes(shape):
|
||||||
"""Returns all shapes that broadcast to `shape`."""
|
"""Returns all shapes that broadcast to `shape`."""
|
||||||
@ -989,6 +1050,20 @@ def _broadcastable_shapes(shape):
|
|||||||
yield list(reversed(x))
|
yield list(reversed(x))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
|
||||||
|
def _can_cast(from_, to):
|
||||||
|
return lax.dtype(to) == dtypes.result_type(from_, to)
|
||||||
|
|
||||||
|
|
||||||
|
def _compatible_dtypes(op, dtype, inexact=False):
|
||||||
|
if op == UpdateOps.ADD:
|
||||||
|
return [dtype]
|
||||||
|
elif inexact:
|
||||||
|
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
|
||||||
|
else:
|
||||||
|
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]
|
||||||
|
|
||||||
|
|
||||||
class UpdateOps(enum.Enum):
|
class UpdateOps(enum.Enum):
|
||||||
UPDATE = 0
|
UPDATE = 0
|
||||||
ADD = 1
|
ADD = 1
|
||||||
@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
for op in s(UpdateOps)
|
for op in s(UpdateOps)
|
||||||
for dtype in s(UpdateOps.dtypes(op))
|
for dtype in s(UpdateOps.dtypes(op))
|
||||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
|
for update_dtype in s(_compatible_dtypes(op, dtype))
|
||||||
for mode in s(MODES))))
|
for mode in s(MODES))))
|
||||||
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
|
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op, mode):
|
indexer, op, mode):
|
||||||
@ -1083,7 +1158,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
for op in s(UpdateOps)
|
for op in s(UpdateOps)
|
||||||
for dtype in s(UpdateOps.dtypes(op))
|
for dtype in s(UpdateOps.dtypes(op))
|
||||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||||
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op):
|
indexer, op):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
@ -1106,7 +1181,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
for op in s(UpdateOps)
|
for op in s(UpdateOps)
|
||||||
for dtype in s(UpdateOps.dtypes(op))
|
for dtype in s(UpdateOps.dtypes(op))
|
||||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||||
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
|
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op):
|
indexer, op):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
@ -1130,7 +1205,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
for op in s(UpdateOps)
|
for op in s(UpdateOps)
|
||||||
for dtype in s(UpdateOps.dtypes(op))
|
for dtype in s(UpdateOps.dtypes(op))
|
||||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
|
for update_dtype in s(_compatible_dtypes(op, dtype)))))
|
||||||
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op):
|
indexer, op):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
@ -1157,7 +1232,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
|
||||||
for dtype in float_dtypes
|
for dtype in float_dtypes
|
||||||
for update_shape in _broadcastable_shapes(update_shape)
|
for update_shape in _broadcastable_shapes(update_shape)
|
||||||
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)))
|
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
|
||||||
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op, mode):
|
indexer, op, mode):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
@ -1184,7 +1259,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
|||||||
else [UpdateOps.ADD])
|
else [UpdateOps.ADD])
|
||||||
for dtype in s(float_dtypes)
|
for dtype in s(float_dtypes)
|
||||||
for update_shape in s(_broadcastable_shapes(update_shape))
|
for update_shape in s(_broadcastable_shapes(update_shape))
|
||||||
for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes))))
|
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
|
||||||
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
|
||||||
indexer, op, unique_indices):
|
indexer, op, unique_indices):
|
||||||
rng = jtu.rand_default(self.rng())
|
rng = jtu.rand_default(self.rng())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user