mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
# Copyright 2022 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Typing tests
|
|
------------
|
|
This test is meant to be both a runtime test and a static type annotation test,
|
|
so it should be checked with pytype/mypy as well as being run with pytest.
|
|
"""
|
|
from typing import Any, Optional, Union
|
|
|
|
import jax
|
|
from jax._src import core
|
|
from jax._src import config as jax_config
|
|
from jax._src import test_util as jtu
|
|
from jax._src import typing
|
|
from jax import lax
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src.array import ArrayImpl
|
|
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
|
|
|
|
# DTypeLike is meant to annotate inputs to np.dtype that return
|
|
# a valid JAX dtype, so we test with np.dtype.
|
|
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
|
|
return np.dtype(x)
|
|
|
|
|
|
# ArrayLike is meant to annotate object that are valid as array
|
|
# inputs to jax primitive functions; use convert_element_type here
|
|
# for simplicity.
|
|
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
|
|
return lax.convert_element_type(x, np.result_type(x))
|
|
|
|
|
|
class HasDType:
|
|
dtype: np.dtype
|
|
def __init__(self, dt):
|
|
self.dtype = np.dtype(dt)
|
|
|
|
float32_dtype = np.dtype("float32")
|
|
|
|
|
|
# Avoid test parameterization because we want to statically check these annotations.
|
|
class TypingTest(jtu.JaxTestCase):
|
|
|
|
def testDTypeLike(self) -> None:
|
|
out1: typing.DType = dtypelike_to_dtype("float32")
|
|
self.assertEqual(out1, float32_dtype)
|
|
|
|
out2: typing.DType = dtypelike_to_dtype(np.float32)
|
|
self.assertEqual(out2, float32_dtype)
|
|
|
|
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
|
|
self.assertEqual(out3, float32_dtype)
|
|
|
|
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
|
|
self.assertEqual(out4, float32_dtype)
|
|
|
|
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
|
|
self.assertEqual(out5, float32_dtype)
|
|
|
|
def testArrayLike(self) -> None:
|
|
out1: typing.Array = arraylike_to_array(jnp.arange(4))
|
|
self.assertArraysEqual(out1, jnp.arange(4))
|
|
|
|
out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
|
|
self.assertArraysEqual(out2, jnp.arange(4))
|
|
|
|
out3: typing.Array = arraylike_to_array(np.arange(4))
|
|
self.assertArraysEqual(out3, jnp.arange(4), check_dtypes=False)
|
|
|
|
out4: typing.Array = arraylike_to_array(True)
|
|
self.assertArraysEqual(out4, jnp.array(True))
|
|
|
|
out5: typing.Array = arraylike_to_array(1)
|
|
self.assertArraysEqual(out5, jnp.array(1))
|
|
|
|
out6: typing.Array = arraylike_to_array(1.0)
|
|
self.assertArraysEqual(out6, jnp.array(1.0))
|
|
|
|
out7: typing.Array = arraylike_to_array(1 + 1j)
|
|
self.assertArraysEqual(out7, jnp.array(1 + 1j))
|
|
|
|
out8: typing.Array = arraylike_to_array(np.bool_(0))
|
|
self.assertArraysEqual(out8, jnp.bool_(0))
|
|
|
|
out9: typing.Array = arraylike_to_array(np.float32(0))
|
|
self.assertArraysEqual(out9, jnp.float32(0))
|
|
|
|
def testArrayInstanceChecks(self):
|
|
def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
|
|
return isinstance(x, typing.Array)
|
|
|
|
x = jnp.arange(5)
|
|
|
|
self.assertFalse(is_array(1.0))
|
|
self.assertTrue(jax.jit(is_array)(1.0))
|
|
self.assertTrue(is_array(x))
|
|
self.assertTrue(jax.jit(is_array)(x))
|
|
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
|
|
|
|
def testAnnotations(self):
|
|
# This test is mainly meant for static type checking: we want to ensure that
|
|
# Tracer and ArrayImpl are valid as array.Array.
|
|
with jax_config.jax_array(True):
|
|
def f(x: Any) -> Optional[typing.Array]:
|
|
if isinstance(x, core.Tracer):
|
|
return x
|
|
elif isinstance(x, ArrayImpl):
|
|
return x
|
|
else:
|
|
return None
|
|
|
|
x = jnp.arange(10)
|
|
y = f(x)
|
|
self.assertArraysEqual(x, y)
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|