diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 491c51287..3c0713b1e 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -233,7 +233,9 @@ def _resize_nearest(x, output_shape: core.Shape): m = input_shape[d] n = output_shape[d] offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n) - offsets = jnp.floor(offsets).astype(np.int32) + # TODO(b/206898375): this computation produces the wrong result on + # CPU and GPU when using float64. Use float32 until the bug is fixed. + offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32) indices = [slice(None)] * len(input_shape) indices[d] = offsets x = x[tuple(indices)] diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 105e9b9db..7c0abb375 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3762,7 +3762,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None, msg = "It arose in jax.numpy.arange argument `{}`.".format dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None)) if stop is None and step is None: - if not core.is_dim_size(start): + if not core.is_special_dim_size(start): start = require(start, msg("stop")) start = np.ceil(start).astype(int) diff --git a/jax/core.py b/jax/core.py index c143d02b6..2a2804382 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1376,13 +1376,10 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple raise ValueError(msg) return next(iter(special_handlers), _dimension_handler_int), tuple(canonical) -def is_dim_size(v: Any) -> bool: - """Checks if a value is a DimSize.""" - try: - handler, _ = _dim_handler_and_canonical(v) - return True - except TypeError: - return False +def is_special_dim_size(v: Any) -> bool: + """Checks if a value is a special DimSize.""" + handler = _SPECIAL_DIMENSION_HANDLERS.get(type(v)) + return (handler is not None) def is_constant_dim(d: DimSize) -> bool: handler, ds = _dim_handler_and_canonical(d) diff --git a/tests/image_test.py b/tests/image_test.py index cafd7944a..6b5c2a105 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -90,6 +90,7 @@ class ImageTest(jtu.JaxTestCase): "target_shape": target_shape, "method": method} for dtype in [np.float32] + for target_shape, image_shape in itertools.combinations_with_replacement( [[3, 2], [6, 4], [33, 17], [50, 39]], 2) for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))