mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Tests: require pillow>=9.1.0 & remove backward compatibility
This commit is contained in:
parent
88f2b5e86d
commit
f7731c8a29
@ -1,7 +1,7 @@
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
matplotlib
|
||||
pillow>=8.3.1
|
||||
pillow>=9.1.0
|
||||
pytest-benchmark
|
||||
pytest-xdist
|
||||
wheel
|
||||
|
@ -42,16 +42,8 @@ except ImportError:
|
||||
|
||||
try:
|
||||
from PIL import Image as PIL_Image
|
||||
# TODO(jakevdp): remove this try/except when pillow requirement is updated.
|
||||
try:
|
||||
# pillow >=9.1.0
|
||||
PIL_Resampling = PIL_Image.Resampling
|
||||
except AttributeError:
|
||||
# pillow <9.1.0
|
||||
PIL_Resampling = PIL_Image
|
||||
except ImportError:
|
||||
PIL_Image = None
|
||||
PIL_Resampling = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -114,10 +106,10 @@ class ImageTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: (rng(image_shape, dtype),)
|
||||
def pil_fn(x):
|
||||
pil_methods = {
|
||||
"nearest": PIL_Resampling.NEAREST,
|
||||
"bilinear": PIL_Resampling.BILINEAR,
|
||||
"bicubic": PIL_Resampling.BICUBIC,
|
||||
"lanczos3": PIL_Resampling.LANCZOS,
|
||||
"nearest": PIL_Image.Resampling.NEAREST,
|
||||
"bilinear": PIL_Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL_Image.Resampling.BICUBIC,
|
||||
"lanczos3": PIL_Image.Resampling.LANCZOS,
|
||||
}
|
||||
img = PIL_Image.fromarray(x.astype(np.float32))
|
||||
out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user