mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added workaround for bug in XLA
This commit is contained in:
parent
75155f5eda
commit
3715fcb930
@ -233,7 +233,9 @@ def _resize_nearest(x, output_shape: core.Shape):
|
||||
m = input_shape[d]
|
||||
n = output_shape[d]
|
||||
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
|
||||
offsets = jnp.floor(offsets).astype(np.int32)
|
||||
# TODO(b/206898375): this computation produces the wrong result on
|
||||
# CPU and GPU when using float64. Use float32 until the bug is fixed.
|
||||
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
|
||||
indices = [slice(None)] * len(input_shape)
|
||||
indices[d] = offsets
|
||||
x = x[tuple(indices)]
|
||||
|
@ -3762,7 +3762,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
|
||||
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
||||
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
|
||||
if stop is None and step is None:
|
||||
if not core.is_dim_size(start):
|
||||
if not core.is_special_dim_size(start):
|
||||
start = require(start, msg("stop"))
|
||||
start = np.ceil(start).astype(int)
|
||||
|
||||
|
11
jax/core.py
11
jax/core.py
@ -1376,13 +1376,10 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple
|
||||
raise ValueError(msg)
|
||||
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
|
||||
|
||||
def is_dim_size(v: Any) -> bool:
|
||||
"""Checks if a value is a DimSize."""
|
||||
try:
|
||||
handler, _ = _dim_handler_and_canonical(v)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
def is_special_dim_size(v: Any) -> bool:
|
||||
"""Checks if a value is a special DimSize."""
|
||||
handler = _SPECIAL_DIMENSION_HANDLERS.get(type(v))
|
||||
return (handler is not None)
|
||||
|
||||
def is_constant_dim(d: DimSize) -> bool:
|
||||
handler, ds = _dim_handler_and_canonical(d)
|
||||
|
@ -90,6 +90,7 @@ class ImageTest(jtu.JaxTestCase):
|
||||
"target_shape": target_shape,
|
||||
"method": method}
|
||||
for dtype in [np.float32]
|
||||
|
||||
for target_shape, image_shape in itertools.combinations_with_replacement(
|
||||
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
|
||||
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user