jax.mask and jax.shapecheck are being deprecated

Issue: #11557
PiperOrigin-RevId: 462315754
This commit is contained in:
George Necula 2022-07-21 00:08:48 -07:00 committed by jax authors
parent ba7ded4331
commit 07fcf79324
6 changed files with 9 additions and 896 deletions

View File

@ -28,6 +28,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
following a similar deprecation in {func}`scipy.linalg.solve`.
* Deprecations:
* {func}`jax.mask` {func}`jax.shapecheck` are being deprecated.
See {jax-issue}`#11557`.
## jaxlib 0.3.15 (Unreleased)

View File

@ -2249,6 +2249,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
warn("`jax.mask` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
unique_ids = masking.UniqueIds()
@ -2292,6 +2294,8 @@ def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
@curry
def shapecheck(in_shapes, out_shape, fun: Callable):
warn("`jax.shapecheck` is deprecated and will be removed soon. ",
DeprecationWarning)
_check_callable(fun)
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masking is **DEPRECATED** and is being removed."""
from contextlib import contextmanager
from collections import Counter, namedtuple
from functools import partial, reduce

View File

@ -464,11 +464,6 @@ jax_test(
},
)
jax_test(
name = "masking_test",
srcs = ["masking_test.py"],
)
jax_test(
name = "metadata_test",
srcs = ["metadata_test.py"],

View File

@ -1859,67 +1859,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
what: ct_b
1.""", testing_stream.output)
def test_tap_mask(self):
@partial(jax.mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
three_x = hcb.id_print((x, 2 * x), result=3 * x, what="x",
output_stream=testing_stream)
return jnp.sum(three_x)
x = np.arange(5.)
self.assertAllClose(9., padded_sum([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# With VMAP
xv = np.arange(10.).reshape((2, 5)) # logical_shape = 5
self.assertAllClose(
np.array([9., 78.]),
# batch_size = 2, n=3 and 4 for the two elements
jax.vmap(padded_sum)([xv],
dict(n=np.array([3., 4.]))))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), ('batch', {'batch_dims': (0, 0, 0, 0)})] what: x
( ( [[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]]
[[ 0. 2. 4. 6. 8.]
[10. 12. 14. 16. 18.]] )
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
testing_stream.reset()
# With JVP
self.assertAllClose((9., 0.9),
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
(x,), (x * 0.1,)))
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
else:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# Now with JIT
self.assertAllClose(9., jax.jit(padded_sum)([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
def test_tap_callback_delay(self):
hcb.callback_extra = lambda dev: time.sleep(1)

View File

@ -1,830 +0,0 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import unittest
import numpy as np
from absl.testing import absltest, parameterized
from jax import lax
from jax import core
from jax._src import test_util as jtu
from jax.config import config
from jax._src.util import safe_map, safe_zip
from jax.tree_util import tree_flatten
import jax.numpy as jnp
from jax.scipy.special import expit
from jax import mask, vmap, jit, grad, shapecheck, make_jaxpr
from jax.interpreters.masking import (
shape_as_value, ShapeError, parse_spec, Poly, Mon, finalize_spec,
eval_poly_shape, remap_ids, UniqueIds, UndefinedPoly)
config.parse_flags_with_absl()
map = safe_map
zip = safe_zip
# TODO:
# These should be only the 'manual' tests for masking.
# Move the more exhaustive, systematic tests into lax_test.py.
def constant_poly(c):
return Poly({Mon(): c})
class PolyTest(jtu.JaxTestCase):
@parameterized.parameters([
['(m, n)', 'ShapeSpec(m, n)'],
['(m * n)', 'ShapeSpec(m n)'],
['m * n', 'ShapeSpec(m n)'],
['(m * n,)', 'ShapeSpec(m n)'],
['(3, m)', 'ShapeSpec(3, m)'],
['(10, m)', 'ShapeSpec(10, m)'],
['(-10, m)', 'ShapeSpec(-10, m)'],
['(3 * m)', 'ShapeSpec(3 m)'],
['m', 'ShapeSpec(m)'],
['', 'ShapeSpec()'],
['n + -1*n', 'ShapeSpec(0)'],
['m + n', 'ShapeSpec(m + n)'],
['m + n * k', 'ShapeSpec(k n + m)'],
['m + 3 * k', 'ShapeSpec(3 k + m)'],
['-3 + k + k * k', 'ShapeSpec(k^2 + k + -3)'],
['', 'ShapeSpec()'],
['_', 'ShapeSpec(_)'],
])
def test_parse_spec(self, spec, ans):
self.assertEqual(str(parse_spec(spec)), ans)
self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)
def test_Poly_equal(self):
assert constant_poly(3) == 3
assert np.array(3, np.int64) == constant_poly(3)
assert np.array(3, np.int64)[()] == constant_poly(3)
assert not np.array(3, np.int64) != constant_poly(3)
assert constant_poly(4) != 3
assert 3 == constant_poly(3)
assert 4 != constant_poly(3)
assert constant_poly(4) == constant_poly(4)
assert constant_poly(3) != constant_poly(4)
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3})
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4})
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2})
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4})
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4})
def test_Poly_hash(self):
assert not len({hash(Poly({Mon(): i})) for i in range(10)}) == 1
assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4}))
== hash(Poly({Mon({'n': 1}): 4, Mon(): 3})))
def test_Mon_hash(self):
assert not len({hash(Mon({'a': i})) for i in range(10)}) == 1
assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1}))
@parameterized.parameters([
(Mon({'a': 1}), Mon({'b': 1})),
(Mon({'a': 2, 'b': 1}), Mon({'b': 1})),
])
def test_Mon_floordiv(self, divisor, quotient):
dividend = quotient * divisor
self.assertEqual(quotient, dividend // divisor)
def test_Poly_compare(self):
poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
# Assume poly > 0 to make various shape rules work with polymorphic shapes:
assert poly >= 0
assert poly >= 1
assert poly > 0
assert 0 <= poly
assert 0 < poly
assert constant_poly(3) >= 1
assert constant_poly(3) > 1
assert poly >= poly
assert poly >= poly - 1
assert poly < poly + 1
poly >= 3
poly > 2
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
poly >= 4
n = Poly({Mon({'n': 1}): 1})
m = Poly({Mon({'m': 1}): 1})
must_divide_msg = " must divide size"
@parameterized.parameters([
(1, constant_poly(0), 0),
(n, 0, 0),
(2, n, 1),
(5, 2 * n, 0),
(5, 2 * n + 4, 3),
(n * n, n + 1, 0),
(2 * n + 1, 2 * n + 1, n + 2, must_divide_msg),
(n * m + 1, m + n + 1, n - 1, must_divide_msg),
(n, n, 0),
(n, n, 1, must_divide_msg),
(n + 1, -n + 1, -1, must_divide_msg),
])
def test_Poly_divmod(self, divisor, quotient, remainder, error_message=None):
dividend = quotient * divisor + remainder
expected = (quotient, remainder)
if dividend.is_constant: dividend = int(dividend)
if error_message:
with self.assertRaisesRegex(UndefinedPoly, error_message):
divmod(dividend, divisor)
else:
self.assertEqual(expected, divmod(dividend, divisor))
def test_Poly_rsub(self):
n = Poly({Mon({'n': 1}): 1})
assert -1 - n == -n - 1
class MaskingTest(jtu.JaxTestCase):
def test_sum(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return jnp.sum(x)
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = 8
self.assertAllClose(ans, expected, check_dtypes=False)
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4))
expected = 9
self.assertAllClose(ans, expected, check_dtypes=False)
def test_sum_vmap(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return jnp.sum(x)
ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5)))
expected = np.array([0, 1, 2, 3, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
dtypes, rng, rtol=None, atol=None):
shapecheck(in_shapes, out_shape)(fun)
masked_fun = mask(fun, in_shapes, out_shape)
padded_args = [rng(shape, dtype)
for shape, dtype in zip(padded_in_shapes, dtypes)]
padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))
out_specs, _ = tree_flatten(out_shape)
out_specs = map(parse_spec, out_specs)
out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
logical_out_shapes = [eval_poly_shape(s, logical_env)
for s in out_specs]
logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]
in_specs = map(parse_spec, in_shapes)
in_specs = map(finalize_spec, in_specs, padded_in_shapes)
logical_in_shapes = [eval_poly_shape(s, logical_env)
for s in in_specs]
logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
assert outs_tree == logical_outs_tree
self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
atol=atol, rtol=rtol)
# Check that abstract evaluation works
padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
atol=atol, rtol=rtol)
def test_add(self):
self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'],
jtu.rand_default(self.rng()))
addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n')
x = jnp.array([3, 1, 4, 1, 5, 9])
y = jnp.array([2, 6, 5, 3, 5, 8])
ans = addvecs([x, y], dict(n=3))
expected = np.array([5, 7, 9])
self.assertAllClose(ans[:3], expected, check_dtypes=False)
thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3))
self.assertRaisesRegex(ShapeError, "", thunk)
def test_scan(self):
@partial(mask, in_shapes=['n'], out_shape='')
def cumsum(arr):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
n = np.uint8(3) # Test non-default integer type for dynamic length.
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n))
expected = 16
self.assertAllClose(ans, expected, check_dtypes=False)
def test_scan_vmap(self):
@partial(mask, in_shapes=['n'], out_shape='')
def cumsum(arr):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2])))
expected = np.array([0, 7])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_scan_jit(self):
@partial(mask, in_shapes=['n'], out_shape='')
def cumsum(arr):
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
return out
@jit
def jit_cumsum(args, shape_env):
assert python_should_be_executing
return cumsum(args, shape_env)
python_should_be_executing = True
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
expected = 16
self.assertAllClose(ans, expected, check_dtypes=False)
python_should_be_executing = False
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4))
expected = 17
self.assertAllClose(ans, expected, check_dtypes=False)
python_should_be_executing = False
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1))
expected = 5
self.assertAllClose(ans, expected, check_dtypes=False)
# TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
@unittest.skip("Shapecheck fails")
def test_mean(self):
self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '',
{'n': 3}, [(4,)], ['float_'],
jtu.rand_default(self.rng()))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_arithmetic(self):
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
def times(x, y):
return x * y
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(
NotImplementedError,
'Masking rule for broadcast_in_dim not implemented yet.'):
times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
dict(n=4, m=5))
# expected = np.array([[1, 2, 3], [8, 10, 12]])
# self.assertAllClose(ans, expected, check_dtypes=False)
def test_stack(self):
@partial(mask, in_shapes=['n','n'], out_shape='(2, n)')
def stack(x, y):
return jnp.stack([x, y], 0)
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(
NotImplementedError,
'Masking rule for broadcast_in_dim not implemented yet.'):
stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
# expected = np.array([[1, 2, 3], [4, 5, 6]])
# self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='')
def padded_sum(x):
return jnp.sum(x)
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
expected = 8
self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic2(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='n')
def padded_sum(x):
return jnp.sum(x, axis=0)
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2))
expected = jnp.array([8, 10])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_monomorphic3(self):
@partial(mask, in_shapes=['(_, n)'], out_shape='_')
def padded_sum(x):
return jnp.sum(x, axis=1)
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
expected = jnp.array([3, 5])
self.assertAllClose(ans, expected, check_dtypes=False)
@shapecheck(['(2*n, n)'], '_, n')
def identity(x):
return x
def test_rnn(self):
n = 3
@partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_')
def rnn(W, xs):
def step(h, x):
new_h = jnp.dot(W, h) + jnp.dot(W, x)
return new_h, ()
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return predicted
rng = self.rng()
W = jnp.eye(n)
xs = rng.randn(10, n).astype(jnp.float_)
ans = rnn([W, xs], dict(t=4))
expected = xs[:4].sum(0)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_rnn_grad(self):
n = 3
@partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='')
def rnn(W, xs, target):
def step(h, x):
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
return new_h, ()
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return jnp.sum((predicted - target)**2)
rng = self.rng()
W = rng.randn(n, n).astype(jnp.float_)
xs = rng.randn(10, n).astype(jnp.float_)
y = rng.randn(n).astype(jnp.float_)
ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W)
def rnn_reference(W, xs, target):
h = jnp.zeros(n)
for x in xs:
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
predicted = h
return jnp.sum((predicted - target)**2)
expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W)
self.assertAllClose(ans, expected, check_dtypes=False,
rtol={np.float64: 1e-14, np.float32: 1e-5})
def test_ragged_batched_rnn(self):
n = 3
@partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
def rnn(W, xs, target):
def step(h, x):
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
return new_h, ()
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
return jnp.sum((predicted - target)**2)
rng = self.rng()
W = rng.randn(n, n).astype(jnp.float_)
seqs = rng.randn(3, 10, n).astype(jnp.float_)
ts = jnp.array([2, 5, 4])
ys = rng.randn(3, n)
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
def rnn_reference(W, seqs, targets):
total_loss = jnp.array(0.0)
for xs, target in zip(seqs, targets):
h = jnp.zeros(n)
for x in xs:
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
predicted = h
total_loss = total_loss + jnp.sum((predicted - target)**2)
return total_loss
seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)
self.assertAllClose(
ans, expected, check_dtypes=False,
rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5)
def test_concatenate(self):
self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
[(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
jtu.rand_default(self.rng()))
def test_dot(self):
self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
jtu.rand_default(self.rng()))
self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
['float_', 'float_'], jtu.rand_default(self.rng()))
# TODO(mattjj,j-towns): fix test failure and reenable.
@jtu.skip_on_devices("tpu")
def test_jit(self):
@partial(mask, in_shapes=['n'], out_shape='2*n')
@jit
def duplicate(x):
assert python_should_be_executing
return lax.concatenate([x, x], 0)
python_should_be_executing = True
out = duplicate([jnp.arange(3)], dict(n=2))
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
python_should_be_executing = False
out = duplicate([jnp.arange(3)], dict(n=2))
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
@unittest.skip("broken by omnistaging") # TODO(mattjj): update
def test_jit2(self):
# Trigger MaskTrace.post_process_call
def fun(x):
@jit
def concat(y):
return lax.concatenate([x, y], 0)
return concat(jnp.array([1., 2., 3.], dtype='float32'))
self.check(fun, ['n'], '(n+3,)', {'n': 2}, [(3,)], ['float32'],
jtu.rand_default(self.rng()))
@parameterized.named_parameters({
'testcase_name': "padding_config={}_shapes={}".format(padding_config,
shape),
'padding_config': padding_config,
'shape': shape} for padding_config, shape in (
(((1, 2, 0),), (2,)),
(((1, 2, 0), (3, 4, 0)), (1, 2)),
(((0, 0, 0), (0, 0, 0)), (1, 2)),
(((1, 2, 3),), (2,)),
(((1, 2, 1), (3, 4, 2)), (3, 2)),
(((-1, 2, 0),), (2,)),
(((-1, -2, 0), (1, 2, 0)), (4, 2)),
(((-1, 2, 0), (1, 2, 2)), (4, 2)),
(((-1, -2, 2),), (5,)),
(((-1, -2, 1), (1, 2, 2)), (4, 2))))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_pad(self, padding_config, shape):
def pad(x):
return lax.pad(x, jnp.array(1., x.dtype), padding_config)
if len(shape) == 1:
padding_config_, = padding_config
linear_coeff = padding_config_[2] + 1
const_coeff = sum(padding_config_[:2]) - padding_config_[2]
out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff)
self.check(pad, ['h'], out_shape, dict(h=shape[0]),
[tuple(np.add(shape, 1))], ['float_'],
jtu.rand_default(self.rng()))
# TODO(mattjj,j-towns): fix test failure and reenable.
@jtu.skip_on_devices("tpu")
@unittest.skip("broken by omnistaging") # TODO(mattjj): update
def test_numpy_pad(self):
def numpy_pad(x):
return jnp.pad(x, (0, 1), constant_values=5.)
self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'],
jtu.rand_default(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': "padding={}_lhs_dilation={}_"
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm),
'padding': padding, 'lhs_dilation': lhs_dilation,
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
'rhs_perm': rhs_perm, 'out_perm': out_perm}
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
for lhs_dilation in (None, (1, 2))
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
)
# String padding is not implemented for transposed convolution, see
# conv_general_dilated implementation:
if (lhs_dilation is None or not isinstance(padding, str))))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_conv(
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm):
def conv(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation,
dimension_numbers=dimension_numbers)
template = '({}, {}, {}, {})'
lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm))
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
if padding == 'VALID':
out_shape = template.format(
*np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm))
elif lhs_dilation:
out_shape = template.format(
*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
else:
out_shape = template.format(
*np.take(['n', 'o', 'h', 'w'], out_perm))
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
self.check(conv, [lhs_shape, rhs_shape], out_shape,
logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)),
tuple(np.take([7, 3, 2, 3], rhs_perm))],
['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
atol=1e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': "padding={}_lhs_dilation={}_"
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm),
'padding': padding, 'lhs_dilation': lhs_dilation,
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
'rhs_perm': rhs_perm, 'out_perm': out_perm}
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
for lhs_dilation in (None, (1, 2))
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
)
# String padding is not implemented for transposed convolution, see
# conv_general_dilated implementation:
if (lhs_dilation is None or not isinstance(padding, str))))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_conv_strided(
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm):
def conv(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation,
dimension_numbers=dimension_numbers)
template = '({}, {}, {}, {})'
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
if padding == 'VALID':
lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm))
elif lhs_dilation:
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
else:
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm))
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
self.check(conv, [lhs_shape, rhs_shape], out_shape,
logical_env, [lhs_shape_padded,
tuple(np.take([7, 3, 2, 3], rhs_perm))],
['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
atol=1e-4)
@unittest.skip("requires gather support")
def test_indexing(self):
self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'],
jtu.rand_default(self.rng()))
self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'],
jtu.rand_default(self.rng()))
@unittest.skip("requires gather support")
def test_slicing(self):
self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_'])
def test_rev(self):
@shapecheck(['n'], 'n')
def rev1(x):
return lax.rev(x, (0,))
@shapecheck(['(m, n)'], '(m, n)')
def rev2(x):
return lax.rev(x, (1,))
@unittest.skip("TODO")
def test_rev_by_indexing(self):
@shapecheck(['n'], 'n+-1')
def rev1(x):
return x[:0:-1]
@shapecheck(['n'], 'n+-1')
def rev2(x):
return x[-2::-1]
# TODO implement masking for rev_p:
# self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
# self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_lax_slice(self):
self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
{'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
# TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)),
# ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)),
['n'], '1', {'n': 2}, [(5,)], ['float_'],
jtu.rand_default(self.rng()))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_reshape(self):
self.check(lambda x: jnp.reshape(x, (x.shape[1], 2, 4, 1)),
['1, n, 4, 2'], 'n, 2, 4, 1', dict(n=2), [(1, 3, 4, 2)],
['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.reshape(x, (x.shape[0] * 2,)),
['n, 2'], '2 * n', dict(n=2), [(3, 2)],
['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.reshape(x, (x.shape[0] // 2, 2)),
['2 * n'], 'n, 2', dict(n=2), [(6,)],
['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.reshape(x, (x.shape[0] * 4, 2)),
['n, 2, 4'], '4 * n, 2', dict(n=2), [(3, 2, 4)],
['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.reshape(x, ((x.shape[0] - 1) // 4 + 1, 2, 4)),
['4 * n + 4, 2'], 'n + 1, 2, 4', dict(n=2), [(12, 2)],
['float_'], jtu.rand_default(self.rng()))
msg = "Reshape on padded dimensions causing fragmentation is not supported."
with self.assertRaisesRegex(NotImplementedError, msg):
self.check(lambda x: jnp.reshape(x, np.prod(x.shape)),
['a, b'], 'a*b', dict(a=2, b=3), [(3, 4)],
['float_'], jtu.rand_default(self.rng()))
with self.assertRaisesRegex(NotImplementedError, msg):
self.check(lambda x: jnp.reshape(x, (x.shape[1], x.shape[0])),
['a, b'], 'b, a', dict(a=2, b=3), [(3, 4)],
['float_'], jtu.rand_default(self.rng()))
with self.assertRaisesRegex(NotImplementedError, msg):
self.check(lambda x: jnp.reshape(x, (x.shape[1] * 2,)),
['2, n'], '2 * n', dict(n=2), [(2, 3)],
['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.reshape(x, (x.shape[0], -1)),
['n, 3, 1, 2'], 'n, 6', dict(n=1), [(2, 3, 1, 2)],
['float_'], jtu.rand_default(self.rng()))
def test_transpose(self):
self.check(lambda x: lax.transpose(x, (1, 0, 2)),
['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
['float_'], jtu.rand_default(self.rng()))
def test_sum_2d(self):
self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'],
jtu.rand_default(self.rng()))
@unittest.skip("custom_jvp doesn't work with masking yet")
def test_expit(self):
self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'],
jtu.rand_default(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{dtype}", "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
@unittest.skip("not yet implemented")
def test_uniform(self, dtype):
# TODO needs fix for https://github.com/google/jax/issues/2155
pass
@unittest.skip("not yet implemented")
def test_broadcast_in_dim(self):
pass
def test_destructure(self):
def d(key):
key1, key2 = key
return key1
self.check(d, ['2'], '', {}, [(2,)], ['int_'], jtu.rand_int(self.rng(), 0, 10))
# TODO(mattjj,j-towns): fix test failure and reenable.
@jtu.skip_on_devices("tpu")
def test_where(self):
self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n',
{'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
@unittest.skip("Failing after fixing Poly unsoundness #4878")
def test_split(self):
self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
[(8,)], ['float_'], jtu.rand_default(self.rng()))
self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12),
[(12,)], ['float_'], jtu.rand_default(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list([{
'testcase_name': f"operator={operator.__name__}", 'operator': operator}
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
def test_reduce(self, operator):
self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
jtu.rand_default(self.rng()))
def test_output_shape_error(self):
def thunk():
shapecheck(['n'], 'n+-1')(lambda x: x)
message = "Output shapes should be (n + -1,) but are (n,)."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def thunk():
shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))
message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def test_output_tree_error(self):
def thunk():
shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])
message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def test_unsupported_op(self):
p = core.Primitive('unsupported_op')
p.def_abstract_eval(lambda x: x)
p.def_impl(lambda x: x)
def thunk():
mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})
message = "Masking rule for unsupported_op not implemented yet."
self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
@unittest.skip("not yet implemented")
def test_nesting(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return jnp.sum(x)
batched_sum = vmap(padded_sum)
@partial(mask, in_shapes=['(m, _)', 'm'], out_shape='')
def fun(x, ns):
return batched_sum([x], dict(n=ns)).sum()
x = jnp.array([[3, 1, 4, 1],
[5, 9, 2, 6],
[5, 3, 5, 8]])
ns = jnp.array([2, 3, 2])
ans = fun([x, ns], dict(m=2))
expected = 3+1 + 5+9+2
self.assertAllClose(ans, expected, check_dtypes=False)
def test_slice_oob_indexing(self):
# https://github.com/google/jax/issues/2245
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10])
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:])
def test_jaxpr_doesnt_include_trivial_operations(self):
@partial(mask, in_shapes=['n'], out_shape='')
def foo(x):
return np.sum(x)
padded_x = np.array([0, 1, 2, 3, 999, 999])
jaxpr = make_jaxpr(foo)([padded_x], dict(n=3))
self.assertNotIn('mul', str(jaxpr))
self.assertNotIn('add', str(jaxpr))
def test_return_shape_to_user(self):
@partial(mask, in_shapes=['n'])
def foo(x):
return [x, np.sum(x)]
out, out_shape = foo([np.arange(5)], dict(n=2))
self.assertIsInstance(out_shape, list)
self.assertLen(out_shape, 2)
a, b = out_shape
self.assertEqual(a.shape, (2,))
self.assertEqual(b.shape, ())
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())