Added workaround for bug in XLA

This commit is contained in:
George Necula 2021-11-18 10:23:53 +02:00
parent 75155f5eda
commit 3715fcb930
4 changed files with 9 additions and 9 deletions

View File

@ -233,7 +233,9 @@ def _resize_nearest(x, output_shape: core.Shape):
m = input_shape[d]
n = output_shape[d]
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
offsets = jnp.floor(offsets).astype(np.int32)
# TODO(b/206898375): this computation produces the wrong result on
# CPU and GPU when using float64. Use float32 until the bug is fixed.
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
indices = [slice(None)] * len(input_shape)
indices[d] = offsets
x = x[tuple(indices)]

View File

@ -3762,7 +3762,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
msg = "It arose in jax.numpy.arange argument `{}`.".format
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
if stop is None and step is None:
if not core.is_dim_size(start):
if not core.is_special_dim_size(start):
start = require(start, msg("stop"))
start = np.ceil(start).astype(int)

View File

@ -1376,13 +1376,10 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple
raise ValueError(msg)
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
def is_dim_size(v: Any) -> bool:
"""Checks if a value is a DimSize."""
try:
handler, _ = _dim_handler_and_canonical(v)
return True
except TypeError:
return False
def is_special_dim_size(v: Any) -> bool:
"""Checks if a value is a special DimSize."""
handler = _SPECIAL_DIMENSION_HANDLERS.get(type(v))
return (handler is not None)
def is_constant_dim(d: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d)

View File

@ -90,6 +90,7 @@ class ImageTest(jtu.JaxTestCase):
"target_shape": target_shape,
"method": method}
for dtype in [np.float32]
for target_shape, image_shape in itertools.combinations_with_replacement(
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))