Internal change

PiperOrigin-RevId: 474331907
This commit is contained in:
Jake VanderPlas 2022-09-14 10:38:54 -07:00 committed by jax authors
parent 0a5d8e8ec6
commit 13a7034e6a
3 changed files with 4 additions and 3 deletions

View File

@ -25,6 +25,7 @@ load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_li
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_library = native.py_library
pytype_test = native.py_test
pyx_library = _pyx_library
pybind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured

View File

@ -19,6 +19,7 @@ load(
"jax_test_file_visibility",
"py_deps",
"pytype_library",
"pytype_test",
)
licenses(["notice"])
@ -732,8 +733,7 @@ py_test(
],
)
# TODO(jakevdp): make this a py_strict_test
py_test(
pytype_test(
name = "typing_test",
srcs = ["typing_test.py"],
deps = [

View File

@ -110,7 +110,7 @@ class TypingTest(jtu.JaxTestCase):
self.assertTrue(jax.jit(is_array)(1.0))
self.assertTrue(is_array(x))
self.assertTrue(jax.jit(is_array)(x))
self.assertTrue(jax.vmap(is_array)(x).all())
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
if __name__ == '__main__':