rocm_jax/tests/typing_test.py
Sergei Lebedev f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00

185 lines
5.8 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Typing tests
------------
This test is meant to be both a runtime test and a static type annotation test,
so it should be checked with pytype/mypy as well as being run with pytest.
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import jax
from jax._src import core
from jax._src import test_util as jtu
from jax._src import typing
from jax import lax
import jax.numpy as jnp
from jax._src.array import ArrayImpl
from absl.testing import absltest
import numpy as np
# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype, so we test with np.dtype.
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
return np.dtype(x)
# ArrayLike is meant to annotate object that are valid as array
# inputs to jax primitive functions; use convert_element_type here
# for simplicity.
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
return lax.convert_element_type(x, np.result_type(x))
class HasDType:
dtype: np.dtype
def __init__(self, dt):
self.dtype = np.dtype(dt)
float32_dtype = np.dtype("float32")
# Avoid test parameterization because we want to statically check these annotations.
class TypingTest(jtu.JaxTestCase):
def testDTypeLike(self) -> None:
out1: typing.DType = dtypelike_to_dtype("float32")
self.assertEqual(out1, float32_dtype)
out2: typing.DType = dtypelike_to_dtype(np.float32)
self.assertEqual(out2, float32_dtype)
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
self.assertEqual(out3, float32_dtype)
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
self.assertEqual(out4, float32_dtype)
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
self.assertEqual(out5, float32_dtype)
def testArrayLike(self) -> None:
out1: typing.Array = arraylike_to_array(jnp.arange(4))
self.assertArraysEqual(out1, jnp.arange(4))
out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
self.assertArraysEqual(out2, jnp.arange(4))
out3: typing.Array = arraylike_to_array(np.arange(4))
self.assertArraysEqual(out3, jnp.arange(4), check_dtypes=False)
out4: typing.Array = arraylike_to_array(True)
self.assertArraysEqual(out4, jnp.array(True))
out5: typing.Array = arraylike_to_array(1)
self.assertArraysEqual(out5, jnp.array(1))
out6: typing.Array = arraylike_to_array(1.0)
self.assertArraysEqual(out6, jnp.array(1.0))
out7: typing.Array = arraylike_to_array(1 + 1j)
self.assertArraysEqual(out7, jnp.array(1 + 1j))
out8: typing.Array = arraylike_to_array(np.bool_(0))
self.assertArraysEqual(out8, jnp.bool_(0))
out9: typing.Array = arraylike_to_array(np.float32(0))
self.assertArraysEqual(out9, jnp.float32(0))
def testArrayInstanceChecks(self):
def is_array(x: typing.ArrayLike) -> bool | typing.Array:
return isinstance(x, typing.Array)
x = jnp.arange(5)
self.assertFalse(is_array(1.0))
self.assertTrue(jax.jit(is_array)(1.0))
self.assertTrue(is_array(x))
self.assertTrue(jax.jit(is_array)(x))
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
def testAnnotations(self):
# This test is mainly meant for static type checking: we want to ensure that
# Tracer and ArrayImpl are valid as array.Array.
def f(x: Any) -> typing.Array | None:
if isinstance(x, core.Tracer):
return x
elif isinstance(x, ArrayImpl):
return x
else:
return None
x = jnp.arange(10)
y = f(x)
self.assertArraysEqual(x, y)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
if TYPE_CHECKING:
# Here we do a number of static type assertions. We purposely don't cover the
# entire public API here, but rather spot-check some potentially problematic
# areas. The goals are:
#
# - Confirm the correctness of a few basic APIs
# - Confirm that types from *.pyi files are correctly pulled-in
# - Confirm that non-trivial overloads are behaving as expected.
#
import sys
if sys.version_info >= (3, 11):
from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade
else:
from typing_extensions import assert_type # pytype: disable=not-supported-yet
mat = jnp.zeros((2, 5))
vals = jnp.arange(5)
mask = jnp.array([True, False, True, False])
assert_type(mat, jax.Array)
assert_type(vals, jax.Array)
assert_type(mask, jax.Array)
# Functions with non-trivial typing overloads:
# jnp.linspace
assert_type(jnp.linspace(0, 10), jax.Array)
assert_type(jnp.linspace(0, 10, retstep=False), jax.Array)
assert_type(jnp.linspace(0, 10, retstep=True), tuple[jax.Array, jax.Array])
# jnp.where
assert_type(mask, jax.Array)
assert_type(jnp.where(mask, 0, 1), jax.Array)
assert_type(jnp.where(mask), tuple[jax.Array, ...])
# jnp.einsum
assert_type(jnp.einsum('ij', mat), jax.Array)
assert_type(jnp.einsum('ij,j->i', mat, vals), jax.Array)
assert_type(jnp.einsum(mat, (0, 0)), jax.Array)
# jnp.indices
assert_type(jnp.indices([2, 3]), jax.Array)
assert_type(jnp.indices([2, 3], sparse=False), jax.Array)
assert_type(jnp.indices([2, 3], sparse=True), tuple[jax.Array, ...])
# jnp.average
assert_type(jnp.average(vals), jax.Array)
assert_type(jnp.average(vals, returned=False), jax.Array)
assert_type(jnp.average(vals, returned=True), tuple[jax.Array, jax.Array])