mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added workaround for bug in XLA
This commit is contained in:
parent
75155f5eda
commit
3715fcb930
@ -233,7 +233,9 @@ def _resize_nearest(x, output_shape: core.Shape):
|
|||||||
m = input_shape[d]
|
m = input_shape[d]
|
||||||
n = output_shape[d]
|
n = output_shape[d]
|
||||||
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
|
offsets = (jnp.arange(n) + 0.5) * core.dimension_as_value(m) / core.dimension_as_value(n)
|
||||||
offsets = jnp.floor(offsets).astype(np.int32)
|
# TODO(b/206898375): this computation produces the wrong result on
|
||||||
|
# CPU and GPU when using float64. Use float32 until the bug is fixed.
|
||||||
|
offsets = jnp.floor(offsets.astype(np.float32)).astype(np.int32)
|
||||||
indices = [slice(None)] * len(input_shape)
|
indices = [slice(None)] * len(input_shape)
|
||||||
indices[d] = offsets
|
indices[d] = offsets
|
||||||
x = x[tuple(indices)]
|
x = x[tuple(indices)]
|
||||||
|
@ -3762,7 +3762,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
|
|||||||
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
msg = "It arose in jax.numpy.arange argument `{}`.".format
|
||||||
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
|
dtype = dtype or _dtype(start, *(x for x in [stop, step] if x is not None))
|
||||||
if stop is None and step is None:
|
if stop is None and step is None:
|
||||||
if not core.is_dim_size(start):
|
if not core.is_special_dim_size(start):
|
||||||
start = require(start, msg("stop"))
|
start = require(start, msg("stop"))
|
||||||
start = np.ceil(start).astype(int)
|
start = np.ceil(start).astype(int)
|
||||||
|
|
||||||
|
11
jax/core.py
11
jax/core.py
@ -1376,13 +1376,10 @@ def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
|
return next(iter(special_handlers), _dimension_handler_int), tuple(canonical)
|
||||||
|
|
||||||
def is_dim_size(v: Any) -> bool:
|
def is_special_dim_size(v: Any) -> bool:
|
||||||
"""Checks if a value is a DimSize."""
|
"""Checks if a value is a special DimSize."""
|
||||||
try:
|
handler = _SPECIAL_DIMENSION_HANDLERS.get(type(v))
|
||||||
handler, _ = _dim_handler_and_canonical(v)
|
return (handler is not None)
|
||||||
return True
|
|
||||||
except TypeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_constant_dim(d: DimSize) -> bool:
|
def is_constant_dim(d: DimSize) -> bool:
|
||||||
handler, ds = _dim_handler_and_canonical(d)
|
handler, ds = _dim_handler_and_canonical(d)
|
||||||
|
@ -90,6 +90,7 @@ class ImageTest(jtu.JaxTestCase):
|
|||||||
"target_shape": target_shape,
|
"target_shape": target_shape,
|
||||||
"method": method}
|
"method": method}
|
||||||
for dtype in [np.float32]
|
for dtype in [np.float32]
|
||||||
|
|
||||||
for target_shape, image_shape in itertools.combinations_with_replacement(
|
for target_shape, image_shape in itertools.combinations_with_replacement(
|
||||||
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
|
[[3, 2], [6, 4], [33, 17], [50, 39]], 2)
|
||||||
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
|
for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user