Fix jax.image._resize function (#3805)

This PR fixes a bug in jax.image._resize where the local `method_id`
variable may be used without being defined first.
This bug can be easily reproduced by passing to `jax.image.resize`
parameter `method` a `ResizeMethod` instead of an `str`. By doing
this, `method_id` is never defined and the instruction
`if method_id == ResizeMethod.NEAREST` raises an error. Currently,
this can be easily bypassed assigning parameter `method` a `str`.
To fix this bug, it only needs to rename `method_id` to `method`,
the same name of the input parameter.
This commit is contained in:
Claudio Fantacci 2020-07-20 21:15:40 +01:00 committed by GitHub
parent f6b3184f3e
commit ce14409025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -165,10 +165,11 @@ def _resize(image, shape: Sequence[int], method: Union[str, ResizeMethod],
f' {shape} vs {image.shape}')
raise ValueError(msg)
if isinstance(method, str):
method_id = ResizeMethod.from_string(method)
if method_id == ResizeMethod.NEAREST:
method = ResizeMethod.from_string(method)
if method == ResizeMethod.NEAREST:
return _resize_nearest(image, shape)
kernel = _kernels[method_id]
assert isinstance(method, ResizeMethod)
kernel = _kernels[method]
scale = [float(o) / i for o, i in zip(shape, image.shape)]
if not jnp.issubdtype(image.dtype, jnp.inexact):
image = lax.convert_element_type(image, jnp.result_type(image, jnp.float32))