mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Internal change
PiperOrigin-RevId: 474331907
This commit is contained in:
parent
0a5d8e8ec6
commit
13a7034e6a
@ -25,6 +25,7 @@ load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_li
|
||||
cuda_library = _cuda_library
|
||||
rocm_library = _rocm_library
|
||||
pytype_library = native.py_library
|
||||
pytype_test = native.py_test
|
||||
pyx_library = _pyx_library
|
||||
pybind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
|
@ -19,6 +19,7 @@ load(
|
||||
"jax_test_file_visibility",
|
||||
"py_deps",
|
||||
"pytype_library",
|
||||
"pytype_test",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -732,8 +733,7 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(jakevdp): make this a py_strict_test
|
||||
py_test(
|
||||
pytype_test(
|
||||
name = "typing_test",
|
||||
srcs = ["typing_test.py"],
|
||||
deps = [
|
||||
|
@ -110,7 +110,7 @@ class TypingTest(jtu.JaxTestCase):
|
||||
self.assertTrue(jax.jit(is_array)(1.0))
|
||||
self.assertTrue(is_array(x))
|
||||
self.assertTrue(jax.jit(is_array)(x))
|
||||
self.assertTrue(jax.vmap(is_array)(x).all())
|
||||
self.assertTrue(jnp.all(jax.vmap(is_array)(x)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user