mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Prior refactoring before the C++ jax.jit. (#4045)
This commit is contained in:
parent
2ab6b42a45
commit
8c2ee372f4
@ -21,6 +21,7 @@ import re
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
import functools
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -45,6 +46,116 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
class JitTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
# Integer support
|
||||
(1, 2, 3, 4, 5),
|
||||
# Numpy array support
|
||||
(
|
||||
np.asarray(1, np.int32),
|
||||
np.asarray(2, np.int32),
|
||||
np.asarray(3, np.int32),
|
||||
np.asarray(4, np.int32),
|
||||
np.asarray(5, np.int32),
|
||||
),
|
||||
])
|
||||
def test_jit_static_args(self, one, two, three, four, five):
|
||||
side = []
|
||||
|
||||
def f(x, y, z, flag=False, flag2=False):
|
||||
del flag2 # unused
|
||||
assert flag
|
||||
side.append(None)
|
||||
return 100 * x + 10 * y + z
|
||||
|
||||
f1 = jax.jit(f, static_argnums=(3, 4))
|
||||
assert f1(one, two, three, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f1(one, two, three, True, False) == 123
|
||||
assert len(side) == 1 # Obvious cache hit.
|
||||
assert f1(two, one, three, True, False) == 213
|
||||
assert len(side) == 1 # Should cache hit because same signature.
|
||||
assert f1(two, one, three, True, True) == 213
|
||||
assert len(side) == 2
|
||||
|
||||
side[:] = []
|
||||
f2 = jax.jit(f, static_argnums=(0, 2, 3, 4))
|
||||
assert f2(one, two, three, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f2(one, three, three, True, False) == 133
|
||||
assert len(side) == 1
|
||||
assert f2(two, two, three, True, False) == 223
|
||||
assert len(side) == 2
|
||||
assert f2(two, four, three, True, False) == 243
|
||||
assert len(side) == 2
|
||||
assert f2(two, four, three, True, True) == 243
|
||||
assert len(side) == 3
|
||||
assert f2(two, five, three, True, True) == 253
|
||||
assert len(side) == 3
|
||||
|
||||
@parameterized.parameters([
|
||||
(1, 2, 3),
|
||||
(
|
||||
np.asarray(1, np.int32),
|
||||
np.asarray(2, np.int32),
|
||||
np.asarray(3, np.int32),
|
||||
),
|
||||
])
|
||||
def test_jit_kwargs(self, one, two, three):
|
||||
side = []
|
||||
|
||||
def f(x, y, z):
|
||||
print(x, y, z)
|
||||
side.append(None)
|
||||
return 100 * x + 10 * y + z
|
||||
|
||||
f = jax.jit(f)
|
||||
assert f(one, two, three) == 123
|
||||
assert len(side) == 1
|
||||
assert f(one, two, three) == 123
|
||||
assert len(side) == 1
|
||||
|
||||
assert f(one, two, z=three) == 123
|
||||
assert len(side) == 2 # actually recompiles from kwarg
|
||||
assert f(one, two, z=three) == 123
|
||||
assert len(side) == 2 # but should still cache
|
||||
|
||||
f(one, two, z=np.zeros(3)) # doesn't crash
|
||||
if FLAGS.jax_enable_x64:
|
||||
# In the above call, three is of a new type (int64), thus it should
|
||||
# trigger a new compilation.
|
||||
assert len(side) == 3
|
||||
|
||||
def test_jit_device(self):
|
||||
device = xb.devices()[-1]
|
||||
x = jax.jit(lambda x: x, device=device)(3.)
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertEqual(x.device_buffer.device(), device)
|
||||
|
||||
def test_complex_support(self):
|
||||
self.assertEqual(jax.jit(lambda x: x+1)(1+1j), 2+1j)
|
||||
|
||||
def test_jit_with_many_args_works(self):
|
||||
|
||||
@jax.jit
|
||||
def f(args_list):
|
||||
return sum(args_list)
|
||||
|
||||
self.assertEqual(f(list(range(500))), sum(range(500)))
|
||||
|
||||
def test_static_argnum_on_method(self):
|
||||
|
||||
class A:
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(0,))
|
||||
def my_func(self, x):
|
||||
return x+2
|
||||
|
||||
A().my_func(3)
|
||||
|
||||
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_argnums(self):
|
||||
@ -66,64 +177,6 @@ class APITest(jtu.JaxTestCase):
|
||||
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
|
||||
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
|
||||
|
||||
def test_jit_static_args(self):
|
||||
side = []
|
||||
|
||||
def f(x, y, z, flag=False, flag2=False):
|
||||
assert flag
|
||||
side.append(None)
|
||||
return 100*x + 10*y + z
|
||||
|
||||
f1 = jit(f, static_argnums=(3, 4))
|
||||
assert f1(1, 2, 3, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f1(2, 1, 3, True, False) == 213
|
||||
assert len(side) == 1
|
||||
assert f1(2, 1, 3, True, True) == 213
|
||||
assert len(side) == 2
|
||||
|
||||
side[:] = []
|
||||
f2 = jit(f, static_argnums=(0, 2, 3, 4))
|
||||
assert f2(1, 2, 3, True, False) == 123
|
||||
assert len(side) == 1
|
||||
assert f2(1, 3, 3, True, False) == 133
|
||||
assert len(side) == 1
|
||||
assert f2(2, 2, 3, True, False) == 223
|
||||
assert len(side) == 2
|
||||
assert f2(2, 4, 3, True, False) == 243
|
||||
assert len(side) == 2
|
||||
assert f2(2, 4, 3, True, True) == 243
|
||||
assert len(side) == 3
|
||||
assert f2(2, 5, 3, True, True) == 253
|
||||
assert len(side) == 3
|
||||
|
||||
def test_jit_kwargs(self):
|
||||
side = []
|
||||
|
||||
def f(x, y, z):
|
||||
side.append(None)
|
||||
return 100*x + 10*y + z
|
||||
|
||||
f = jit(f)
|
||||
assert f(1, 2, 3) == 123
|
||||
assert len(side) == 1
|
||||
assert f(1, 2, 3) == 123
|
||||
assert len(side) == 1
|
||||
|
||||
assert f(1, 2, z=3) == 123
|
||||
assert len(side) == 2 # actually recompiles from kwarg
|
||||
assert f(1, 2, z=3) == 123
|
||||
assert len(side) == 2 # but should still cache
|
||||
|
||||
f(1, 2, z=np.zeros(3)) # doesn't crash
|
||||
|
||||
def test_jit_with_many_args_works(self):
|
||||
@jit
|
||||
def f(args_list):
|
||||
return sum(args_list)
|
||||
|
||||
self.assertEqual(f(list(range(500))), sum(range(500)))
|
||||
|
||||
def test_grad_of_jit(self):
|
||||
side = []
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user