mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

This PR is a follow up to #18881. The changes were generated by adding from __future__ import annotations to the files which did not already have them and running pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Typing tests
|
|
------------
|
|
This test is meant to be both a runtime test and a static type annotation test,
|
|
so it should be checked with pytype/mypy as well as being run with pytest.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
import jax
|
|
from jax._src import core
|
|
from jax._src import test_util as jtu
|
|
from jax._src import typing
|
|
from jax import lax
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src.array import ArrayImpl
|
|
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
|
|
|
|
# DTypeLike is meant to annotate inputs to np.dtype that return
|
|
# a valid JAX dtype, so we test with np.dtype.
|
|
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
|
|
return np.dtype(x)
|
|
|
|
|
|
# ArrayLike is meant to annotate object that are valid as array
|
|
# inputs to jax primitive functions; use convert_element_type here
|
|
# for simplicity.
|
|
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
|
|
return lax.convert_element_type(x, np.result_type(x))
|
|
|
|
|
|
class HasDType:
|
|
dtype: np.dtype
|
|
def __init__(self, dt):
|
|
self.dtype = np.dtype(dt)
|
|
|
|
float32_dtype = np.dtype("float32")
|
|
|
|
# Avoid test parameterization because we want to statically check these annotations.
|
|
class TypingTest(jtu.JaxTestCase):
|
|
|
|
def testDTypeLike(self) -> None:
|
|
out1: typing.DType = dtypelike_to_dtype("float32")
|
|
self.assertEqual(out1, float32_dtype)
|
|
|
|
out2: typing.DType = dtypelike_to_dtype(np.float32)
|
|
self.assertEqual(out2, float32_dtype)
|
|
|
|
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
|
|
self.assertEqual(out3, float32_dtype)
|
|
|
|
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
|
|
self.assertEqual(out4, float32_dtype)
|
|
|
|
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
|
|
self.assertEqual(out5, float32_dtype)
|
|
|
|
def testArrayLike(self) -> None:
|
|
out1: typing.Array = arraylike_to_array(jnp.arange(4))
|
|
self.assertArraysEqual(out1, jnp.arange(4))
|
|
|
|
out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
|
|
self.assertArraysEqual(out2, jnp.arange(4))
|
|
|
|
out3: typing.Array = arraylike_to_array(np.arange(4))
|
|
self.assertArraysEqual(out3, jnp.arange(4), check_dtypes=False)
|
|
|
|
out4: typing.Array = arraylike_to_array(True)
|
|
self.assertArraysEqual(out4, jnp.array(True))
|
|
|
|
out5: typing.Array = arraylike_to_array(1)
|
|
self.assertArraysEqual(out5, jnp.array(1))
|
|
|
|
out6: typing.Array = arraylike_to_array(1.0)
|
|
self.assertArraysEqual(out6, jnp.array(1.0))
|
|
|
|
out7: typing.Array = arraylike_to_array(1 + 1j)
|
|
self.assertArraysEqual(out7, jnp.array(1 + 1j))
|
|
|
|
out8: typing.Array = arraylike_to_array(np.bool_(0))
|
|
self.assertArraysEqual(out8, jnp.bool_(0))
|
|
|
|
out9: typing.Array = arraylike_to_array(np.float32(0))
|
|
self.assertArraysEqual(out9, jnp.float32(0))
|
|
|
|
def testArrayInstanceChecks(self):
|
|
def is_array(x: typing.ArrayLike) -> bool | typing.Array:
|
|
return isinstance(x, typing.Array)
|
|
|
|
x = jnp.arange(5)
|
|
|
|
self.assertFalse(is_array(1.0))
|
|
self.assertTrue(jax.jit(is_array)(1.0))
|
|
self.assertTrue(is_array(x))
|
|
self.assertTrue(jax.jit(is_array)(x))
|
|
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
|
|
|
|
def testAnnotations(self):
|
|
# This test is mainly meant for static type checking: we want to ensure that
|
|
# Tracer and ArrayImpl are valid as array.Array.
|
|
def f(x: Any) -> typing.Array | None:
|
|
if isinstance(x, core.Tracer):
|
|
return x
|
|
elif isinstance(x, ArrayImpl):
|
|
return x
|
|
else:
|
|
return None
|
|
|
|
x = jnp.arange(10)
|
|
y = f(x)
|
|
self.assertArraysEqual(x, y)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# Here we do a number of static type assertions. We purposely don't cover the
|
|
# entire public API here, but rather spot-check some potentially problematic
|
|
# areas. The goals are:
|
|
#
|
|
# - Confirm the correctness of a few basic APIs
|
|
# - Confirm that types from *.pyi files are correctly pulled-in
|
|
# - Confirm that non-trivial overloads are behaving as expected.
|
|
#
|
|
import sys
|
|
if sys.version_info >= (3, 11):
|
|
from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade
|
|
else:
|
|
from typing_extensions import assert_type # pytype: disable=not-supported-yet
|
|
|
|
mat = jnp.zeros((2, 5))
|
|
vals = jnp.arange(5)
|
|
mask = jnp.array([True, False, True, False])
|
|
|
|
assert_type(mat, jax.Array)
|
|
assert_type(vals, jax.Array)
|
|
assert_type(mask, jax.Array)
|
|
|
|
# Functions with non-trivial typing overloads:
|
|
# jnp.linspace
|
|
assert_type(jnp.linspace(0, 10), jax.Array)
|
|
assert_type(jnp.linspace(0, 10, retstep=False), jax.Array)
|
|
assert_type(jnp.linspace(0, 10, retstep=True), tuple[jax.Array, jax.Array])
|
|
|
|
# jnp.where
|
|
assert_type(mask, jax.Array)
|
|
assert_type(jnp.where(mask, 0, 1), jax.Array)
|
|
assert_type(jnp.where(mask), tuple[jax.Array, ...])
|
|
|
|
# jnp.einsum
|
|
assert_type(jnp.einsum('ij', mat), jax.Array)
|
|
assert_type(jnp.einsum('ij,j->i', mat, vals), jax.Array)
|
|
assert_type(jnp.einsum(mat, (0, 0)), jax.Array)
|
|
|
|
# jnp.indices
|
|
assert_type(jnp.indices([2, 3]), jax.Array)
|
|
assert_type(jnp.indices([2, 3], sparse=False), jax.Array)
|
|
assert_type(jnp.indices([2, 3], sparse=True), tuple[jax.Array, ...])
|
|
|
|
# jnp.average
|
|
assert_type(jnp.average(vals), jax.Array)
|
|
assert_type(jnp.average(vals, returned=False), jax.Array)
|
|
assert_type(jnp.average(vals, returned=True), tuple[jax.Array, jax.Array])
|