mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix jax.image._resize
function (#3805)
This PR fixes a bug in jax.image._resize where the local `method_id` variable may be used without being defined first. This bug can be easily reproduced by passing to `jax.image.resize` parameter `method` a `ResizeMethod` instead of an `str`. By doing this, `method_id` is never defined and the instruction `if method_id == ResizeMethod.NEAREST` raises an error. Currently, this can be easily bypassed assigning parameter `method` a `str`. To fix this bug, it only needs to rename `method_id` to `method`, the same name of the input parameter.
This commit is contained in:
parent
f6b3184f3e
commit
ce14409025
@ -165,10 +165,11 @@ def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
|
||||
f' {shape} vs {image.shape}')
|
||||
raise ValueError(msg)
|
||||
if isinstance(method, str):
|
||||
method_id = ResizeMethod.from_string(method)
|
||||
if method_id == ResizeMethod.NEAREST:
|
||||
method = ResizeMethod.from_string(method)
|
||||
if method == ResizeMethod.NEAREST:
|
||||
return _resize_nearest(image, shape)
|
||||
kernel = _kernels[method_id]
|
||||
assert isinstance(method, ResizeMethod)
|
||||
kernel = _kernels[method]
|
||||
scale = [float(o) / i for o, i in zip(shape, image.shape)]
|
||||
if not jnp.issubdtype(image.dtype, jnp.inexact):
|
||||
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))
|
||||
|
Loading…
x
Reference in New Issue
Block a user