Prior refactoring before the C++ jax.jit. (#4045)

This commit is contained in:
Jean-Baptiste Lespiau 2020-08-18 10:43:52 +02:00 committed by GitHub
parent 2ab6b42a45
commit 8c2ee372f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,6 +21,7 @@ import re
import unittest
import warnings
import weakref
import functools
from absl import logging
from absl.testing import absltest, parameterized
@ -45,6 +46,116 @@ from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class JitTest(jtu.JaxTestCase):
@parameterized.parameters([
# Integer support
(1, 2, 3, 4, 5),
# Numpy array support
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
np.asarray(4, np.int32),
np.asarray(5, np.int32),
),
])
def test_jit_static_args(self, one, two, three, four, five):
side = []
def f(x, y, z, flag=False, flag2=False):
del flag2 # unused
assert flag
side.append(None)
return 100 * x + 10 * y + z
f1 = jax.jit(f, static_argnums=(3, 4))
assert f1(one, two, three, True, False) == 123
assert len(side) == 1
assert f1(one, two, three, True, False) == 123
assert len(side) == 1 # Obvious cache hit.
assert f1(two, one, three, True, False) == 213
assert len(side) == 1 # Should cache hit because same signature.
assert f1(two, one, three, True, True) == 213
assert len(side) == 2
side[:] = []
f2 = jax.jit(f, static_argnums=(0, 2, 3, 4))
assert f2(one, two, three, True, False) == 123
assert len(side) == 1
assert f2(one, three, three, True, False) == 133
assert len(side) == 1
assert f2(two, two, three, True, False) == 223
assert len(side) == 2
assert f2(two, four, three, True, False) == 243
assert len(side) == 2
assert f2(two, four, three, True, True) == 243
assert len(side) == 3
assert f2(two, five, three, True, True) == 253
assert len(side) == 3
@parameterized.parameters([
(1, 2, 3),
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
),
])
def test_jit_kwargs(self, one, two, three):
side = []
def f(x, y, z):
print(x, y, z)
side.append(None)
return 100 * x + 10 * y + z
f = jax.jit(f)
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, z=three) == 123
assert len(side) == 2 # actually recompiles from kwarg
assert f(one, two, z=three) == 123
assert len(side) == 2 # but should still cache
f(one, two, z=np.zeros(3)) # doesn't crash
if FLAGS.jax_enable_x64:
# In the above call, three is of a new type (int64), thus it should
# trigger a new compilation.
assert len(side) == 3
def test_jit_device(self):
device = xb.devices()[-1]
x = jax.jit(lambda x: x, device=device)(3.)
self.assertIsInstance(x, xla.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)
def test_complex_support(self):
self.assertEqual(jax.jit(lambda x: x+1)(1+1j), 2+1j)
def test_jit_with_many_args_works(self):
@jax.jit
def f(args_list):
return sum(args_list)
self.assertEqual(f(list(range(500))), sum(range(500)))
def test_static_argnum_on_method(self):
class A:
@functools.partial(jax.jit, static_argnums=(0,))
def my_func(self, x):
return x+2
A().my_func(3)
class APITest(jtu.JaxTestCase):
def test_grad_argnums(self):
@ -66,64 +177,6 @@ class APITest(jtu.JaxTestCase):
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
def test_jit_static_args(self):
side = []
def f(x, y, z, flag=False, flag2=False):
assert flag
side.append(None)
return 100*x + 10*y + z
f1 = jit(f, static_argnums=(3, 4))
assert f1(1, 2, 3, True, False) == 123
assert len(side) == 1
assert f1(2, 1, 3, True, False) == 213
assert len(side) == 1
assert f1(2, 1, 3, True, True) == 213
assert len(side) == 2
side[:] = []
f2 = jit(f, static_argnums=(0, 2, 3, 4))
assert f2(1, 2, 3, True, False) == 123
assert len(side) == 1
assert f2(1, 3, 3, True, False) == 133
assert len(side) == 1
assert f2(2, 2, 3, True, False) == 223
assert len(side) == 2
assert f2(2, 4, 3, True, False) == 243
assert len(side) == 2
assert f2(2, 4, 3, True, True) == 243
assert len(side) == 3
assert f2(2, 5, 3, True, True) == 253
assert len(side) == 3
def test_jit_kwargs(self):
side = []
def f(x, y, z):
side.append(None)
return 100*x + 10*y + z
f = jit(f)
assert f(1, 2, 3) == 123
assert len(side) == 1
assert f(1, 2, 3) == 123
assert len(side) == 1
assert f(1, 2, z=3) == 123
assert len(side) == 2 # actually recompiles from kwarg
assert f(1, 2, z=3) == 123
assert len(side) == 2 # but should still cache
f(1, 2, z=np.zeros(3)) # doesn't crash
def test_jit_with_many_args_works(self):
@jit
def f(args_list):
return sum(args_list)
self.assertEqual(f(list(range(500))), sum(range(500)))
def test_grad_of_jit(self):
side = []