mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.mask and jax.shapecheck are being deprecated
Issue: #11557 PiperOrigin-RevId: 462315754
This commit is contained in:
parent
ba7ded4331
commit
07fcf79324
@ -28,6 +28,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
|
||||
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
|
||||
following a similar deprecation in {func}`scipy.linalg.solve`.
|
||||
* Deprecations:
|
||||
* {func}`jax.mask` {func}`jax.shapecheck` are being deprecated.
|
||||
See {jax-issue}`#11557`.
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
||||
|
@ -2249,6 +2249,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
|
||||
|
||||
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
|
||||
warn("`jax.mask` is deprecated and will be removed soon. ",
|
||||
DeprecationWarning)
|
||||
_check_callable(fun)
|
||||
unique_ids = masking.UniqueIds()
|
||||
|
||||
@ -2292,6 +2294,8 @@ def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
|
||||
|
||||
@curry
|
||||
def shapecheck(in_shapes, out_shape, fun: Callable):
|
||||
warn("`jax.shapecheck` is deprecated and will be removed soon. ",
|
||||
DeprecationWarning)
|
||||
_check_callable(fun)
|
||||
in_shapes, in_tree = tree_flatten(in_shapes)
|
||||
in_shapes = map(masking.parse_spec, in_shapes)
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Masking is **DEPRECATED** and is being removed."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from collections import Counter, namedtuple
|
||||
from functools import partial, reduce
|
||||
|
@ -464,11 +464,6 @@ jax_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "masking_test",
|
||||
srcs = ["masking_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "metadata_test",
|
||||
srcs = ["metadata_test.py"],
|
||||
|
@ -1859,67 +1859,6 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
what: ct_b
|
||||
1.""", testing_stream.output)
|
||||
|
||||
def test_tap_mask(self):
|
||||
|
||||
@partial(jax.mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
three_x = hcb.id_print((x, 2 * x), result=3 * x, what="x",
|
||||
output_stream=testing_stream)
|
||||
return jnp.sum(three_x)
|
||||
|
||||
x = np.arange(5.)
|
||||
|
||||
self.assertAllClose(9., padded_sum([x], dict(n=3)))
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# With VMAP
|
||||
xv = np.arange(10.).reshape((2, 5)) # logical_shape = 5
|
||||
self.assertAllClose(
|
||||
np.array([9., 78.]),
|
||||
# batch_size = 2, n=3 and 4 for the two elements
|
||||
jax.vmap(padded_sum)([xv],
|
||||
dict(n=np.array([3., 4.]))))
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5}), ('batch', {'batch_dims': (0, 0, 0, 0)})] what: x
|
||||
( ( [[0. 1. 2. 3. 4.]
|
||||
[5. 6. 7. 8. 9.]]
|
||||
[[ 0. 2. 4. 6. 8.]
|
||||
[10. 12. 14. 16. 18.]] )
|
||||
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# With JVP
|
||||
self.assertAllClose((9., 0.9),
|
||||
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
|
||||
(x,), (x * 0.1,)))
|
||||
hcb.barrier_wait()
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
|
||||
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
|
||||
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
|
||||
testing_stream.output)
|
||||
else:
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# Now with JIT
|
||||
self.assertAllClose(9., jax.jit(padded_sum)([x], dict(n=3)))
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
transforms: [('mask', {'logical_shapes': 5})] what: x
|
||||
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
|
||||
testing_stream.output)
|
||||
|
||||
def test_tap_callback_delay(self):
|
||||
hcb.callback_extra = lambda dev: time.sleep(1)
|
||||
|
||||
|
@ -1,830 +0,0 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.tree_util import tree_flatten
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.scipy.special import expit
|
||||
from jax import mask, vmap, jit, grad, shapecheck, make_jaxpr
|
||||
from jax.interpreters.masking import (
|
||||
shape_as_value, ShapeError, parse_spec, Poly, Mon, finalize_spec,
|
||||
eval_poly_shape, remap_ids, UniqueIds, UndefinedPoly)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
|
||||
# TODO:
|
||||
# These should be only the 'manual' tests for masking.
|
||||
# Move the more exhaustive, systematic tests into lax_test.py.
|
||||
|
||||
def constant_poly(c):
|
||||
return Poly({Mon(): c})
|
||||
|
||||
class PolyTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
['(m, n)', 'ShapeSpec(m, n)'],
|
||||
['(m * n)', 'ShapeSpec(m n)'],
|
||||
['m * n', 'ShapeSpec(m n)'],
|
||||
['(m * n,)', 'ShapeSpec(m n)'],
|
||||
['(3, m)', 'ShapeSpec(3, m)'],
|
||||
['(10, m)', 'ShapeSpec(10, m)'],
|
||||
['(-10, m)', 'ShapeSpec(-10, m)'],
|
||||
['(3 * m)', 'ShapeSpec(3 m)'],
|
||||
['m', 'ShapeSpec(m)'],
|
||||
['', 'ShapeSpec()'],
|
||||
['n + -1*n', 'ShapeSpec(0)'],
|
||||
['m + n', 'ShapeSpec(m + n)'],
|
||||
['m + n * k', 'ShapeSpec(k n + m)'],
|
||||
['m + 3 * k', 'ShapeSpec(3 k + m)'],
|
||||
['-3 + k + k * k', 'ShapeSpec(k^2 + k + -3)'],
|
||||
['', 'ShapeSpec()'],
|
||||
['_', 'ShapeSpec(_)'],
|
||||
])
|
||||
def test_parse_spec(self, spec, ans):
|
||||
self.assertEqual(str(parse_spec(spec)), ans)
|
||||
self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)
|
||||
|
||||
def test_Poly_equal(self):
|
||||
assert constant_poly(3) == 3
|
||||
assert np.array(3, np.int64) == constant_poly(3)
|
||||
assert np.array(3, np.int64)[()] == constant_poly(3)
|
||||
assert not np.array(3, np.int64) != constant_poly(3)
|
||||
assert constant_poly(4) != 3
|
||||
assert 3 == constant_poly(3)
|
||||
assert 4 != constant_poly(3)
|
||||
assert constant_poly(4) == constant_poly(4)
|
||||
assert constant_poly(3) != constant_poly(4)
|
||||
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3})
|
||||
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4})
|
||||
assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2})
|
||||
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
|
||||
Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4})
|
||||
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
|
||||
Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4})
|
||||
|
||||
def test_Poly_hash(self):
|
||||
assert not len({hash(Poly({Mon(): i})) for i in range(10)}) == 1
|
||||
assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4}))
|
||||
== hash(Poly({Mon({'n': 1}): 4, Mon(): 3})))
|
||||
|
||||
def test_Mon_hash(self):
|
||||
assert not len({hash(Mon({'a': i})) for i in range(10)}) == 1
|
||||
assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1}))
|
||||
|
||||
@parameterized.parameters([
|
||||
(Mon({'a': 1}), Mon({'b': 1})),
|
||||
(Mon({'a': 2, 'b': 1}), Mon({'b': 1})),
|
||||
])
|
||||
def test_Mon_floordiv(self, divisor, quotient):
|
||||
dividend = quotient * divisor
|
||||
self.assertEqual(quotient, dividend // divisor)
|
||||
|
||||
def test_Poly_compare(self):
|
||||
poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
|
||||
# Assume poly > 0 to make various shape rules work with polymorphic shapes:
|
||||
assert poly >= 0
|
||||
assert poly >= 1
|
||||
assert poly > 0
|
||||
|
||||
assert 0 <= poly
|
||||
assert 0 < poly
|
||||
assert constant_poly(3) >= 1
|
||||
assert constant_poly(3) > 1
|
||||
assert poly >= poly
|
||||
assert poly >= poly - 1
|
||||
assert poly < poly + 1
|
||||
|
||||
poly >= 3
|
||||
poly > 2
|
||||
with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
|
||||
poly >= 4
|
||||
|
||||
n = Poly({Mon({'n': 1}): 1})
|
||||
m = Poly({Mon({'m': 1}): 1})
|
||||
|
||||
must_divide_msg = " must divide size"
|
||||
|
||||
@parameterized.parameters([
|
||||
(1, constant_poly(0), 0),
|
||||
(n, 0, 0),
|
||||
(2, n, 1),
|
||||
(5, 2 * n, 0),
|
||||
(5, 2 * n + 4, 3),
|
||||
(n * n, n + 1, 0),
|
||||
(2 * n + 1, 2 * n + 1, n + 2, must_divide_msg),
|
||||
(n * m + 1, m + n + 1, n - 1, must_divide_msg),
|
||||
(n, n, 0),
|
||||
(n, n, 1, must_divide_msg),
|
||||
(n + 1, -n + 1, -1, must_divide_msg),
|
||||
])
|
||||
def test_Poly_divmod(self, divisor, quotient, remainder, error_message=None):
|
||||
dividend = quotient * divisor + remainder
|
||||
expected = (quotient, remainder)
|
||||
if dividend.is_constant: dividend = int(dividend)
|
||||
if error_message:
|
||||
with self.assertRaisesRegex(UndefinedPoly, error_message):
|
||||
divmod(dividend, divisor)
|
||||
else:
|
||||
self.assertEqual(expected, divmod(dividend, divisor))
|
||||
|
||||
def test_Poly_rsub(self):
|
||||
n = Poly({Mon({'n': 1}): 1})
|
||||
assert -1 - n == -n - 1
|
||||
|
||||
class MaskingTest(jtu.JaxTestCase):
|
||||
def test_sum(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = 8
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4))
|
||||
expected = 9
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_sum_vmap(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5)))
|
||||
expected = np.array([0, 1, 2, 3, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
|
||||
dtypes, rng, rtol=None, atol=None):
|
||||
shapecheck(in_shapes, out_shape)(fun)
|
||||
masked_fun = mask(fun, in_shapes, out_shape)
|
||||
padded_args = [rng(shape, dtype)
|
||||
for shape, dtype in zip(padded_in_shapes, dtypes)]
|
||||
padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))
|
||||
|
||||
out_specs, _ = tree_flatten(out_shape)
|
||||
out_specs = map(parse_spec, out_specs)
|
||||
out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
|
||||
logical_out_shapes = [eval_poly_shape(s, logical_env)
|
||||
for s in out_specs]
|
||||
logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
|
||||
logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]
|
||||
|
||||
in_specs = map(parse_spec, in_shapes)
|
||||
in_specs = map(finalize_spec, in_specs, padded_in_shapes)
|
||||
logical_in_shapes = [eval_poly_shape(s, logical_env)
|
||||
for s in in_specs]
|
||||
logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
|
||||
logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
|
||||
logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
|
||||
assert outs_tree == logical_outs_tree
|
||||
self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
# Check that abstract evaluation works
|
||||
padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
|
||||
self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
def test_add(self):
|
||||
self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n')
|
||||
|
||||
x = jnp.array([3, 1, 4, 1, 5, 9])
|
||||
y = jnp.array([2, 6, 5, 3, 5, 8])
|
||||
ans = addvecs([x, y], dict(n=3))
|
||||
expected = np.array([5, 7, 9])
|
||||
self.assertAllClose(ans[:3], expected, check_dtypes=False)
|
||||
|
||||
thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3))
|
||||
self.assertRaisesRegex(ShapeError, "", thunk)
|
||||
|
||||
def test_scan(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def cumsum(arr):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
n = np.uint8(3) # Test non-default integer type for dynamic length.
|
||||
ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n))
|
||||
expected = 16
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_scan_vmap(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def cumsum(arr):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2])))
|
||||
expected = np.array([0, 7])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_scan_jit(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def cumsum(arr):
|
||||
out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
|
||||
return out
|
||||
|
||||
@jit
|
||||
def jit_cumsum(args, shape_env):
|
||||
assert python_should_be_executing
|
||||
return cumsum(args, shape_env)
|
||||
|
||||
python_should_be_executing = True
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
|
||||
expected = 16
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
python_should_be_executing = False
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4))
|
||||
expected = 17
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
python_should_be_executing = False
|
||||
ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1))
|
||||
expected = 5
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
|
||||
@unittest.skip("Shapecheck fails")
|
||||
def test_mean(self):
|
||||
self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '',
|
||||
{'n': 3}, [(4,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_arithmetic(self):
|
||||
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
|
||||
def times(x, y):
|
||||
return x * y
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Masking rule for broadcast_in_dim not implemented yet.'):
|
||||
times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
|
||||
dict(n=4, m=5))
|
||||
# expected = np.array([[1, 2, 3], [8, 10, 12]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_stack(self):
|
||||
@partial(mask, in_shapes=['n','n'], out_shape='(2, n)')
|
||||
def stack(x, y):
|
||||
return jnp.stack([x, y], 0)
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Masking rule for broadcast_in_dim not implemented yet.'):
|
||||
stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
|
||||
# expected = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x)
|
||||
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
expected = 8
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic2(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='n')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x, axis=0)
|
||||
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2))
|
||||
expected = jnp.array([8, 10])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_monomorphic3(self):
|
||||
@partial(mask, in_shapes=['(_, n)'], out_shape='_')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x, axis=1)
|
||||
|
||||
ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
|
||||
expected = jnp.array([3, 5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@shapecheck(['(2*n, n)'], '_, n')
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
def test_rnn(self):
|
||||
n = 3
|
||||
|
||||
@partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_')
|
||||
def rnn(W, xs):
|
||||
def step(h, x):
|
||||
new_h = jnp.dot(W, h) + jnp.dot(W, x)
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return predicted
|
||||
|
||||
rng = self.rng()
|
||||
W = jnp.eye(n)
|
||||
xs = rng.randn(10, n).astype(jnp.float_)
|
||||
ans = rnn([W, xs], dict(t=4))
|
||||
expected = xs[:4].sum(0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_rnn_grad(self):
|
||||
n = 3
|
||||
|
||||
@partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='')
|
||||
def rnn(W, xs, target):
|
||||
def step(h, x):
|
||||
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
rng = self.rng()
|
||||
W = rng.randn(n, n).astype(jnp.float_)
|
||||
xs = rng.randn(10, n).astype(jnp.float_)
|
||||
y = rng.randn(n).astype(jnp.float_)
|
||||
|
||||
ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W)
|
||||
|
||||
def rnn_reference(W, xs, target):
|
||||
h = jnp.zeros(n)
|
||||
for x in xs:
|
||||
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
predicted = h
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W)
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
rtol={np.float64: 1e-14, np.float32: 1e-5})
|
||||
|
||||
def test_ragged_batched_rnn(self):
|
||||
n = 3
|
||||
|
||||
@partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
|
||||
def rnn(W, xs, target):
|
||||
def step(h, x):
|
||||
new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
return new_h, ()
|
||||
predicted, _ = lax.scan(step, jnp.zeros(n), xs)
|
||||
return jnp.sum((predicted - target)**2)
|
||||
|
||||
rng = self.rng()
|
||||
W = rng.randn(n, n).astype(jnp.float_)
|
||||
seqs = rng.randn(3, 10, n).astype(jnp.float_)
|
||||
ts = jnp.array([2, 5, 4])
|
||||
ys = rng.randn(3, n)
|
||||
|
||||
ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)
|
||||
|
||||
def rnn_reference(W, seqs, targets):
|
||||
total_loss = jnp.array(0.0)
|
||||
for xs, target in zip(seqs, targets):
|
||||
h = jnp.zeros(n)
|
||||
for x in xs:
|
||||
h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
|
||||
predicted = h
|
||||
total_loss = total_loss + jnp.sum((predicted - target)**2)
|
||||
return total_loss
|
||||
|
||||
seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
|
||||
expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)
|
||||
|
||||
self.assertAllClose(
|
||||
ans, expected, check_dtypes=False,
|
||||
rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5)
|
||||
|
||||
def test_concatenate(self):
|
||||
self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
|
||||
['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
|
||||
[(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
def test_dot(self):
|
||||
self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
|
||||
dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
|
||||
['float_', 'float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
# TODO(mattjj,j-towns): fix test failure and reenable.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_jit(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='2*n')
|
||||
@jit
|
||||
def duplicate(x):
|
||||
assert python_should_be_executing
|
||||
return lax.concatenate([x, x], 0)
|
||||
|
||||
python_should_be_executing = True
|
||||
out = duplicate([jnp.arange(3)], dict(n=2))
|
||||
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
|
||||
|
||||
python_should_be_executing = False
|
||||
out = duplicate([jnp.arange(3)], dict(n=2))
|
||||
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
|
||||
|
||||
@unittest.skip("broken by omnistaging") # TODO(mattjj): update
|
||||
def test_jit2(self):
|
||||
# Trigger MaskTrace.post_process_call
|
||||
def fun(x):
|
||||
@jit
|
||||
def concat(y):
|
||||
return lax.concatenate([x, y], 0)
|
||||
return concat(jnp.array([1., 2., 3.], dtype='float32'))
|
||||
|
||||
self.check(fun, ['n'], '(n+3,)', {'n': 2}, [(3,)], ['float32'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': "padding_config={}_shapes={}".format(padding_config,
|
||||
shape),
|
||||
'padding_config': padding_config,
|
||||
'shape': shape} for padding_config, shape in (
|
||||
(((1, 2, 0),), (2,)),
|
||||
(((1, 2, 0), (3, 4, 0)), (1, 2)),
|
||||
(((0, 0, 0), (0, 0, 0)), (1, 2)),
|
||||
(((1, 2, 3),), (2,)),
|
||||
(((1, 2, 1), (3, 4, 2)), (3, 2)),
|
||||
(((-1, 2, 0),), (2,)),
|
||||
(((-1, -2, 0), (1, 2, 0)), (4, 2)),
|
||||
(((-1, 2, 0), (1, 2, 2)), (4, 2)),
|
||||
(((-1, -2, 2),), (5,)),
|
||||
(((-1, -2, 1), (1, 2, 2)), (4, 2))))
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_pad(self, padding_config, shape):
|
||||
def pad(x):
|
||||
return lax.pad(x, jnp.array(1., x.dtype), padding_config)
|
||||
|
||||
if len(shape) == 1:
|
||||
padding_config_, = padding_config
|
||||
linear_coeff = padding_config_[2] + 1
|
||||
const_coeff = sum(padding_config_[:2]) - padding_config_[2]
|
||||
out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff)
|
||||
self.check(pad, ['h'], out_shape, dict(h=shape[0]),
|
||||
[tuple(np.add(shape, 1))], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
|
||||
# TODO(mattjj,j-towns): fix test failure and reenable.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@unittest.skip("broken by omnistaging") # TODO(mattjj): update
|
||||
def test_numpy_pad(self):
|
||||
def numpy_pad(x):
|
||||
return jnp.pad(x, (0, 1), constant_values=5.)
|
||||
|
||||
self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': "padding={}_lhs_dilation={}_"
|
||||
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
|
||||
padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm),
|
||||
'padding': padding, 'lhs_dilation': lhs_dilation,
|
||||
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
|
||||
'rhs_perm': rhs_perm, 'out_perm': out_perm}
|
||||
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
|
||||
for lhs_dilation in (None, (1, 2))
|
||||
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
|
||||
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
|
||||
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
|
||||
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
|
||||
)
|
||||
# String padding is not implemented for transposed convolution, see
|
||||
# conv_general_dilated implementation:
|
||||
if (lhs_dilation is None or not isinstance(padding, str))))
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_conv(
|
||||
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm):
|
||||
def conv(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
template = '({}, {}, {}, {})'
|
||||
lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm))
|
||||
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
|
||||
if padding == 'VALID':
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm))
|
||||
elif lhs_dilation:
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
|
||||
else:
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h', 'w'], out_perm))
|
||||
|
||||
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
|
||||
|
||||
self.check(conv, [lhs_shape, rhs_shape], out_shape,
|
||||
logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)),
|
||||
tuple(np.take([7, 3, 2, 3], rhs_perm))],
|
||||
['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
|
||||
atol=1e-4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': "padding={}_lhs_dilation={}_"
|
||||
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
|
||||
padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm),
|
||||
'padding': padding, 'lhs_dilation': lhs_dilation,
|
||||
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
|
||||
'rhs_perm': rhs_perm, 'out_perm': out_perm}
|
||||
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
|
||||
for lhs_dilation in (None, (1, 2))
|
||||
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
|
||||
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
|
||||
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
|
||||
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
|
||||
)
|
||||
# String padding is not implemented for transposed convolution, see
|
||||
# conv_general_dilated implementation:
|
||||
if (lhs_dilation is None or not isinstance(padding, str))))
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_conv_strided(
|
||||
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm):
|
||||
def conv(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
template = '({}, {}, {}, {})'
|
||||
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
|
||||
if padding == 'VALID':
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm))
|
||||
elif lhs_dilation:
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
|
||||
else:
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm))
|
||||
|
||||
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
|
||||
|
||||
self.check(conv, [lhs_shape, rhs_shape], out_shape,
|
||||
logical_env, [lhs_shape_padded,
|
||||
tuple(np.take([7, 3, 2, 3], rhs_perm))],
|
||||
['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
|
||||
atol=1e-4)
|
||||
|
||||
@unittest.skip("requires gather support")
|
||||
def test_indexing(self):
|
||||
self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@unittest.skip("requires gather support")
|
||||
def test_slicing(self):
|
||||
self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
|
||||
self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
|
||||
self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_'])
|
||||
|
||||
def test_rev(self):
|
||||
@shapecheck(['n'], 'n')
|
||||
def rev1(x):
|
||||
return lax.rev(x, (0,))
|
||||
|
||||
@shapecheck(['(m, n)'], '(m, n)')
|
||||
def rev2(x):
|
||||
return lax.rev(x, (1,))
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_rev_by_indexing(self):
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def rev1(x):
|
||||
return x[:0:-1]
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def rev2(x):
|
||||
return x[-2::-1]
|
||||
|
||||
# TODO implement masking for rev_p:
|
||||
# self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
|
||||
# self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
|
||||
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_lax_slice(self):
|
||||
self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
|
||||
{'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
|
||||
# TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)),
|
||||
# ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng()))
|
||||
self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)),
|
||||
['n'], '1', {'n': 2}, [(5,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_reshape(self):
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[1], 2, 4, 1)),
|
||||
['1, n, 4, 2'], 'n, 2, 4, 1', dict(n=2), [(1, 3, 4, 2)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[0] * 2,)),
|
||||
['n, 2'], '2 * n', dict(n=2), [(3, 2)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[0] // 2, 2)),
|
||||
['2 * n'], 'n, 2', dict(n=2), [(6,)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[0] * 4, 2)),
|
||||
['n, 2, 4'], '4 * n, 2', dict(n=2), [(3, 2, 4)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
self.check(lambda x: jnp.reshape(x, ((x.shape[0] - 1) // 4 + 1, 2, 4)),
|
||||
['4 * n + 4, 2'], 'n + 1, 2, 4', dict(n=2), [(12, 2)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
msg = "Reshape on padded dimensions causing fragmentation is not supported."
|
||||
with self.assertRaisesRegex(NotImplementedError, msg):
|
||||
self.check(lambda x: jnp.reshape(x, np.prod(x.shape)),
|
||||
['a, b'], 'a*b', dict(a=2, b=3), [(3, 4)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, msg):
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[1], x.shape[0])),
|
||||
['a, b'], 'b, a', dict(a=2, b=3), [(3, 4)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, msg):
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[1] * 2,)),
|
||||
['2, n'], '2 * n', dict(n=2), [(2, 3)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
self.check(lambda x: jnp.reshape(x, (x.shape[0], -1)),
|
||||
['n, 3, 1, 2'], 'n, 6', dict(n=1), [(2, 3, 1, 2)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
def test_transpose(self):
|
||||
self.check(lambda x: lax.transpose(x, (1, 0, 2)),
|
||||
['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
|
||||
['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
def test_sum_2d(self):
|
||||
self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@unittest.skip("custom_jvp doesn't work with masking yet")
|
||||
def test_expit(self):
|
||||
self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{dtype}", "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_uniform(self, dtype):
|
||||
# TODO needs fix for https://github.com/google/jax/issues/2155
|
||||
pass
|
||||
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_broadcast_in_dim(self):
|
||||
pass
|
||||
|
||||
def test_destructure(self):
|
||||
def d(key):
|
||||
key1, key2 = key
|
||||
return key1
|
||||
|
||||
self.check(d, ['2'], '', {}, [(2,)], ['int_'], jtu.rand_int(self.rng(), 0, 10))
|
||||
|
||||
# TODO(mattjj,j-towns): fix test failure and reenable.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_where(self):
|
||||
self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n',
|
||||
{'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
@unittest.skip("Failing after fixing Poly unsoundness #4878")
|
||||
def test_split(self):
|
||||
self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
|
||||
[(8,)], ['float_'], jtu.rand_default(self.rng()))
|
||||
self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12),
|
||||
[(12,)], ['float_'], jtu.rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list([{
|
||||
'testcase_name': f"operator={operator.__name__}", 'operator': operator}
|
||||
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
|
||||
def test_reduce(self, operator):
|
||||
self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
|
||||
jtu.rand_default(self.rng()))
|
||||
|
||||
def test_output_shape_error(self):
|
||||
def thunk():
|
||||
shapecheck(['n'], 'n+-1')(lambda x: x)
|
||||
|
||||
message = "Output shapes should be (n + -1,) but are (n,)."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def thunk():
|
||||
shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))
|
||||
|
||||
message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def test_output_tree_error(self):
|
||||
def thunk():
|
||||
shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])
|
||||
|
||||
message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def test_unsupported_op(self):
|
||||
p = core.Primitive('unsupported_op')
|
||||
p.def_abstract_eval(lambda x: x)
|
||||
p.def_impl(lambda x: x)
|
||||
|
||||
def thunk():
|
||||
mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})
|
||||
|
||||
message = "Masking rule for unsupported_op not implemented yet."
|
||||
self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
|
||||
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_nesting(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x)
|
||||
|
||||
batched_sum = vmap(padded_sum)
|
||||
|
||||
@partial(mask, in_shapes=['(m, _)', 'm'], out_shape='')
|
||||
def fun(x, ns):
|
||||
return batched_sum([x], dict(n=ns)).sum()
|
||||
|
||||
x = jnp.array([[3, 1, 4, 1],
|
||||
[5, 9, 2, 6],
|
||||
[5, 3, 5, 8]])
|
||||
ns = jnp.array([2, 3, 2])
|
||||
ans = fun([x, ns], dict(m=2))
|
||||
expected = 3+1 + 5+9+2
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
def test_slice_oob_indexing(self):
|
||||
# https://github.com/google/jax/issues/2245
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10])
|
||||
self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:])
|
||||
|
||||
def test_jaxpr_doesnt_include_trivial_operations(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def foo(x):
|
||||
return np.sum(x)
|
||||
|
||||
padded_x = np.array([0, 1, 2, 3, 999, 999])
|
||||
|
||||
jaxpr = make_jaxpr(foo)([padded_x], dict(n=3))
|
||||
self.assertNotIn('mul', str(jaxpr))
|
||||
self.assertNotIn('add', str(jaxpr))
|
||||
|
||||
def test_return_shape_to_user(self):
|
||||
@partial(mask, in_shapes=['n'])
|
||||
def foo(x):
|
||||
return [x, np.sum(x)]
|
||||
|
||||
out, out_shape = foo([np.arange(5)], dict(n=2))
|
||||
self.assertIsInstance(out_shape, list)
|
||||
self.assertLen(out_shape, 2)
|
||||
a, b = out_shape
|
||||
self.assertEqual(a.shape, (2,))
|
||||
self.assertEqual(b.shape, ())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user