mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add typing_test.py
This commit is contained in:
parent
dc4922fd08
commit
b3c31ebe7d
@ -17,7 +17,7 @@ repos:
|
||||
rev: 'v0.971'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: jax/
|
||||
files: (jax/|tests/typing_test\.py)
|
||||
additional_dependencies: [types-requests==2.27.16, jaxlib==0.3.5]
|
||||
|
||||
- repo: https://github.com/mwouts/jupytext
|
||||
|
@ -35,8 +35,12 @@ class HasDtypeAttribute(Protocol):
|
||||
|
||||
Dtype = np.dtype
|
||||
|
||||
# Any is here to allow scalar types like np.int32.
|
||||
# TODO(jakevdp) figure out how to specify these more strictly.
|
||||
# DtypeLike is meant to annotate inputs to np.dtype that return
|
||||
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
||||
# because JAX doesn't support objects or structured dtypes.
|
||||
# It does not include JAX dtype extensions such as KeyType and others.
|
||||
# For now, we use Any to allow scalar types like np.int32 & jnp.int32.
|
||||
# TODO(jakevdp) specify these more strictly.
|
||||
DtypeLike = Union[Any, str, np.dtype, HasDtypeAttribute]
|
||||
|
||||
# Shapes are tuples of dimension sizes, which are normally integers. We allow
|
||||
|
10
tests/BUILD
10
tests/BUILD
@ -732,6 +732,16 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(jakevdp): make this a py_strict_test
|
||||
py_test(
|
||||
name = "typing_test",
|
||||
srcs = ["typing_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.py"],
|
||||
|
117
tests/typing_test.py
Normal file
117
tests/typing_test.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Typing tests
|
||||
------------
|
||||
This test is meant to be both a runtime test and a static type annotation test,
|
||||
so it should be checked with pytype/mypy as well as being run with pytest.
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import typing
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
|
||||
# DtypeLike is meant to annotate inputs to np.dtype that return
|
||||
# a valid JAX dtype, so we test with np.dtype.
|
||||
def dtypelike_to_dtype(x: typing.DtypeLike) -> typing.Dtype:
|
||||
return np.dtype(x)
|
||||
|
||||
|
||||
# ArrayLike is meant to annotate object that are valid as array
|
||||
# inputs to jax primitive functions; use convert_element_type here
|
||||
# for simplicity.
|
||||
def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
|
||||
return lax.convert_element_type(x, np.result_type(x))
|
||||
|
||||
|
||||
class HasDtype:
|
||||
dtype: np.dtype
|
||||
def __init__(self, dt):
|
||||
self.dtype = np.dtype(dt)
|
||||
|
||||
float32_dtype = np.dtype("float32")
|
||||
|
||||
|
||||
# Avoid test parameterization because we want to statically check these annotations.
|
||||
class TypingTest(jtu.JaxTestCase):
|
||||
|
||||
def testDtypeLike(self) -> None:
|
||||
out1: typing.Dtype = dtypelike_to_dtype("float32")
|
||||
self.assertEqual(out1, float32_dtype)
|
||||
|
||||
out2: typing.Dtype = dtypelike_to_dtype(np.float32)
|
||||
self.assertEqual(out2, float32_dtype)
|
||||
|
||||
out3: typing.Dtype = dtypelike_to_dtype(jnp.float32)
|
||||
self.assertEqual(out3, float32_dtype)
|
||||
|
||||
out4: typing.Dtype = dtypelike_to_dtype(np.dtype('float32'))
|
||||
self.assertEqual(out4, float32_dtype)
|
||||
|
||||
out5: typing.Dtype = dtypelike_to_dtype(HasDtype("float32"))
|
||||
self.assertEqual(out5, float32_dtype)
|
||||
|
||||
def testArrayLike(self) -> None:
|
||||
out1: typing.Array = arraylike_to_array(jnp.arange(4))
|
||||
self.assertArraysEqual(out1, jnp.arange(4))
|
||||
|
||||
out2: typing.Array = jax.jit(arraylike_to_array)(jnp.arange(4))
|
||||
self.assertArraysEqual(out2, jnp.arange(4))
|
||||
|
||||
out3: typing.Array = arraylike_to_array(np.arange(4))
|
||||
self.assertArraysEqual(out3, jnp.arange(4))
|
||||
|
||||
out4: typing.Array = arraylike_to_array(True)
|
||||
self.assertArraysEqual(out4, jnp.array(True))
|
||||
|
||||
out5: typing.Array = arraylike_to_array(1)
|
||||
self.assertArraysEqual(out5, jnp.array(1))
|
||||
|
||||
out6: typing.Array = arraylike_to_array(1.0)
|
||||
self.assertArraysEqual(out6, jnp.array(1.0))
|
||||
|
||||
out7: typing.Array = arraylike_to_array(1 + 1j)
|
||||
self.assertArraysEqual(out7, jnp.array(1 + 1j))
|
||||
|
||||
out8: typing.Array = arraylike_to_array(np.bool_(0))
|
||||
self.assertArraysEqual(out8, jnp.bool_(0))
|
||||
|
||||
out9: typing.Array = arraylike_to_array(np.float32(0))
|
||||
self.assertArraysEqual(out9, jnp.float32(0))
|
||||
|
||||
def testArrayInstanceChecks(self):
|
||||
# TODO(jakevdp): enable this test when `typing.Array` instance checks are implemented.
|
||||
self.skipTest("Test is broken for now.")
|
||||
|
||||
def is_array(x: typing.ArrayLike) -> Union[bool, typing.Array]:
|
||||
return isinstance(x, typing.Array)
|
||||
|
||||
x = jnp.arange(5)
|
||||
|
||||
self.assertFalse(is_array(1.0))
|
||||
self.assertTrue(jax.jit(is_array)(1.0))
|
||||
self.assertTrue(is_array(x))
|
||||
self.assertTrue(jax.jit(is_array)(x))
|
||||
self.assertTrue(jax.vmap(is_array)(x).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user