From 13a7034e6a35e56991b4c02647c40aececc96b75 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Sep 2022 10:38:54 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 474331907 --- jaxlib/jax.bzl | 1 + tests/BUILD | 4 ++-- tests/typing_test.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 9577cf43c..bc20e6dcf 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -25,6 +25,7 @@ load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_li cuda_library = _cuda_library rocm_library = _rocm_library pytype_library = native.py_library +pytype_test = native.py_test pyx_library = _pyx_library pybind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured diff --git a/tests/BUILD b/tests/BUILD index d32eae56f..c6347dcd9 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -19,6 +19,7 @@ load( "jax_test_file_visibility", "py_deps", "pytype_library", + "pytype_test", ) licenses(["notice"]) @@ -732,8 +733,7 @@ py_test( ], ) -# TODO(jakevdp): make this a py_strict_test -py_test( +pytype_test( name = "typing_test", srcs = ["typing_test.py"], deps = [ diff --git a/tests/typing_test.py b/tests/typing_test.py index 9eb87c1e6..165d7b3aa 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -110,7 +110,7 @@ class TypingTest(jtu.JaxTestCase): self.assertTrue(jax.jit(is_array)(1.0)) self.assertTrue(is_array(x)) self.assertTrue(jax.jit(is_array)(x)) - self.assertTrue(jax.vmap(is_array)(x).all()) + self.assertTrue(jnp.all(jax.vmap(is_array)(x))) if __name__ == '__main__':